mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 05:29:50 +08:00
test: add new unit tests for message service utilities, get message, feedback, and retention services. (#33169)
This commit is contained in:
parent
45a8967b8b
commit
a808389122
@ -0,0 +1,309 @@
|
||||
import datetime
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.conversation.messages_clean_policy import (
|
||||
BillingDisabledPolicy,
|
||||
)
|
||||
from services.retention.conversation.messages_clean_service import MessagesCleanService
|
||||
|
||||
|
||||
class TestMessagesCleanService:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db_engine(self):
|
||||
with patch("services.retention.conversation.messages_clean_service.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db.engine
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self, mock_db_engine):
|
||||
with patch("services.retention.conversation.messages_clean_service.Session") as mock_session_cls:
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
yield mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_policy(self):
|
||||
policy = MagicMock(spec=BillingDisabledPolicy)
|
||||
return policy
|
||||
|
||||
def test_run_calls_clean_messages(self, mock_policy):
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
end_before=datetime.datetime.now(),
|
||||
batch_size=10,
|
||||
)
|
||||
with patch.object(service, "_clean_messages_by_time_range") as mock_clean:
|
||||
mock_clean.return_value = {"total_deleted": 5}
|
||||
result = service.run()
|
||||
assert result == {"total_deleted": 5}
|
||||
mock_clean.assert_called_once()
|
||||
|
||||
def test_clean_messages_by_time_range_basic(self, mock_db_session, mock_policy):
|
||||
# Arrange
|
||||
end_before = datetime.datetime(2024, 1, 1, 12, 0, 0)
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
end_before=end_before,
|
||||
batch_size=10,
|
||||
)
|
||||
|
||||
mock_db_session.execute.side_effect = [
|
||||
MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # messages
|
||||
MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps
|
||||
MagicMock(
|
||||
rowcount=1
|
||||
), # delete relations (this is wrong, relations delete doesn't use rowcount here, but execute)
|
||||
MagicMock(rowcount=1), # delete relations
|
||||
MagicMock(rowcount=1), # delete relations
|
||||
MagicMock(rowcount=1), # delete relations
|
||||
MagicMock(rowcount=1), # delete relations
|
||||
MagicMock(rowcount=1), # delete relations
|
||||
MagicMock(rowcount=1), # delete relations
|
||||
MagicMock(rowcount=1), # delete relations
|
||||
MagicMock(rowcount=1), # delete messages
|
||||
MagicMock(all=lambda: []), # next batch empty
|
||||
]
|
||||
|
||||
# Reset side_effect to be more robust
|
||||
# The service calls session.execute for:
|
||||
# 1. Fetch messages
|
||||
# 2. Fetch apps
|
||||
# 3. Batch delete relations (8 calls if IDs exist)
|
||||
# 4. Delete messages
|
||||
|
||||
mock_returns = [
|
||||
MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # fetch messages
|
||||
MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # fetch apps
|
||||
]
|
||||
# 8 deletes for relations
|
||||
mock_returns.extend([MagicMock() for _ in range(8)])
|
||||
# 1 delete for messages
|
||||
mock_returns.append(MagicMock(rowcount=1))
|
||||
# Final fetch messages (empty)
|
||||
mock_returns.append(MagicMock(all=lambda: []))
|
||||
|
||||
mock_db_session.execute.side_effect = mock_returns
|
||||
mock_policy.filter_message_ids.return_value = ["msg1"]
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.time.sleep"):
|
||||
stats = service.run()
|
||||
|
||||
# Assert
|
||||
assert stats["total_messages"] == 1
|
||||
assert stats["total_deleted"] == 1
|
||||
assert stats["batches"] == 2
|
||||
|
||||
def test_clean_messages_by_time_range_with_start_from(self, mock_db_session, mock_policy):
|
||||
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||
end_before = datetime.datetime(2024, 1, 1, 12, 0, 0)
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=10,
|
||||
)
|
||||
|
||||
mock_db_session.execute.side_effect = [
|
||||
MagicMock(all=lambda: []), # No messages
|
||||
]
|
||||
|
||||
stats = service.run()
|
||||
assert stats["total_messages"] == 0
|
||||
|
||||
def test_clean_messages_by_time_range_with_cursor(self, mock_db_session, mock_policy):
|
||||
# Test pagination with cursor
|
||||
end_before = datetime.datetime(2024, 1, 1, 12, 0, 0)
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
end_before=end_before,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
msg1_time = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
msg2_time = datetime.datetime(2024, 1, 1, 11, 0, 0)
|
||||
|
||||
mock_returns = []
|
||||
# Batch 1
|
||||
mock_returns.append(MagicMock(all=lambda: [("msg1", "app1", msg1_time)]))
|
||||
mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]))
|
||||
mock_returns.extend([MagicMock() for _ in range(8)]) # relations
|
||||
mock_returns.append(MagicMock(rowcount=1)) # messages
|
||||
|
||||
# Batch 2
|
||||
mock_returns.append(MagicMock(all=lambda: [("msg2", "app1", msg2_time)]))
|
||||
mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]))
|
||||
mock_returns.extend([MagicMock() for _ in range(8)]) # relations
|
||||
mock_returns.append(MagicMock(rowcount=1)) # messages
|
||||
|
||||
# Batch 3
|
||||
mock_returns.append(MagicMock(all=lambda: []))
|
||||
|
||||
mock_db_session.execute.side_effect = mock_returns
|
||||
mock_policy.filter_message_ids.return_value = ["msg1"] # Simplified
|
||||
|
||||
with patch("services.retention.conversation.messages_clean_service.time.sleep"):
|
||||
stats = service.run()
|
||||
|
||||
assert stats["batches"] == 3
|
||||
assert stats["total_messages"] == 2
|
||||
|
||||
def test_clean_messages_by_time_range_dry_run(self, mock_db_session, mock_policy):
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
end_before=datetime.datetime.now(),
|
||||
batch_size=10,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
mock_db_session.execute.side_effect = [
|
||||
MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages
|
||||
MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps
|
||||
MagicMock(all=lambda: []), # next batch empty
|
||||
]
|
||||
mock_policy.filter_message_ids.return_value = ["msg1"]
|
||||
|
||||
with patch("services.retention.conversation.messages_clean_service.random.sample") as mock_sample:
|
||||
mock_sample.return_value = ["msg1"]
|
||||
stats = service.run()
|
||||
assert stats["filtered_messages"] == 1
|
||||
assert stats["total_deleted"] == 0 # Dry run
|
||||
mock_sample.assert_called()
|
||||
|
||||
def test_clean_messages_by_time_range_no_apps_found(self, mock_db_session, mock_policy):
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
end_before=datetime.datetime.now(),
|
||||
batch_size=10,
|
||||
)
|
||||
|
||||
mock_db_session.execute.side_effect = [
|
||||
MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages
|
||||
MagicMock(all=lambda: []), # apps NOT found
|
||||
MagicMock(all=lambda: []), # next batch empty
|
||||
]
|
||||
|
||||
stats = service.run()
|
||||
assert stats["total_messages"] == 1
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
def test_clean_messages_by_time_range_no_app_ids(self, mock_db_session, mock_policy):
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
end_before=datetime.datetime.now(),
|
||||
batch_size=10,
|
||||
)
|
||||
|
||||
mock_db_session.execute.side_effect = [
|
||||
MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages
|
||||
MagicMock(all=lambda: []), # next batch empty
|
||||
]
|
||||
|
||||
# We need to successfully execute line 228 and 229, then return empty at 251.
|
||||
# line 228: raw_messages = list(session.execute(msg_stmt).all())
|
||||
# line 251: app_ids = list({msg.app_id for msg in messages})
|
||||
|
||||
calls = []
|
||||
|
||||
def list_side_effect(arg):
|
||||
calls.append(arg)
|
||||
if len(calls) == 2: # This is the second call to list() in the loop
|
||||
return []
|
||||
return list(arg)
|
||||
|
||||
with patch("services.retention.conversation.messages_clean_service.list", side_effect=list_side_effect):
|
||||
stats = service.run()
|
||||
assert stats["batches"] == 2
|
||||
assert stats["total_messages"] == 1
|
||||
|
||||
def test_from_time_range_validation(self, mock_policy):
|
||||
now = datetime.datetime.now()
|
||||
# Test start_from >= end_before
|
||||
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
|
||||
MessagesCleanService.from_time_range(mock_policy, now, now)
|
||||
|
||||
# Test batch_size <= 0
|
||||
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
|
||||
MessagesCleanService.from_time_range(mock_policy, now - datetime.timedelta(days=1), now, batch_size=0)
|
||||
|
||||
def test_from_time_range_success(self, mock_policy):
|
||||
start = datetime.datetime(2024, 1, 1)
|
||||
end = datetime.datetime(2024, 2, 1)
|
||||
# Mock logger to avoid actual logging if needed, though it's fine
|
||||
service = MessagesCleanService.from_time_range(mock_policy, start, end)
|
||||
assert service._start_from == start
|
||||
assert service._end_before == end
|
||||
|
||||
def test_from_days_validation(self, mock_policy):
|
||||
# Test days < 0
|
||||
with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"):
|
||||
MessagesCleanService.from_days(mock_policy, days=-1)
|
||||
|
||||
# Test batch_size <= 0
|
||||
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
|
||||
MessagesCleanService.from_days(mock_policy, days=30, batch_size=0)
|
||||
|
||||
def test_from_days_success(self, mock_policy):
|
||||
with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now:
|
||||
fixed_now = datetime.datetime(2024, 6, 1)
|
||||
mock_now.return_value = fixed_now
|
||||
|
||||
service = MessagesCleanService.from_days(mock_policy, days=10)
|
||||
assert service._start_from is None
|
||||
assert service._end_before == fixed_now - datetime.timedelta(days=10)
|
||||
|
||||
def test_clean_messages_by_time_range_no_messages_to_delete(self, mock_db_session, mock_policy):
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
end_before=datetime.datetime.now(),
|
||||
batch_size=10,
|
||||
)
|
||||
|
||||
mock_db_session.execute.side_effect = [
|
||||
MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages
|
||||
MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps
|
||||
MagicMock(all=lambda: []), # next batch empty
|
||||
]
|
||||
mock_policy.filter_message_ids.return_value = [] # Policy says NO
|
||||
|
||||
stats = service.run()
|
||||
assert stats["total_messages"] == 1
|
||||
assert stats["filtered_messages"] == 0
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
def test_batch_delete_message_relations_empty(self, mock_db_session):
|
||||
MessagesCleanService._batch_delete_message_relations(mock_db_session, [])
|
||||
mock_db_session.execute.assert_not_called()
|
||||
|
||||
def test_batch_delete_message_relations_with_ids(self, mock_db_session):
|
||||
MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"])
|
||||
assert mock_db_session.execute.call_count == 8 # 8 tables to clean up
|
||||
|
||||
@patch.dict(os.environ, {"SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL": "500"})
|
||||
def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy):
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
end_before=datetime.datetime.now(),
|
||||
batch_size=10,
|
||||
)
|
||||
|
||||
mock_returns = [
|
||||
MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages
|
||||
MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps
|
||||
]
|
||||
mock_returns.extend([MagicMock() for _ in range(8)]) # relations
|
||||
mock_returns.append(MagicMock(rowcount=1)) # messages
|
||||
mock_returns.append(MagicMock(all=lambda: [])) # next batch empty
|
||||
|
||||
mock_db_session.execute.side_effect = mock_returns
|
||||
mock_policy.filter_message_ids.return_value = ["msg1"]
|
||||
|
||||
with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep:
|
||||
with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform:
|
||||
mock_uniform.return_value = 300.0
|
||||
service.run()
|
||||
mock_uniform.assert_called_with(0, 500)
|
||||
mock_sleep.assert_called_with(0.3)
|
||||
@ -0,0 +1,499 @@
|
||||
"""
|
||||
Unit tests for WorkflowRunCleanup service.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
||||
|
||||
|
||||
def make_run(tenant_id: str = "t1", run_id: str = "r1", created_at: datetime.datetime | None = None):
|
||||
run = MagicMock()
|
||||
run.tenant_id = tenant_id
|
||||
run.id = run_id
|
||||
run.created_at = created_at or datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC)
|
||||
return run
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repo():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup(mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
yield WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constructor validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWorkflowRunCleanupInit:
|
||||
def test_only_start_from_raises(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
with pytest.raises(ValueError, match="both set or both omitted"):
|
||||
WorkflowRunCleanup(
|
||||
days=30,
|
||||
batch_size=10,
|
||||
start_from=datetime.datetime(2024, 1, 1),
|
||||
workflow_run_repo=mock_repo,
|
||||
)
|
||||
|
||||
def test_only_end_before_raises(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
with pytest.raises(ValueError, match="both set or both omitted"):
|
||||
WorkflowRunCleanup(
|
||||
days=30,
|
||||
batch_size=10,
|
||||
end_before=datetime.datetime(2024, 1, 1),
|
||||
workflow_run_repo=mock_repo,
|
||||
)
|
||||
|
||||
def test_end_before_not_greater_than_start_raises(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
with pytest.raises(ValueError, match="end_before must be greater than start_from"):
|
||||
WorkflowRunCleanup(
|
||||
days=30,
|
||||
batch_size=10,
|
||||
start_from=datetime.datetime(2024, 6, 1),
|
||||
end_before=datetime.datetime(2024, 1, 1),
|
||||
workflow_run_repo=mock_repo,
|
||||
)
|
||||
|
||||
def test_equal_start_end_raises(self, mock_repo):
|
||||
dt = datetime.datetime(2024, 1, 1)
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowRunCleanup(days=30, batch_size=10, start_from=dt, end_before=dt, workflow_run_repo=mock_repo)
|
||||
|
||||
def test_zero_batch_size_raises(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
with pytest.raises(ValueError, match="batch_size must be greater than 0"):
|
||||
WorkflowRunCleanup(days=30, batch_size=0, workflow_run_repo=mock_repo)
|
||||
|
||||
def test_negative_batch_size_raises(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowRunCleanup(days=30, batch_size=-1, workflow_run_repo=mock_repo)
|
||||
|
||||
def test_valid_window_init(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 7
|
||||
cfg.BILLING_ENABLED = False
|
||||
start = datetime.datetime(2024, 1, 1)
|
||||
end = datetime.datetime(2024, 6, 1)
|
||||
c = WorkflowRunCleanup(days=30, batch_size=5, start_from=start, end_before=end, workflow_run_repo=mock_repo)
|
||||
assert c.window_start == start
|
||||
assert c.window_end == end
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _empty_related_counts / _format_related_counts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStaticHelpers:
|
||||
def test_empty_related_counts(self):
|
||||
counts = WorkflowRunCleanup._empty_related_counts()
|
||||
assert counts == {
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 0,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
|
||||
def test_format_related_counts(self):
|
||||
counts = {
|
||||
"node_executions": 1,
|
||||
"offloads": 2,
|
||||
"app_logs": 3,
|
||||
"trigger_logs": 4,
|
||||
"pauses": 5,
|
||||
"pause_reasons": 6,
|
||||
}
|
||||
result = WorkflowRunCleanup._format_related_counts(counts)
|
||||
assert "node_executions 1" in result
|
||||
assert "offloads 2" in result
|
||||
assert "trigger_logs 4" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _expiration_datetime
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExpirationDatetime:
|
||||
def test_negative_returns_none(self, cleanup):
|
||||
assert cleanup._expiration_datetime("t1", -1) is None
|
||||
|
||||
def test_valid_timestamp(self, cleanup):
|
||||
ts = int(datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC).timestamp())
|
||||
result = cleanup._expiration_datetime("t1", ts)
|
||||
assert result is not None
|
||||
assert result.year == 2025
|
||||
|
||||
def test_overflow_returns_none(self, cleanup):
|
||||
result = cleanup._expiration_datetime("t1", 2**62)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_within_grace_period
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsWithinGracePeriod:
|
||||
def test_zero_grace_period_returns_false(self, cleanup):
|
||||
cleanup.free_plan_grace_period_days = 0
|
||||
assert cleanup._is_within_grace_period("t1", {"expiration_date": 9999999999}) is False
|
||||
|
||||
def test_within_grace_period(self, cleanup):
|
||||
cleanup.free_plan_grace_period_days = 30
|
||||
# expired just 1 day ago
|
||||
expired = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=1)
|
||||
ts = int(expired.timestamp())
|
||||
assert cleanup._is_within_grace_period("t1", {"expiration_date": ts}) is True
|
||||
|
||||
def test_outside_grace_period(self, cleanup):
|
||||
cleanup.free_plan_grace_period_days = 5
|
||||
# expired 100 days ago
|
||||
expired = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=100)
|
||||
ts = int(expired.timestamp())
|
||||
assert cleanup._is_within_grace_period("t1", {"expiration_date": ts}) is False
|
||||
|
||||
def test_missing_expiration_date_returns_false(self, cleanup):
|
||||
cleanup.free_plan_grace_period_days = 30
|
||||
assert cleanup._is_within_grace_period("t1", {"expiration_date": -1}) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_cleanup_whitelist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetCleanupWhitelist:
|
||||
def test_billing_disabled_returns_empty(self, cleanup):
|
||||
cleanup._cleanup_whitelist = None
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
result = cleanup._get_cleanup_whitelist()
|
||||
assert result == set()
|
||||
|
||||
def test_billing_enabled_fetches_whitelist(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = True
|
||||
c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo)
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService"
|
||||
) as bs:
|
||||
bs.get_expired_subscription_cleanup_whitelist.return_value = ["t1", "t2"]
|
||||
result = c._get_cleanup_whitelist()
|
||||
assert result == {"t1", "t2"}
|
||||
|
||||
def test_cached_whitelist_returned(self, cleanup):
|
||||
cleanup._cleanup_whitelist = {"cached"}
|
||||
result = cleanup._get_cleanup_whitelist()
|
||||
assert result == {"cached"}
|
||||
|
||||
def test_billing_service_error_returns_empty(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = True
|
||||
c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo)
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService"
|
||||
) as bs:
|
||||
bs.get_expired_subscription_cleanup_whitelist.side_effect = Exception("error")
|
||||
result = c._get_cleanup_whitelist()
|
||||
assert result == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _filter_free_tenants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFilterFreeTenants:
|
||||
def test_billing_disabled_all_tenants_free(self, cleanup):
|
||||
result = cleanup._filter_free_tenants(["t1", "t2"])
|
||||
assert result == {"t1", "t2"}
|
||||
|
||||
def test_empty_tenants_returns_empty(self, cleanup):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = True
|
||||
result = cleanup._filter_free_tenants([])
|
||||
assert result == set()
|
||||
|
||||
def test_whitelisted_tenant_excluded(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = True
|
||||
c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo)
|
||||
c._cleanup_whitelist = {"t1"}
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService"
|
||||
) as bs:
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
||||
bs.get_plan_bulk_with_cache.return_value = {
|
||||
"t1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
"t2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
}
|
||||
result = c._filter_free_tenants(["t1", "t2"])
|
||||
assert "t1" not in result
|
||||
assert "t2" in result
|
||||
|
||||
def test_paid_tenant_excluded(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = True
|
||||
c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo)
|
||||
c._cleanup_whitelist = set()
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService"
|
||||
) as bs:
|
||||
bs.get_plan_bulk_with_cache.return_value = {
|
||||
"t1": {"plan": "professional", "expiration_date": -1},
|
||||
}
|
||||
result = c._filter_free_tenants(["t1"])
|
||||
assert result == set()
|
||||
|
||||
def test_missing_billing_info_treats_as_non_free(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = True
|
||||
c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo)
|
||||
c._cleanup_whitelist = set()
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService"
|
||||
) as bs:
|
||||
bs.get_plan_bulk_with_cache.return_value = {}
|
||||
result = c._filter_free_tenants(["t1"])
|
||||
assert result == set()
|
||||
|
||||
def test_billing_bulk_error_treats_as_non_free(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = True
|
||||
c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo)
|
||||
c._cleanup_whitelist = set()
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService"
|
||||
) as bs:
|
||||
bs.get_plan_bulk_with_cache.side_effect = Exception("fail")
|
||||
result = c._filter_free_tenants(["t1"])
|
||||
assert result == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run() — delete mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunDeleteMode:
|
||||
def _make_cleanup(self, mock_repo, billing_enabled=False):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = billing_enabled
|
||||
return WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo)
|
||||
|
||||
def test_no_rows_stops_immediately(self, mock_repo):
|
||||
mock_repo.get_runs_batch_by_time_range.return_value = []
|
||||
c = self._make_cleanup(mock_repo)
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
c.run()
|
||||
mock_repo.delete_runs_with_related.assert_not_called()
|
||||
|
||||
def test_all_paid_skips_delete(self, mock_repo):
|
||||
run = make_run("t1")
|
||||
mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []]
|
||||
c = self._make_cleanup(mock_repo)
|
||||
# billing disabled -> all free; but let's override _filter_free_tenants to return empty
|
||||
c._filter_free_tenants = MagicMock(return_value=set())
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
c.run()
|
||||
mock_repo.delete_runs_with_related.assert_not_called()
|
||||
|
||||
def test_runs_deleted_successfully(self, mock_repo):
|
||||
run = make_run("t1")
|
||||
mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []]
|
||||
mock_repo.delete_runs_with_related.return_value = {
|
||||
"runs": 1,
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 0,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
c = self._make_cleanup(mock_repo)
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.time.sleep"):
|
||||
c.run()
|
||||
mock_repo.delete_runs_with_related.assert_called_once()
|
||||
|
||||
def test_delete_exception_reraises(self, mock_repo):
|
||||
run = make_run("t1")
|
||||
mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []]
|
||||
mock_repo.delete_runs_with_related.side_effect = RuntimeError("db error")
|
||||
c = self._make_cleanup(mock_repo)
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
with pytest.raises(RuntimeError):
|
||||
c.run()
|
||||
|
||||
def test_summary_with_window_start(self, mock_repo):
|
||||
mock_repo.get_runs_batch_by_time_range.return_value = []
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
c = WorkflowRunCleanup(
|
||||
days=30,
|
||||
batch_size=10,
|
||||
start_from=datetime.datetime(2024, 1, 1),
|
||||
end_before=datetime.datetime(2024, 6, 1),
|
||||
workflow_run_repo=mock_repo,
|
||||
)
|
||||
c.run()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run() — dry run mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunDryRunMode:
|
||||
def _make_dry_cleanup(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
return WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo, dry_run=True)
|
||||
|
||||
def test_dry_run_no_delete_called(self, mock_repo):
|
||||
run = make_run("t1")
|
||||
mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []]
|
||||
mock_repo.count_runs_with_related.return_value = {
|
||||
"node_executions": 2,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 1,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
c = self._make_dry_cleanup(mock_repo)
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
c.run()
|
||||
mock_repo.delete_runs_with_related.assert_not_called()
|
||||
mock_repo.count_runs_with_related.assert_called_once()
|
||||
|
||||
def test_dry_run_summary_with_window_start(self, mock_repo):
|
||||
mock_repo.get_runs_batch_by_time_range.return_value = []
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
c = WorkflowRunCleanup(
|
||||
days=30,
|
||||
batch_size=10,
|
||||
start_from=datetime.datetime(2024, 1, 1),
|
||||
end_before=datetime.datetime(2024, 6, 1),
|
||||
workflow_run_repo=mock_repo,
|
||||
dry_run=True,
|
||||
)
|
||||
c.run()
|
||||
|
||||
def test_dry_run_all_paid_skips_count(self, mock_repo):
|
||||
run = make_run("t1")
|
||||
mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []]
|
||||
c = self._make_dry_cleanup(mock_repo)
|
||||
c._filter_free_tenants = MagicMock(return_value=set())
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
c.run()
|
||||
mock_repo.count_runs_with_related.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _delete_trigger_logs / _count_trigger_logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTriggerLogMethods:
|
||||
def test_delete_trigger_logs(self, cleanup):
|
||||
session = MagicMock()
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.SQLAlchemyWorkflowTriggerLogRepository"
|
||||
) as RepoClass:
|
||||
instance = RepoClass.return_value
|
||||
instance.delete_by_run_ids.return_value = 5
|
||||
result = cleanup._delete_trigger_logs(session, ["r1", "r2"])
|
||||
assert result == 5
|
||||
|
||||
def test_count_trigger_logs(self, cleanup):
|
||||
session = MagicMock()
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.SQLAlchemyWorkflowTriggerLogRepository"
|
||||
) as RepoClass:
|
||||
instance = RepoClass.return_value
|
||||
instance.count_by_run_ids.return_value = 3
|
||||
result = cleanup._count_trigger_logs(session, ["r1"])
|
||||
assert result == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _count_node_executions / _delete_node_executions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNodeExecutionMethods:
|
||||
def test_count_node_executions(self, cleanup):
|
||||
session = MagicMock()
|
||||
session.get_bind.return_value = MagicMock()
|
||||
runs = [make_run("t1", "r1")]
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.DifyAPIRepositoryFactory"
|
||||
) as factory:
|
||||
repo = factory.create_api_workflow_node_execution_repository.return_value
|
||||
repo.count_by_runs.return_value = (10, 2)
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.sessionmaker"):
|
||||
result = cleanup._count_node_executions(session, runs)
|
||||
assert result == (10, 2)
|
||||
|
||||
def test_delete_node_executions(self, cleanup):
|
||||
session = MagicMock()
|
||||
session.get_bind.return_value = MagicMock()
|
||||
runs = [make_run("t1", "r1")]
|
||||
with patch(
|
||||
"services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.DifyAPIRepositoryFactory"
|
||||
) as factory:
|
||||
repo = factory.create_api_workflow_node_execution_repository.return_value
|
||||
repo.delete_by_runs.return_value = (5, 1)
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.sessionmaker"):
|
||||
result = cleanup._delete_node_executions(session, runs)
|
||||
assert result == (5, 1)
|
||||
@ -0,0 +1,216 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.workflow import WorkflowRun
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult
|
||||
|
||||
|
||||
class TestArchivedWorkflowRunDeletion:
|
||||
@pytest.fixture
|
||||
def mock_db(self):
|
||||
with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessionmaker(self):
|
||||
with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm:
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_sm.return_value.return_value.__enter__.return_value = mock_session
|
||||
yield mock_sm, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run_repo(self):
|
||||
with patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository"
|
||||
) as mock_repo_cls:
|
||||
mock_repo = MagicMock()
|
||||
yield mock_repo
|
||||
|
||||
def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker):
|
||||
mock_sm, mock_session = mock_sessionmaker
|
||||
run_id = "run-123"
|
||||
tenant_id = "tenant-456"
|
||||
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = run_id
|
||||
mock_run.tenant_id = tenant_id
|
||||
mock_session.get.return_value = mock_run
|
||||
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.get_archived_run_ids.return_value = [run_id]
|
||||
|
||||
with patch.object(deletion, "_delete_run") as mock_delete_run:
|
||||
expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True)
|
||||
mock_delete_run.return_value = expected_result
|
||||
|
||||
result = deletion.delete_by_run_id(run_id)
|
||||
|
||||
assert result == expected_result
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, run_id)
|
||||
mock_repo.get_archived_run_ids.assert_called_once()
|
||||
mock_delete_run.assert_called_once_with(mock_run)
|
||||
|
||||
def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker):
|
||||
mock_sm, mock_session = mock_sessionmaker
|
||||
run_id = "run-123"
|
||||
mock_session.get.return_value = None
|
||||
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
with patch.object(deletion, "_get_workflow_run_repo"):
|
||||
result = deletion.delete_by_run_id(run_id)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.error
|
||||
assert result.run_id == run_id
|
||||
|
||||
def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker):
|
||||
mock_sm, mock_session = mock_sessionmaker
|
||||
run_id = "run-123"
|
||||
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = run_id
|
||||
mock_session.get.return_value = mock_run
|
||||
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.get_archived_run_ids.return_value = []
|
||||
|
||||
result = deletion.delete_by_run_id(run_id)
|
||||
|
||||
assert result.success is False
|
||||
assert "is not archived" in result.error
|
||||
|
||||
def test_delete_batch(self, mock_db, mock_sessionmaker):
|
||||
mock_sm, mock_session = mock_sessionmaker
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
|
||||
mock_run1 = MagicMock(spec=WorkflowRun)
|
||||
mock_run1.id = "run-1"
|
||||
mock_run2 = MagicMock(spec=WorkflowRun)
|
||||
mock_run2.id = "run-2"
|
||||
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2]
|
||||
|
||||
with patch.object(deletion, "_delete_run") as mock_delete_run:
|
||||
mock_delete_run.side_effect = [
|
||||
DeleteResult(run_id="run-1", tenant_id="t1", success=True),
|
||||
DeleteResult(run_id="run-2", tenant_id="t1", success=True),
|
||||
]
|
||||
|
||||
results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now())
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].run_id == "run-1"
|
||||
assert results[1].run_id == "run-2"
|
||||
assert mock_delete_run.call_count == 2
|
||||
|
||||
def test_delete_run_dry_run(self):
|
||||
deletion = ArchivedWorkflowRunDeletion(dry_run=True)
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = "run-123"
|
||||
mock_run.tenant_id = "tenant-456"
|
||||
|
||||
result = deletion._delete_run(mock_run)
|
||||
|
||||
assert result.success is True
|
||||
assert result.run_id == "run-123"
|
||||
|
||||
def test_delete_run_success(self):
|
||||
deletion = ArchivedWorkflowRunDeletion(dry_run=False)
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = "run-123"
|
||||
mock_run.tenant_id = "tenant-456"
|
||||
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1}
|
||||
|
||||
result = deletion._delete_run(mock_run)
|
||||
|
||||
assert result.success is True
|
||||
assert result.deleted_counts == {"workflow_runs": 1}
|
||||
|
||||
def test_delete_run_exception(self):
|
||||
deletion = ArchivedWorkflowRunDeletion(dry_run=False)
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = "run-123"
|
||||
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.delete_runs_with_related.side_effect = Exception("Database error")
|
||||
|
||||
result = deletion._delete_run(mock_run)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Database error"
|
||||
|
||||
def test_delete_trigger_logs(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
run_ids = ["run-1", "run-2"]
|
||||
|
||||
with patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository"
|
||||
) as mock_repo_cls:
|
||||
mock_repo = MagicMock()
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
mock_repo.delete_by_run_ids.return_value = 5
|
||||
|
||||
count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids)
|
||||
|
||||
assert count == 5
|
||||
mock_repo_cls.assert_called_once_with(mock_session)
|
||||
mock_repo.delete_by_run_ids.assert_called_once_with(run_ids)
|
||||
|
||||
def test_delete_node_executions(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = "run-1"
|
||||
runs = [mock_run]
|
||||
|
||||
with patch(
|
||||
"repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository"
|
||||
) as mock_create_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_create_repo.return_value = mock_repo
|
||||
mock_repo.delete_by_runs.return_value = (1, 2)
|
||||
|
||||
with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm:
|
||||
result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs)
|
||||
|
||||
assert result == (1, 2)
|
||||
mock_create_repo.assert_called_once()
|
||||
mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"])
|
||||
|
||||
def test_get_workflow_run_repo(self, mock_db):
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
|
||||
with patch(
|
||||
"repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository"
|
||||
) as mock_create_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_create_repo.return_value = mock_repo
|
||||
|
||||
# First call
|
||||
repo1 = deletion._get_workflow_run_repo()
|
||||
assert repo1 == mock_repo
|
||||
assert deletion.workflow_run_repo == mock_repo
|
||||
|
||||
# Second call (should return cached)
|
||||
repo2 = deletion._get_workflow_run_repo()
|
||||
assert repo2 == mock_repo
|
||||
mock_create_repo.assert_called_once()
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,421 @@
|
||||
"""
|
||||
Comprehensive unit tests for services/api_based_extension_service.py
|
||||
|
||||
Covers:
|
||||
- APIBasedExtensionService.get_all_by_tenant_id
|
||||
- APIBasedExtensionService.save
|
||||
- APIBasedExtensionService.delete
|
||||
- APIBasedExtensionService.get_with_tenant_id
|
||||
- APIBasedExtensionService._validation (new record & existing record branches)
|
||||
- APIBasedExtensionService._ping_connection (pong success, wrong response, exception)
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_extension(
|
||||
*,
|
||||
id_: str | None = None,
|
||||
tenant_id: str = "tenant-001",
|
||||
name: str = "my-ext",
|
||||
api_endpoint: str = "https://example.com/hook",
|
||||
api_key: str = "secret-key-123",
|
||||
) -> MagicMock:
|
||||
"""Return a lightweight mock that mimics APIBasedExtension."""
|
||||
ext = MagicMock()
|
||||
ext.id = id_
|
||||
ext.tenant_id = tenant_id
|
||||
ext.name = name
|
||||
ext.api_endpoint = api_endpoint
|
||||
ext.api_key = api_key
|
||||
return ext
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: get_all_by_tenant_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetAllByTenantId:
|
||||
"""Tests for APIBasedExtensionService.get_all_by_tenant_id."""
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_returns_extensions_with_decrypted_keys(self, mock_db, mock_decrypt):
|
||||
"""Each api_key is decrypted and the list is returned."""
|
||||
ext1 = _make_extension(id_="id-1", api_key="enc-key-1")
|
||||
ext2 = _make_extension(id_="id-2", api_key="enc-key-2")
|
||||
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [
|
||||
ext1,
|
||||
ext2,
|
||||
]
|
||||
|
||||
result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001")
|
||||
|
||||
assert result == [ext1, ext2]
|
||||
assert ext1.api_key == "decrypted-key"
|
||||
assert ext2.api_key == "decrypted-key"
|
||||
assert mock_decrypt.call_count == 2
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_returns_empty_list_when_no_extensions(self, mock_db, mock_decrypt):
|
||||
"""Returns an empty list gracefully when no records exist."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = []
|
||||
|
||||
result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001")
|
||||
|
||||
assert result == []
|
||||
mock_decrypt.assert_not_called()
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_calls_query_with_correct_tenant_id(self, mock_db, mock_decrypt):
|
||||
"""Verifies the DB is queried with the supplied tenant_id."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = []
|
||||
|
||||
APIBasedExtensionService.get_all_by_tenant_id("tenant-xyz")
|
||||
|
||||
mock_db.session.query.return_value.filter_by.assert_called_once_with(tenant_id="tenant-xyz")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: save
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSave:
|
||||
"""Tests for APIBasedExtensionService.save."""
|
||||
|
||||
@patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
@patch.object(APIBasedExtensionService, "_validation")
|
||||
def test_save_new_record_encrypts_key_and_commits(self, mock_validation, mock_db, mock_encrypt):
|
||||
"""Happy path: validation passes, key is encrypted, record is added and committed."""
|
||||
ext = _make_extension(id_=None, api_key="plain-key-123")
|
||||
|
||||
result = APIBasedExtensionService.save(ext)
|
||||
|
||||
mock_validation.assert_called_once_with(ext)
|
||||
mock_encrypt.assert_called_once_with(ext.tenant_id, "plain-key-123")
|
||||
assert ext.api_key == "encrypted-key"
|
||||
mock_db.session.add.assert_called_once_with(ext)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
assert result is ext
|
||||
|
||||
@patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
@patch.object(APIBasedExtensionService, "_validation", side_effect=ValueError("name must not be empty"))
|
||||
def test_save_raises_when_validation_fails(self, mock_validation, mock_db, mock_encrypt):
|
||||
"""If _validation raises, save should propagate the error without touching the DB."""
|
||||
ext = _make_extension(name="")
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(ext)
|
||||
|
||||
mock_db.session.add.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDelete:
|
||||
"""Tests for APIBasedExtensionService.delete."""
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_delete_removes_record_and_commits(self, mock_db):
|
||||
"""delete() must call session.delete with the extension and then commit."""
|
||||
ext = _make_extension(id_="delete-me")
|
||||
|
||||
APIBasedExtensionService.delete(ext)
|
||||
|
||||
mock_db.session.delete.assert_called_once_with(ext)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: get_with_tenant_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetWithTenantId:
|
||||
"""Tests for APIBasedExtensionService.get_with_tenant_id."""
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_returns_extension_with_decrypted_key(self, mock_db, mock_decrypt):
|
||||
"""Found extension has its api_key decrypted before being returned."""
|
||||
ext = _make_extension(id_="ext-123", api_key="enc-key")
|
||||
|
||||
(mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = ext
|
||||
|
||||
result = APIBasedExtensionService.get_with_tenant_id("tenant-001", "ext-123")
|
||||
|
||||
assert result is ext
|
||||
assert ext.api_key == "decrypted-key"
|
||||
mock_decrypt.assert_called_once_with(ext.tenant_id, "enc-key")
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_value_error_when_not_found(self, mock_db):
|
||||
"""Raises ValueError when no matching extension exists."""
|
||||
(mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = None
|
||||
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id("tenant-001", "non-existent")
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_queries_with_correct_tenant_and_extension_id(self, mock_db, mock_decrypt):
|
||||
"""Verifies both tenant_id and extension id are used in the query."""
|
||||
ext = _make_extension(id_="ext-abc")
|
||||
chain = mock_db.session.query.return_value
|
||||
chain.filter_by.return_value.filter_by.return_value.first.return_value = ext
|
||||
|
||||
APIBasedExtensionService.get_with_tenant_id("tenant-002", "ext-abc")
|
||||
|
||||
# First filter_by call uses tenant_id
|
||||
chain.filter_by.assert_called_once_with(tenant_id="tenant-002")
|
||||
# Second filter_by call uses id
|
||||
chain.filter_by.return_value.filter_by.assert_called_once_with(id="ext-abc")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _validation (new record — id is falsy)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidationNewRecord:
|
||||
"""Tests for _validation() with a brand-new record (no id)."""
|
||||
|
||||
def _build_mock_db(self, name_exists: bool = False):
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = (
|
||||
MagicMock() if name_exists else None
|
||||
)
|
||||
return mock_db
|
||||
|
||||
@patch.object(APIBasedExtensionService, "_ping_connection")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_valid_new_extension_passes(self, mock_db, mock_ping):
|
||||
"""A new record with all valid fields should pass without exceptions."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, name="valid-ext", api_key="longenoughkey")
|
||||
|
||||
# Should not raise
|
||||
APIBasedExtensionService._validation(ext)
|
||||
mock_ping.assert_called_once_with(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_name_is_empty(self, mock_db):
|
||||
"""Empty name raises ValueError."""
|
||||
ext = _make_extension(id_=None, name="")
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_name_is_none(self, mock_db):
|
||||
"""None name raises ValueError."""
|
||||
ext = _make_extension(id_=None, name=None)
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_name_already_exists_for_new_record(self, mock_db):
|
||||
"""A new record whose name already exists raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = (
|
||||
MagicMock()
|
||||
)
|
||||
ext = _make_extension(id_=None, name="duplicate-name")
|
||||
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_endpoint_is_empty(self, mock_db):
|
||||
"""Empty api_endpoint raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_endpoint="")
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_endpoint_is_none(self, mock_db):
|
||||
"""None api_endpoint raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_endpoint=None)
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_key_is_empty(self, mock_db):
|
||||
"""Empty api_key raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key="")
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_key_is_none(self, mock_db):
|
||||
"""None api_key raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key=None)
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_key_too_short(self, mock_db):
|
||||
"""api_key shorter than 5 characters raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key="abc")
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_key_exactly_four_chars(self, mock_db):
|
||||
"""api_key with exactly 4 characters raises ValueError (boundary condition)."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key="1234")
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch.object(APIBasedExtensionService, "_ping_connection")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_api_key_exactly_five_chars_is_accepted(self, mock_db, mock_ping):
|
||||
"""api_key with exactly 5 characters should pass (boundary condition)."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key="12345")
|
||||
|
||||
# Should not raise
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _validation (existing record — id is truthy)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidationExistingRecord:
|
||||
"""Tests for _validation() with an existing record (id is set)."""
|
||||
|
||||
@patch.object(APIBasedExtensionService, "_ping_connection")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_valid_existing_extension_passes(self, mock_db, mock_ping):
|
||||
"""An existing record whose name is unique (excluding self) should pass."""
|
||||
# .where(...).first() → None means no *other* record has that name
|
||||
(
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value
|
||||
) = None
|
||||
ext = _make_extension(id_="existing-id", name="unique-name", api_key="longenoughkey")
|
||||
|
||||
# Should not raise
|
||||
APIBasedExtensionService._validation(ext)
|
||||
mock_ping.assert_called_once_with(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_existing_record_name_conflicts_with_another(self, mock_db):
|
||||
"""Existing record cannot use a name already owned by a different record."""
|
||||
(
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value
|
||||
) = MagicMock()
|
||||
ext = _make_extension(id_="existing-id", name="taken-name")
|
||||
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _ping_connection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPingConnection:
|
||||
"""Tests for APIBasedExtensionService._ping_connection."""
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_successful_ping_returns_pong(self, mock_requestor_class):
|
||||
"""When the endpoint returns {"result": "pong"}, no exception is raised."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.return_value = {"result": "pong"}
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension(api_endpoint="https://ok.example.com", api_key="secret-key")
|
||||
# Should not raise
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
mock_requestor_class.assert_called_once_with(ext.api_endpoint, ext.api_key)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_wrong_ping_response_raises_value_error(self, mock_requestor_class):
|
||||
"""When the response is not {"result": "pong"}, a ValueError is raised."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.return_value = {"result": "error"}
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension()
|
||||
with pytest.raises(ValueError, match="connection error"):
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_network_exception_wraps_in_value_error(self, mock_requestor_class):
|
||||
"""Any exception raised during request is wrapped in a ValueError."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.side_effect = ConnectionError("network failure")
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension()
|
||||
with pytest.raises(ValueError, match="connection error: network failure"):
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_requestor_constructor_exception_wraps_in_value_error(self, mock_requestor_class):
|
||||
"""Exception raised by the requestor constructor itself is wrapped."""
|
||||
mock_requestor_class.side_effect = RuntimeError("bad config")
|
||||
|
||||
ext = _make_extension()
|
||||
with pytest.raises(ValueError, match="connection error: bad config"):
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_missing_result_key_raises_value_error(self, mock_requestor_class):
|
||||
"""A response dict without a 'result' key does not equal 'pong' → raises."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.return_value = {} # no 'result' key
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension()
|
||||
with pytest.raises(ValueError, match="connection error"):
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_uses_ping_extension_point(self, mock_requestor_class):
|
||||
"""The PING extension point is passed to the client.request call."""
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.return_value = {"result": "pong"}
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension()
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
call_kwargs = mock_client.request.call_args
|
||||
assert call_kwargs.kwargs["point"] == APIBasedExtensionPoint.PING
|
||||
assert call_kwargs.kwargs["params"] == {}
|
||||
913
api/tests/unit_tests/services/test_app_dsl_service.py
Normal file
913
api/tests/unit_tests/services/test_app_dsl_service.py
Normal file
@ -0,0 +1,913 @@
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from dify_graph.enums import NodeType
|
||||
from models import Account, AppMode
|
||||
from models.model import IconType
|
||||
from services import app_dsl_service
|
||||
from services.app_dsl_service import (
|
||||
AppDslService,
|
||||
CheckDependenciesPendingData,
|
||||
ImportMode,
|
||||
ImportStatus,
|
||||
PendingData,
|
||||
_check_version_compatibility,
|
||||
)
|
||||
|
||||
|
||||
class _FakeHttpResponse:
|
||||
def __init__(self, content: bytes, *, raises: Exception | None = None):
|
||||
self.content = content
|
||||
self._raises = raises
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self._raises is not None:
|
||||
raise self._raises
|
||||
|
||||
|
||||
def _account_mock(*, tenant_id: str = "tenant-1", account_id: str = "account-1") -> MagicMock:
|
||||
account = MagicMock(spec=Account)
|
||||
account.current_tenant_id = tenant_id
|
||||
account.id = account_id
|
||||
return account
|
||||
|
||||
|
||||
def _yaml_dump(data: dict) -> str:
|
||||
return yaml.safe_dump(data, allow_unicode=True)
|
||||
|
||||
|
||||
def _workflow_yaml(*, version: str = app_dsl_service.CURRENT_DSL_VERSION) -> str:
|
||||
return _yaml_dump(
|
||||
{
|
||||
"version": version,
|
||||
"kind": "app",
|
||||
"app": {"name": "My App", "mode": AppMode.WORKFLOW.value},
|
||||
"workflow": {"graph": {"nodes": []}, "features": {}},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_check_version_compatibility_invalid_version_returns_failed():
|
||||
assert _check_version_compatibility("not-a-version") == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_check_version_compatibility_newer_version_returns_pending():
|
||||
assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING
|
||||
|
||||
|
||||
def test_check_version_compatibility_major_older_returns_pending(monkeypatch):
|
||||
monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0")
|
||||
assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING
|
||||
|
||||
|
||||
def test_check_version_compatibility_minor_older_returns_completed_with_warnings():
|
||||
assert _check_version_compatibility("0.5.0") == ImportStatus.COMPLETED_WITH_WARNINGS
|
||||
|
||||
|
||||
def test_check_version_compatibility_equal_returns_completed():
|
||||
assert _check_version_compatibility(app_dsl_service.CURRENT_DSL_VERSION) == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
def test_import_app_invalid_import_mode_raises_value_error():
|
||||
service = AppDslService(MagicMock())
|
||||
with pytest.raises(ValueError, match="Invalid import_mode"):
|
||||
service.import_app(account=_account_mock(), import_mode="invalid-mode", yaml_content="version: '0.1.0'")
|
||||
|
||||
|
||||
def test_import_app_yaml_url_requires_url():
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url=None)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "yaml_url is required" in result.error
|
||||
|
||||
|
||||
def test_import_app_yaml_content_requires_content():
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=None)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "yaml_content is required" in result.error
|
||||
|
||||
|
||||
def test_import_app_yaml_url_fetch_error_returns_failed(monkeypatch):
|
||||
def fake_get(_url: str, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml"
|
||||
)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Error fetching YAML from URL: boom" in result.error
|
||||
|
||||
|
||||
def test_import_app_yaml_url_empty_content_returns_failed(monkeypatch):
|
||||
def fake_get(_url: str, **_kwargs):
|
||||
return _FakeHttpResponse(b"")
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml"
|
||||
)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Empty content" in result.error
|
||||
|
||||
|
||||
def test_import_app_yaml_url_file_too_large_returns_failed(monkeypatch):
|
||||
def fake_get(_url: str, **_kwargs):
|
||||
return _FakeHttpResponse(b"x" * (app_dsl_service.DSL_MAX_SIZE + 1))
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml"
|
||||
)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "File size exceeds" in result.error
|
||||
|
||||
|
||||
def test_import_app_yaml_not_mapping_returns_failed():
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content="[]")
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "content must be a mapping" in result.error
|
||||
|
||||
|
||||
def test_import_app_version_not_str_returns_failed():
|
||||
service = AppDslService(MagicMock())
|
||||
yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}})
|
||||
result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=yaml_content)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Invalid version type" in result.error
|
||||
|
||||
|
||||
def test_import_app_missing_app_data_returns_failed():
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=_yaml_dump({"version": "0.6.0", "kind": "app"}),
|
||||
)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Missing app data" in result.error
|
||||
|
||||
|
||||
def test_import_app_app_id_not_found_returns_failed(monkeypatch):
|
||||
def fake_select(_model):
|
||||
stmt = MagicMock()
|
||||
stmt.where.return_value = stmt
|
||||
return stmt
|
||||
|
||||
monkeypatch.setattr(app_dsl_service, "select", fake_select)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
service = AppDslService(session)
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=_workflow_yaml(),
|
||||
app_id="missing-app",
|
||||
)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.error == "App not found"
|
||||
|
||||
|
||||
def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch):
|
||||
def fake_select(_model):
|
||||
stmt = MagicMock()
|
||||
stmt.where.return_value = stmt
|
||||
return stmt
|
||||
|
||||
monkeypatch.setattr(app_dsl_service, "select", fake_select)
|
||||
|
||||
existing_app = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = existing_app
|
||||
service = AppDslService(session)
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=_workflow_yaml(),
|
||||
app_id="app-1",
|
||||
)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Only workflow or advanced chat apps" in result.error
|
||||
|
||||
|
||||
def test_import_app_pending_stores_import_info_in_redis():
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=_workflow_yaml(version="99.0.0"),
|
||||
name="n",
|
||||
description="d",
|
||||
icon_type="emoji",
|
||||
icon="i",
|
||||
icon_background="#000000",
|
||||
)
|
||||
assert result.status == ImportStatus.PENDING
|
||||
assert result.imported_dsl_version == "99.0.0"
|
||||
|
||||
app_dsl_service.redis_client.setex.assert_called_once()
|
||||
call = app_dsl_service.redis_client.setex.call_args
|
||||
redis_key = call.args[0]
|
||||
assert redis_key.startswith(app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX)
|
||||
|
||||
|
||||
def test_import_app_completed_uses_declared_dependencies(monkeypatch):
|
||||
dependencies_payload = [{"id": "langgenius/google", "version": "1.0.0"}]
|
||||
|
||||
plugin_deps = [SimpleNamespace(model_dump=lambda: dependencies_payload[0])]
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.PluginDependency,
|
||||
"model_validate",
|
||||
lambda d: plugin_deps[0],
|
||||
)
|
||||
|
||||
created_app = SimpleNamespace(id="app-new", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1")
|
||||
monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app)
|
||||
|
||||
draft_var_service = MagicMock()
|
||||
monkeypatch.setattr(app_dsl_service, "WorkflowDraftVariableService", lambda *args, **kwargs: draft_var_service)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=_yaml_dump(
|
||||
{
|
||||
"version": app_dsl_service.CURRENT_DSL_VERSION,
|
||||
"kind": "app",
|
||||
"app": {"name": "My App", "mode": AppMode.WORKFLOW.value},
|
||||
"workflow": {"graph": {"nodes": []}, "features": {}},
|
||||
"dependencies": dependencies_payload,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
assert result.app_id == "app-new"
|
||||
draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-new")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_workflow", [True, False])
|
||||
def test_import_app_legacy_versions_extract_dependencies(monkeypatch, has_workflow: bool):
|
||||
monkeypatch.setattr(
|
||||
AppDslService,
|
||||
"_extract_dependencies_from_workflow_graph",
|
||||
lambda *_args, **_kwargs: ["from-workflow"],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
AppDslService,
|
||||
"_extract_dependencies_from_model_config",
|
||||
lambda *_args, **_kwargs: ["from-model-config"],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"generate_latest_dependencies",
|
||||
lambda deps: [SimpleNamespace(model_dump=lambda: {"dep": deps[0]})],
|
||||
)
|
||||
|
||||
created_app = SimpleNamespace(id="app-legacy", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1")
|
||||
monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app)
|
||||
|
||||
draft_var_service = MagicMock()
|
||||
monkeypatch.setattr(app_dsl_service, "WorkflowDraftVariableService", lambda *args, **kwargs: draft_var_service)
|
||||
|
||||
data: dict = {
|
||||
"version": "0.1.5",
|
||||
"kind": "app",
|
||||
"app": {"name": "Legacy", "mode": AppMode.WORKFLOW.value},
|
||||
}
|
||||
if has_workflow:
|
||||
data["workflow"] = {"graph": {"nodes": []}, "features": {}}
|
||||
else:
|
||||
data["model_config"] = {"model": {"provider": "openai"}}
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_yaml_dump(data)
|
||||
)
|
||||
assert result.status == ImportStatus.COMPLETED_WITH_WARNINGS
|
||||
draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-legacy")
|
||||
|
||||
|
||||
def test_import_app_yaml_error_returns_failed(monkeypatch):
|
||||
def bad_safe_load(_content: str):
|
||||
raise yaml.YAMLError("bad")
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.yaml, "safe_load", bad_safe_load)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content="x: y")
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.error.startswith("Invalid YAML format:")
|
||||
|
||||
|
||||
def test_import_app_unexpected_error_returns_failed(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("oops"))
|
||||
)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_workflow_yaml()
|
||||
)
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.error == "oops"
|
||||
|
||||
|
||||
def test_confirm_import_expired_returns_failed():
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.confirm_import(import_id="import-1", account=_account_mock())
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "expired" in result.error
|
||||
|
||||
|
||||
def test_confirm_import_invalid_pending_data_type_returns_failed():
|
||||
app_dsl_service.redis_client.get.return_value = 123
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.confirm_import(import_id="import-1", account=_account_mock())
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Invalid import information" in result.error
|
||||
|
||||
|
||||
def test_confirm_import_success_deletes_redis_key(monkeypatch):
|
||||
def fake_select(_model):
|
||||
stmt = MagicMock()
|
||||
stmt.where.return_value = stmt
|
||||
return stmt
|
||||
|
||||
monkeypatch.setattr(app_dsl_service, "select", fake_select)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
service = AppDslService(session)
|
||||
|
||||
pending = PendingData(
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=_workflow_yaml(),
|
||||
name="name",
|
||||
description="desc",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#fff",
|
||||
app_id=None,
|
||||
)
|
||||
app_dsl_service.redis_client.get.return_value = pending.model_dump_json()
|
||||
|
||||
created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1")
|
||||
monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app)
|
||||
|
||||
result = service.confirm_import(import_id="import-1", account=_account_mock())
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
assert result.app_id == "confirmed-app"
|
||||
app_dsl_service.redis_client.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_confirm_import_exception_returns_failed(monkeypatch):
|
||||
app_dsl_service.redis_client.get.return_value = "not-json"
|
||||
monkeypatch.setattr(
|
||||
PendingData, "model_validate_json", lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad"))
|
||||
)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.confirm_import(import_id="import-1", account=_account_mock())
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.error == "bad"
|
||||
|
||||
|
||||
def test_check_dependencies_returns_empty_when_no_redis_data():
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.check_dependencies(app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"))
|
||||
assert result.leaked_dependencies == []
|
||||
|
||||
|
||||
def test_check_dependencies_calls_analysis_service(monkeypatch):
|
||||
pending = CheckDependenciesPendingData(dependencies=[], app_id="app-1").model_dump_json()
|
||||
app_dsl_service.redis_client.get.return_value = pending
|
||||
dep = app_dsl_service.PluginDependency.model_validate(
|
||||
{"type": "package", "value": {"plugin_unique_identifier": "acme/foo", "version": "1.0.0"}}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"get_leaked_dependencies",
|
||||
lambda *, tenant_id, dependencies: [dep],
|
||||
)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.check_dependencies(app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"))
|
||||
assert len(result.leaked_dependencies) == 1
|
||||
|
||||
|
||||
def test_create_or_update_app_missing_mode_raises():
|
||||
service = AppDslService(MagicMock())
|
||||
with pytest.raises(ValueError, match="loss app mode"):
|
||||
service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock())
|
||||
|
||||
|
||||
def test_create_or_update_app_existing_app_updates_fields(monkeypatch):
|
||||
fixed_now = object()
|
||||
monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now)
|
||||
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = None
|
||||
monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service)
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.variable_factory,
|
||||
"build_environment_variable_from_mapping",
|
||||
lambda _m: SimpleNamespace(kind="env"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.variable_factory,
|
||||
"build_conversation_variable_from_mapping",
|
||||
lambda _m: SimpleNamespace(kind="conv"),
|
||||
)
|
||||
|
||||
app = SimpleNamespace(
|
||||
id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
name="old",
|
||||
description="old-desc",
|
||||
icon_type=IconType.EMOJI,
|
||||
icon="old-icon",
|
||||
icon_background="#111111",
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
service = AppDslService(MagicMock())
|
||||
updated = service._create_or_update_app(
|
||||
app=app,
|
||||
data={
|
||||
"app": {"mode": AppMode.WORKFLOW.value, "name": "yaml-name", "icon_type": IconType.IMAGE, "icon": "X"},
|
||||
"workflow": {"graph": {"nodes": []}, "features": {}},
|
||||
},
|
||||
account=_account_mock(),
|
||||
name="override-name",
|
||||
description=None,
|
||||
icon_background="#222222",
|
||||
)
|
||||
assert updated is app
|
||||
assert app.name == "override-name"
|
||||
assert app.icon_type == IconType.IMAGE
|
||||
assert app.icon == "X"
|
||||
assert app.icon_background == "#222222"
|
||||
assert app.updated_at is fixed_now
|
||||
|
||||
|
||||
def test_create_or_update_app_new_app_requires_tenant():
|
||||
account = _account_mock()
|
||||
account.current_tenant_id = None
|
||||
service = AppDslService(MagicMock())
|
||||
with pytest.raises(ValueError, match="Current tenant is not set"):
|
||||
service._create_or_update_app(
|
||||
app=None,
|
||||
data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}},
|
||||
account=account,
|
||||
)
|
||||
|
||||
|
||||
def test_create_or_update_app_creates_workflow_app_and_saves_dependencies(monkeypatch):
|
||||
class DummyApp(SimpleNamespace):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(app_dsl_service, "App", DummyApp)
|
||||
|
||||
sent: list[tuple[str, object]] = []
|
||||
monkeypatch.setattr(app_dsl_service.app_was_created, "send", lambda app, account: sent.append((app.id, account.id)))
|
||||
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = SimpleNamespace(unique_hash="uh")
|
||||
monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service)
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.variable_factory,
|
||||
"build_environment_variable_from_mapping",
|
||||
lambda _m: SimpleNamespace(kind="env"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.variable_factory,
|
||||
"build_conversation_variable_from_mapping",
|
||||
lambda _m: SimpleNamespace(kind="conv"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AppDslService, "decrypt_dataset_id", lambda *_args, **_kwargs: "00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
service = AppDslService(session)
|
||||
deps = [
|
||||
app_dsl_service.PluginDependency.model_validate(
|
||||
{"type": "package", "value": {"plugin_unique_identifier": "acme/foo", "version": "1.0.0"}}
|
||||
)
|
||||
]
|
||||
data = {
|
||||
"app": {"mode": AppMode.WORKFLOW.value, "name": "n"},
|
||||
"workflow": {
|
||||
"environment_variables": [{"x": 1}],
|
||||
"conversation_variables": [{"y": 2}],
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["enc-1", "enc-2"]}},
|
||||
]
|
||||
},
|
||||
"features": {},
|
||||
},
|
||||
}
|
||||
|
||||
app = service._create_or_update_app(app=None, data=data, account=_account_mock(), dependencies=deps)
|
||||
|
||||
assert app.tenant_id == "tenant-1"
|
||||
assert sent == [(app.id, "account-1")]
|
||||
app_dsl_service.redis_client.setex.assert_called()
|
||||
workflow_service.sync_draft_workflow.assert_called_once()
|
||||
|
||||
passed_graph = workflow_service.sync_draft_workflow.call_args.kwargs["graph"]
|
||||
dataset_ids = passed_graph["nodes"][0]["data"]["dataset_ids"]
|
||||
assert dataset_ids == ["00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000000"]
|
||||
|
||||
|
||||
def test_create_or_update_app_workflow_missing_workflow_data_raises():
|
||||
service = AppDslService(MagicMock())
|
||||
with pytest.raises(ValueError, match="Missing workflow data"):
|
||||
service._create_or_update_app(
|
||||
app=SimpleNamespace(
|
||||
id="a",
|
||||
tenant_id="t",
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
name="n",
|
||||
description="d",
|
||||
icon_background="#fff",
|
||||
app_model_config=None,
|
||||
),
|
||||
data={"app": {"mode": AppMode.WORKFLOW.value}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
|
||||
|
||||
def test_create_or_update_app_chat_requires_model_config():
|
||||
service = AppDslService(MagicMock())
|
||||
with pytest.raises(ValueError, match="Missing model_config"):
|
||||
service._create_or_update_app(
|
||||
app=SimpleNamespace(
|
||||
id="a",
|
||||
tenant_id="t",
|
||||
mode=AppMode.CHAT.value,
|
||||
name="n",
|
||||
description="d",
|
||||
icon_background="#fff",
|
||||
app_model_config=None,
|
||||
),
|
||||
data={"app": {"mode": AppMode.CHAT.value}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
|
||||
|
||||
def test_create_or_update_app_chat_creates_model_config_and_sends_event(monkeypatch):
|
||||
class DummyModelConfig(SimpleNamespace):
|
||||
def from_model_config_dict(self, _cfg: dict):
|
||||
return self
|
||||
|
||||
monkeypatch.setattr(app_dsl_service, "AppModelConfig", DummyModelConfig)
|
||||
|
||||
sent: list[str] = []
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.app_model_config_was_updated, "send", lambda app, app_model_config: sent.append(app.id)
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
service = AppDslService(session)
|
||||
|
||||
app = SimpleNamespace(
|
||||
id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
mode=AppMode.CHAT.value,
|
||||
name="n",
|
||||
description="d",
|
||||
icon_background="#fff",
|
||||
app_model_config=None,
|
||||
)
|
||||
service._create_or_update_app(
|
||||
app=app,
|
||||
data={"app": {"mode": AppMode.CHAT.value}, "model_config": {"model": {"provider": "openai"}}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
|
||||
assert app.app_model_config_id is not None
|
||||
assert sent == ["app-1"]
|
||||
session.add.assert_called()
|
||||
|
||||
|
||||
def test_create_or_update_app_invalid_mode_raises():
|
||||
service = AppDslService(MagicMock())
|
||||
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||
service._create_or_update_app(
|
||||
app=SimpleNamespace(
|
||||
id="a",
|
||||
tenant_id="t",
|
||||
mode=AppMode.RAG_PIPELINE.value,
|
||||
name="n",
|
||||
description="d",
|
||||
icon_background="#fff",
|
||||
app_model_config=None,
|
||||
),
|
||||
data={"app": {"mode": AppMode.RAG_PIPELINE.value}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
|
||||
|
||||
def test_export_dsl_delegates_by_mode(monkeypatch):
|
||||
workflow_calls: list[bool] = []
|
||||
model_calls: list[bool] = []
|
||||
monkeypatch.setattr(AppDslService, "_append_workflow_export_data", lambda **_kwargs: workflow_calls.append(True))
|
||||
monkeypatch.setattr(
|
||||
AppDslService, "_append_model_config_export_data", lambda *_args, **_kwargs: model_calls.append(True)
|
||||
)
|
||||
|
||||
workflow_app = SimpleNamespace(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
tenant_id="tenant-1",
|
||||
name="n",
|
||||
icon="i",
|
||||
icon_type="emoji",
|
||||
icon_background="#fff",
|
||||
description="d",
|
||||
use_icon_as_answer_icon=False,
|
||||
app_model_config=None,
|
||||
)
|
||||
AppDslService.export_dsl(workflow_app)
|
||||
assert workflow_calls == [True]
|
||||
|
||||
chat_app = SimpleNamespace(
|
||||
mode=AppMode.CHAT.value,
|
||||
tenant_id="tenant-1",
|
||||
name="n",
|
||||
icon="i",
|
||||
icon_type="emoji",
|
||||
icon_background="#fff",
|
||||
description="d",
|
||||
use_icon_as_answer_icon=False,
|
||||
app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}),
|
||||
)
|
||||
AppDslService.export_dsl(chat_app)
|
||||
assert model_calls == [True]
|
||||
|
||||
|
||||
def test_append_workflow_export_data_filters_and_overrides(monkeypatch):
|
||||
workflow_dict = {
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["d1", "d2"]}},
|
||||
{"data": {"type": NodeType.TOOL, "credential_id": "secret"}},
|
||||
{
|
||||
"data": {
|
||||
"type": NodeType.AGENT,
|
||||
"agent_parameters": {"tools": {"value": [{"credential_id": "secret"}]}},
|
||||
}
|
||||
},
|
||||
{"data": {"type": NodeType.TRIGGER_SCHEDULE.value, "config": {"x": 1}}},
|
||||
{"data": {"type": NodeType.TRIGGER_WEBHOOK.value, "webhook_url": "x", "webhook_debug_url": "y"}},
|
||||
{"data": {"type": NodeType.TRIGGER_PLUGIN.value, "subscription_id": "s"}},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
workflow = SimpleNamespace(to_dict=lambda *, include_secret: workflow_dict)
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = workflow
|
||||
monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AppDslService, "encrypt_dataset_id", lambda *, dataset_id, tenant_id: f"enc:{tenant_id}:{dataset_id}"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
TriggerScheduleNode := app_dsl_service.TriggerScheduleNode,
|
||||
"get_default_config",
|
||||
lambda: {"config": {"default": True}},
|
||||
)
|
||||
monkeypatch.setattr(AppDslService, "_extract_dependencies_from_workflow", lambda *_args, **_kwargs: ["dep-1"])
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"generate_dependencies",
|
||||
lambda *, tenant_id, dependencies: [
|
||||
SimpleNamespace(model_dump=lambda: {"tenant": tenant_id, "dep": dependencies[0]})
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x)
|
||||
|
||||
export_data: dict = {}
|
||||
AppDslService._append_workflow_export_data(
|
||||
export_data=export_data,
|
||||
app_model=SimpleNamespace(tenant_id="tenant-1"),
|
||||
include_secret=False,
|
||||
workflow_id=None,
|
||||
)
|
||||
|
||||
nodes = export_data["workflow"]["graph"]["nodes"]
|
||||
assert nodes[0]["data"]["dataset_ids"] == ["enc:tenant-1:d1", "enc:tenant-1:d2"]
|
||||
assert "credential_id" not in nodes[1]["data"]
|
||||
assert "credential_id" not in nodes[2]["data"]["agent_parameters"]["tools"]["value"][0]
|
||||
assert nodes[3]["data"]["config"] == {"default": True}
|
||||
assert nodes[4]["data"]["webhook_url"] == ""
|
||||
assert nodes[4]["data"]["webhook_debug_url"] == ""
|
||||
assert nodes[5]["data"]["subscription_id"] == ""
|
||||
assert export_data["dependencies"] == [{"tenant": "tenant-1", "dep": "dep-1"}]
|
||||
|
||||
|
||||
def test_append_workflow_export_data_missing_workflow_raises(monkeypatch):
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = None
|
||||
monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service)
|
||||
|
||||
with pytest.raises(ValueError, match="Missing draft workflow configuration"):
|
||||
AppDslService._append_workflow_export_data(
|
||||
export_data={},
|
||||
app_model=SimpleNamespace(tenant_id="tenant-1"),
|
||||
include_secret=False,
|
||||
workflow_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_append_model_config_export_data_filters_credential_id(monkeypatch):
|
||||
monkeypatch.setattr(AppDslService, "_extract_dependencies_from_model_config", lambda *_args, **_kwargs: ["dep-1"])
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"generate_dependencies",
|
||||
lambda *, tenant_id, dependencies: [
|
||||
SimpleNamespace(model_dump=lambda: {"tenant": tenant_id, "dep": dependencies[0]})
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x)
|
||||
|
||||
app_model_config = SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": [{"credential_id": "secret"}]}})
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", app_model_config=app_model_config)
|
||||
export_data: dict = {}
|
||||
|
||||
AppDslService._append_model_config_export_data(export_data, app_model)
|
||||
assert export_data["model_config"]["agent_mode"]["tools"] == [{}]
|
||||
assert export_data["dependencies"] == [{"tenant": "tenant-1", "dep": "dep-1"}]
|
||||
|
||||
|
||||
def test_append_model_config_export_data_requires_app_config():
|
||||
with pytest.raises(ValueError, match="Missing app configuration"):
|
||||
AppDslService._append_model_config_export_data({}, SimpleNamespace(app_model_config=None))
|
||||
|
||||
|
||||
def test_extract_dependencies_from_workflow_graph_covers_all_node_types(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"analyze_tool_dependency",
|
||||
lambda provider_id: f"tool:{provider_id}",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"analyze_model_provider_dependency",
|
||||
lambda provider: f"model:{provider}",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ToolNodeData, "model_validate", lambda _d: SimpleNamespace(provider_id="p1"))
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.LLMNodeData, "model_validate", lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m1"))
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.QuestionClassifierNodeData,
|
||||
"model_validate",
|
||||
lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m2")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.ParameterExtractorNodeData,
|
||||
"model_validate",
|
||||
lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m3")),
|
||||
)
|
||||
|
||||
def kr_validate(_d):
|
||||
return SimpleNamespace(
|
||||
retrieval_mode="multiple",
|
||||
multiple_retrieval_config=SimpleNamespace(
|
||||
reranking_mode="weighted_score",
|
||||
weights=SimpleNamespace(vector_setting=SimpleNamespace(embedding_provider_name="m4")),
|
||||
reranking_model=None,
|
||||
),
|
||||
single_retrieval_config=None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.KnowledgeRetrievalNodeData, "model_validate", kr_validate)
|
||||
|
||||
graph = {
|
||||
"nodes": [
|
||||
{"data": {"type": NodeType.TOOL}},
|
||||
{"data": {"type": NodeType.LLM}},
|
||||
{"data": {"type": NodeType.QUESTION_CLASSIFIER}},
|
||||
{"data": {"type": NodeType.PARAMETER_EXTRACTOR}},
|
||||
{"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL}},
|
||||
{"data": {"type": "unknown"}},
|
||||
]
|
||||
}
|
||||
|
||||
deps = AppDslService._extract_dependencies_from_workflow_graph(graph)
|
||||
assert deps == ["tool:p1", "model:m1", "model:m2", "model:m3", "model:m4"]
|
||||
|
||||
|
||||
def test_extract_dependencies_from_workflow_graph_handles_exceptions(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.ToolNodeData, "model_validate", lambda _d: (_ for _ in ()).throw(ValueError("bad"))
|
||||
)
|
||||
deps = AppDslService._extract_dependencies_from_workflow_graph({"nodes": [{"data": {"type": NodeType.TOOL}}]})
|
||||
assert deps == []
|
||||
|
||||
|
||||
def test_extract_dependencies_from_model_config_parses_providers(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"analyze_model_provider_dependency",
|
||||
lambda provider: f"model:{provider}",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"analyze_tool_dependency",
|
||||
lambda provider_id: f"tool:{provider_id}",
|
||||
)
|
||||
|
||||
deps = AppDslService._extract_dependencies_from_model_config(
|
||||
{
|
||||
"model": {"provider": "p1"},
|
||||
"dataset_configs": {
|
||||
"datasets": {"datasets": [{"reranking_model": {"reranking_provider_name": {"provider": "p2"}}}]}
|
||||
},
|
||||
"agent_mode": {"tools": [{"provider_id": "t1"}]},
|
||||
}
|
||||
)
|
||||
assert deps == ["model:p1", "model:p2", "tool:t1"]
|
||||
|
||||
|
||||
def test_extract_dependencies_from_model_config_handles_exceptions(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"analyze_model_provider_dependency",
|
||||
lambda _p: (_ for _ in ()).throw(ValueError("bad")),
|
||||
)
|
||||
deps = AppDslService._extract_dependencies_from_model_config({"model": {"provider": "p1"}})
|
||||
assert deps == []
|
||||
|
||||
|
||||
def test_get_leaked_dependencies_empty_returns_empty():
|
||||
assert AppDslService.get_leaked_dependencies("tenant-1", []) == []
|
||||
|
||||
|
||||
def test_get_leaked_dependencies_delegates(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"get_leaked_dependencies",
|
||||
lambda *, tenant_id, dependencies: [SimpleNamespace(tenant_id=tenant_id, deps=dependencies)],
|
||||
)
|
||||
res = AppDslService.get_leaked_dependencies("tenant-1", [SimpleNamespace(id="x")])
|
||||
assert len(res) == 1
|
||||
|
||||
|
||||
def test_encrypt_decrypt_dataset_id_respects_config(monkeypatch):
|
||||
tenant_id = "tenant-1"
|
||||
dataset_uuid = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", False)
|
||||
assert AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) == dataset_uuid
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True)
|
||||
encrypted = AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id)
|
||||
assert encrypted != dataset_uuid
|
||||
assert base64.b64decode(encrypted.encode())
|
||||
assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=tenant_id) == dataset_uuid
|
||||
|
||||
|
||||
def test_decrypt_dataset_id_returns_plain_uuid_unchanged():
|
||||
value = "00000000-0000-0000-0000-000000000000"
|
||||
assert AppDslService.decrypt_dataset_id(encrypted_data=value, tenant_id="tenant-1") == value
|
||||
|
||||
|
||||
def test_decrypt_dataset_id_returns_none_on_invalid_data(monkeypatch):
|
||||
monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True)
|
||||
assert AppDslService.decrypt_dataset_id(encrypted_data="not-base64", tenant_id="tenant-1") is None
|
||||
|
||||
|
||||
def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(monkeypatch):
|
||||
monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True)
|
||||
encrypted = AppDslService.encrypt_dataset_id(dataset_id="not-a-uuid", tenant_id="tenant-1")
|
||||
assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id="tenant-1") is None
|
||||
|
||||
|
||||
def test_is_valid_uuid_handles_bad_inputs():
|
||||
assert AppDslService._is_valid_uuid("00000000-0000-0000-0000-000000000000") is True
|
||||
assert AppDslService._is_valid_uuid("nope") is False
|
||||
@ -1,14 +1,50 @@
|
||||
"""
|
||||
Comprehensive unit tests for services.app_generate_service.AppGenerateService.
|
||||
|
||||
Covers:
|
||||
- _build_streaming_task_on_subscribe (streams / pubsub / exception / idempotency)
|
||||
- generate (COMPLETION / AGENT_CHAT / CHAT / ADVANCED_CHAT / WORKFLOW / invalid mode,
|
||||
streaming & blocking, billing, quota-refund-on-error, rate_limit.exit)
|
||||
- _get_max_active_requests (all limit combos)
|
||||
- generate_single_iteration (ADVANCED_CHAT / WORKFLOW / invalid mode)
|
||||
- generate_single_loop (ADVANCED_CHAT / WORKFLOW / invalid mode)
|
||||
- generate_more_like_this
|
||||
- _get_workflow (debugger / non-debugger / specific id / invalid format / not found)
|
||||
- get_response_generator (ended / non-ended workflow run)
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import services.app_generate_service as app_generate_service_module
|
||||
import pytest
|
||||
|
||||
import services.app_generate_service as ags_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
class _DummyRateLimit:
|
||||
"""Minimal stand-in for RateLimit that never touches Redis."""
|
||||
|
||||
_instance_dict: dict[str, "_DummyRateLimit"] = {}
|
||||
|
||||
def __new__(cls, client_id: str, max_active_requests: int):
|
||||
# avoid singleton caching across tests
|
||||
instance = object.__new__(cls)
|
||||
return instance
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int) -> None:
|
||||
self.client_id = client_id
|
||||
self.max_active_requests = max_active_requests
|
||||
self._exited: list[str] = []
|
||||
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
@ -18,101 +54,720 @@ class _DummyRateLimit:
|
||||
return request_id or "dummy-request-id"
|
||||
|
||||
def exit(self, request_id: str) -> None:
|
||||
return None
|
||||
self._exited.append(request_id)
|
||||
|
||||
def generate(self, generator, request_id: str):
|
||||
return generator
|
||||
|
||||
|
||||
def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch):
|
||||
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
|
||||
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
|
||||
def _make_app(mode: AppMode | str, *, max_active_requests: int = 0, is_agent: bool = False) -> MagicMock:
|
||||
app = MagicMock()
|
||||
app.mode = mode
|
||||
app.id = "app-id"
|
||||
app.tenant_id = "tenant-id"
|
||||
app.max_active_requests = max_active_requests
|
||||
app.is_agent = is_agent
|
||||
return app
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-id"
|
||||
workflow.created_by = "owner-id"
|
||||
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
|
||||
generator_spy = mocker.patch(
|
||||
"services.app_generate_service.WorkflowAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
app_model = MagicMock()
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
app_model.id = "app-id"
|
||||
app_model.tenant_id = "tenant-id"
|
||||
app_model.max_active_requests = 0
|
||||
app_model.is_agent = False
|
||||
|
||||
def _make_user() -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = "user-id"
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"inputs": {"k": "v"}},
|
||||
invoke_from=MagicMock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert pause_state_config is not None
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
return user
|
||||
|
||||
|
||||
def test_advanced_chat_blocking_returns_dict_and_does_not_use_event_retrieval(mocker, monkeypatch):
|
||||
"""
|
||||
Regression test: ADVANCED_CHAT in blocking mode should return a plain dict
|
||||
(non-streaming), and must not go through the async retrieve_events path.
|
||||
Keeps behavior consistent with WORKFLOW blocking branch.
|
||||
"""
|
||||
# Disable billing and stub RateLimit to a no-op that just passes values through
|
||||
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
|
||||
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
|
||||
|
||||
# Arrange a fake workflow and wire AppGenerateService._get_workflow to return it
|
||||
def _make_workflow(*, workflow_id: str = "workflow-id", created_by: str = "owner-id") -> MagicMock:
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-id"
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
workflow.id = workflow_id
|
||||
workflow.created_by = created_by
|
||||
return workflow
|
||||
|
||||
# Spy on the streaming retrieval path to ensure it's NOT called
|
||||
retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events")
|
||||
|
||||
# Make AdvancedChatAppGenerator.generate return a plain dict when streaming=False
|
||||
generate_spy = mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
@contextmanager
|
||||
def _noop_rate_limit_context(rate_limit, request_id):
|
||||
"""Drop-in replacement for rate_limit_context that doesn't touch Redis."""
|
||||
yield
|
||||
|
||||
# Minimal app model for ADVANCED_CHAT
|
||||
app_model = MagicMock()
|
||||
app_model.mode = AppMode.ADVANCED_CHAT
|
||||
app_model.id = "app-id"
|
||||
app_model.tenant_id = "tenant-id"
|
||||
app_model.max_active_requests = 0
|
||||
app_model.is_agent = False
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "user-id"
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_streaming_task_on_subscribe
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestBuildStreamingTaskOnSubscribe:
|
||||
"""Tests for AppGenerateService._build_streaming_task_on_subscribe."""
|
||||
|
||||
# Must include query and inputs for AdvancedChatAppGenerator
|
||||
args = {"workflow_id": "wf-1", "query": "hello", "inputs": {}}
|
||||
def test_streams_mode_starts_immediately(self, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams")
|
||||
called = []
|
||||
cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1))
|
||||
# task started immediately during build
|
||||
assert called == [1]
|
||||
# calling the returned callback is idempotent
|
||||
cb()
|
||||
assert called == [1] # not called again
|
||||
|
||||
# Act: call service with streaming=False (blocking mode)
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=MagicMock(),
|
||||
streaming=False,
|
||||
)
|
||||
def test_pubsub_mode_starts_on_subscribe(self, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
|
||||
monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) # large to prevent timer
|
||||
called = []
|
||||
cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1))
|
||||
assert called == []
|
||||
cb()
|
||||
assert called == [1]
|
||||
# second call is idempotent
|
||||
cb()
|
||||
assert called == [1]
|
||||
|
||||
# Assert: returns the dict from generate(), and did not call retrieve_events()
|
||||
assert result == {"result": "ok"}
|
||||
assert generate_spy.call_args.kwargs.get("streaming") is False
|
||||
retrieve_spy.assert_not_called()
|
||||
def test_sharded_mode_starts_on_subscribe(self, monkeypatch):
|
||||
"""sharded is treated like pubsub (i.e. not 'streams')."""
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded")
|
||||
monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000)
|
||||
called = []
|
||||
cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1))
|
||||
assert called == []
|
||||
cb()
|
||||
assert called == [1]
|
||||
|
||||
def test_pubsub_fallback_timer_fires(self, monkeypatch):
|
||||
"""When nobody subscribes fast enough the fallback timer fires."""
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
|
||||
monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 50) # 50 ms
|
||||
called = []
|
||||
_cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1))
|
||||
time.sleep(0.2) # give the timer time to fire
|
||||
assert called == [1]
|
||||
|
||||
def test_exception_in_start_task_returns_false(self, monkeypatch):
|
||||
"""When start_task raises, _try_start returns False and next call retries."""
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams")
|
||||
call_count = 0
|
||||
|
||||
def _bad():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
cb = AppGenerateService._build_streaming_task_on_subscribe(_bad)
|
||||
# first call inside build raised, but is caught; second call via cb succeeds
|
||||
assert call_count == 1
|
||||
cb()
|
||||
assert call_count == 2
|
||||
|
||||
def test_concurrent_subscribe_only_starts_once(self, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
|
||||
monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000)
|
||||
call_count = 0
|
||||
|
||||
def _inc():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
cb = AppGenerateService._build_streaming_task_on_subscribe(_inc)
|
||||
threads = [threading.Thread(target=cb) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_max_active_requests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGetMaxActiveRequests:
|
||||
def test_both_zero_returns_zero(self, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0)
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0)
|
||||
app = _make_app(AppMode.CHAT, max_active_requests=0)
|
||||
assert AppGenerateService._get_max_active_requests(app) == 0
|
||||
|
||||
def test_app_limit_only(self, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0)
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0)
|
||||
app = _make_app(AppMode.CHAT, max_active_requests=5)
|
||||
assert AppGenerateService._get_max_active_requests(app) == 5
|
||||
|
||||
def test_config_limit_only(self, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 10)
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0)
|
||||
app = _make_app(AppMode.CHAT, max_active_requests=0)
|
||||
assert AppGenerateService._get_max_active_requests(app) == 10
|
||||
|
||||
def test_both_non_zero_returns_min(self, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 20)
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0)
|
||||
app = _make_app(AppMode.CHAT, max_active_requests=5)
|
||||
assert AppGenerateService._get_max_active_requests(app) == 5
|
||||
|
||||
def test_default_active_requests_used_when_app_has_none(self, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0)
|
||||
monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 15)
|
||||
app = _make_app(AppMode.CHAT, max_active_requests=0)
|
||||
assert AppGenerateService._get_max_active_requests(app) == 15
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate – every AppMode branch
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGenerate:
|
||||
"""Tests for AppGenerateService.generate covering each mode."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _common(self, mocker, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
|
||||
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
|
||||
# Prevent AppExecutionParams.new from touching real models via isinstance
|
||||
mocker.patch(
|
||||
"services.app_generate_service.rate_limit_context",
|
||||
_noop_rate_limit_context,
|
||||
)
|
||||
|
||||
# -- COMPLETION ---------------------------------------------------------
|
||||
def test_completion_mode(self, mocker):
|
||||
gen_spy = mocker.patch(
|
||||
"services.app_generate_service.CompletionAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.CompletionAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
result = AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.COMPLETION),
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
assert result == {"result": "ok"}
|
||||
gen_spy.assert_called_once()
|
||||
|
||||
# -- AGENT_CHAT via mode ------------------------------------------------
|
||||
def test_agent_chat_mode(self, mocker):
|
||||
gen_spy = mocker.patch(
|
||||
"services.app_generate_service.AgentChatAppGenerator.generate",
|
||||
return_value={"result": "agent"},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AgentChatAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
result = AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.AGENT_CHAT),
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
assert result == {"result": "agent"}
|
||||
gen_spy.assert_called_once()
|
||||
|
||||
# -- AGENT_CHAT via is_agent flag (non-AGENT_CHAT mode) -----------------
|
||||
def test_agent_via_is_agent_flag(self, mocker):
|
||||
gen_spy = mocker.patch(
|
||||
"services.app_generate_service.AgentChatAppGenerator.generate",
|
||||
return_value={"result": "agent-via-flag"},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AgentChatAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
app = _make_app(AppMode.CHAT, is_agent=True)
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
assert result == {"result": "agent-via-flag"}
|
||||
gen_spy.assert_called_once()
|
||||
|
||||
# -- CHAT ---------------------------------------------------------------
|
||||
def test_chat_mode(self, mocker):
|
||||
gen_spy = mocker.patch(
|
||||
"services.app_generate_service.ChatAppGenerator.generate",
|
||||
return_value={"result": "chat"},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.ChatAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
app = _make_app(AppMode.CHAT, is_agent=False)
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
assert result == {"result": "chat"}
|
||||
gen_spy.assert_called_once()
|
||||
|
||||
# -- ADVANCED_CHAT blocking ---------------------------------------------
|
||||
def test_advanced_chat_blocking(self, mocker):
|
||||
workflow = _make_workflow()
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
|
||||
retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events")
|
||||
gen_spy = mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.generate",
|
||||
return_value={"result": "advanced-blocking"},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.ADVANCED_CHAT),
|
||||
user=_make_user(),
|
||||
args={"workflow_id": None, "query": "hi", "inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
assert result == {"result": "advanced-blocking"}
|
||||
assert gen_spy.call_args.kwargs.get("streaming") is False
|
||||
retrieve_spy.assert_not_called()
|
||||
|
||||
# -- ADVANCED_CHAT streaming --------------------------------------------
|
||||
def test_advanced_chat_streaming(self, mocker, monkeypatch):
|
||||
workflow = _make_workflow()
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AppExecutionParams.new",
|
||||
return_value=MagicMock(workflow_run_id="wfr-1", model_dump_json=MagicMock(return_value="{}")),
|
||||
)
|
||||
delay_spy = mocker.patch("services.app_generate_service.workflow_based_app_execution_task.delay")
|
||||
# Let _build_streaming_task_on_subscribe call the real on_subscribe
|
||||
# so the inner closure (line 165) actually executes.
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams")
|
||||
gen_instance = MagicMock()
|
||||
gen_instance.retrieve_events.return_value = iter([])
|
||||
gen_instance.convert_to_event_stream.side_effect = lambda x: x
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator",
|
||||
return_value=gen_instance,
|
||||
)
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.ADVANCED_CHAT),
|
||||
user=_make_user(),
|
||||
args={"workflow_id": None, "query": "hi", "inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=True,
|
||||
)
|
||||
# In streaming mode it should go through retrieve_events, not generate
|
||||
gen_instance.retrieve_events.assert_called_once()
|
||||
# The inner on_subscribe closure was invoked by _build_streaming_task_on_subscribe
|
||||
delay_spy.assert_called_once()
|
||||
|
||||
# -- WORKFLOW blocking --------------------------------------------------
|
||||
def test_workflow_blocking(self, mocker):
|
||||
workflow = _make_workflow()
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
gen_spy = mocker.patch(
|
||||
"services.app_generate_service.WorkflowAppGenerator.generate",
|
||||
return_value={"result": "workflow-blocking"},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.WorkflowAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.WORKFLOW),
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
assert result == {"result": "workflow-blocking"}
|
||||
call_kwargs = gen_spy.call_args.kwargs
|
||||
assert call_kwargs.get("pause_state_config") is not None
|
||||
assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id"
|
||||
|
||||
# -- WORKFLOW streaming -------------------------------------------------
|
||||
def test_workflow_streaming(self, mocker, monkeypatch):
|
||||
workflow = _make_workflow()
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AppExecutionParams.new",
|
||||
return_value=MagicMock(workflow_run_id="wfr-2", model_dump_json=MagicMock(return_value="{}")),
|
||||
)
|
||||
delay_spy = mocker.patch("services.app_generate_service.workflow_based_app_execution_task.delay")
|
||||
# Let _build_streaming_task_on_subscribe invoke the real on_subscribe
|
||||
# so the inner closure (line 216) actually executes.
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams")
|
||||
retrieve_spy = mocker.patch(
|
||||
"services.app_generate_service.MessageBasedAppGenerator.retrieve_events",
|
||||
return_value=iter([]),
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.WorkflowAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.WORKFLOW),
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=True,
|
||||
)
|
||||
retrieve_spy.assert_called_once()
|
||||
# The inner on_subscribe closure was invoked by _build_streaming_task_on_subscribe
|
||||
delay_spy.assert_called_once()
|
||||
|
||||
# -- Invalid mode -------------------------------------------------------
|
||||
def test_invalid_mode_raises(self, mocker):
|
||||
app = _make_app("invalid-mode", is_agent=False)
|
||||
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||
AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=_make_user(),
|
||||
args={},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate – billing / quota
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGenerateBilling:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _common(self, mocker, monkeypatch):
|
||||
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.rate_limit_context",
|
||||
_noop_rate_limit_context,
|
||||
)
|
||||
|
||||
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
consume_mock = mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.CompletionAppGenerator.generate",
|
||||
return_value={"ok": True},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.CompletionAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
|
||||
AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.COMPLETION),
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
consume_mock.assert_called_once_with("tenant-id")
|
||||
|
||||
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
|
||||
)
|
||||
|
||||
with pytest.raises(InvokeRateLimitError):
|
||||
AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.COMPLETION),
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_exception_refunds_quota_and_exits_rate_limit(self, mocker, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.CompletionAppGenerator.generate",
|
||||
side_effect=RuntimeError("boom"),
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.CompletionAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.COMPLETION),
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
quota_charge.refund.assert_called_once()
|
||||
|
||||
def test_rate_limit_exit_called_in_finally_for_blocking(self, mocker, monkeypatch):
|
||||
"""For non-streaming (blocking) calls, rate_limit.exit should be called in finally."""
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
exit_calls: list[str] = []
|
||||
|
||||
class _TrackingRateLimit(_DummyRateLimit):
|
||||
def exit(self, request_id: str) -> None:
|
||||
exit_calls.append(request_id)
|
||||
|
||||
mocker.patch("services.app_generate_service.RateLimit", _TrackingRateLimit)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.CompletionAppGenerator.generate",
|
||||
return_value={"ok": True},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.CompletionAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
|
||||
AppGenerateService.generate(
|
||||
app_model=_make_app(AppMode.COMPLETION),
|
||||
user=_make_user(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
# exit is called in finally block for non-streaming
|
||||
assert len(exit_calls) >= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_workflow
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGetWorkflow:
|
||||
def test_debugger_fetches_draft(self, mocker):
|
||||
draft_wf = _make_workflow()
|
||||
ws = MagicMock()
|
||||
ws.get_draft_workflow.return_value = draft_wf
|
||||
mocker.patch("services.app_generate_service.WorkflowService", return_value=ws)
|
||||
|
||||
result = AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.DEBUGGER)
|
||||
assert result is draft_wf
|
||||
ws.get_draft_workflow.assert_called_once()
|
||||
|
||||
def test_debugger_raises_when_no_draft(self, mocker):
|
||||
ws = MagicMock()
|
||||
ws.get_draft_workflow.return_value = None
|
||||
mocker.patch("services.app_generate_service.WorkflowService", return_value=ws)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not initialized"):
|
||||
AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.DEBUGGER)
|
||||
|
||||
def test_non_debugger_fetches_published(self, mocker):
|
||||
pub_wf = _make_workflow()
|
||||
ws = MagicMock()
|
||||
ws.get_published_workflow.return_value = pub_wf
|
||||
mocker.patch("services.app_generate_service.WorkflowService", return_value=ws)
|
||||
|
||||
result = AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API)
|
||||
assert result is pub_wf
|
||||
ws.get_published_workflow.assert_called_once()
|
||||
|
||||
def test_non_debugger_raises_when_no_published(self, mocker):
|
||||
ws = MagicMock()
|
||||
ws.get_published_workflow.return_value = None
|
||||
mocker.patch("services.app_generate_service.WorkflowService", return_value=ws)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not published"):
|
||||
AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API)
|
||||
|
||||
def test_specific_workflow_id_valid_uuid(self, mocker):
|
||||
valid_uuid = str(uuid.uuid4())
|
||||
specific_wf = _make_workflow(workflow_id=valid_uuid)
|
||||
ws = MagicMock()
|
||||
ws.get_published_workflow_by_id.return_value = specific_wf
|
||||
mocker.patch("services.app_generate_service.WorkflowService", return_value=ws)
|
||||
|
||||
result = AppGenerateService._get_workflow(
|
||||
_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id=valid_uuid
|
||||
)
|
||||
assert result is specific_wf
|
||||
ws.get_published_workflow_by_id.assert_called_once()
|
||||
|
||||
def test_specific_workflow_id_invalid_uuid(self, mocker):
|
||||
ws = MagicMock()
|
||||
mocker.patch("services.app_generate_service.WorkflowService", return_value=ws)
|
||||
|
||||
with pytest.raises(WorkflowIdFormatError):
|
||||
AppGenerateService._get_workflow(
|
||||
_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id="not-a-uuid"
|
||||
)
|
||||
|
||||
def test_specific_workflow_id_not_found(self, mocker):
|
||||
valid_uuid = str(uuid.uuid4())
|
||||
ws = MagicMock()
|
||||
ws.get_published_workflow_by_id.return_value = None
|
||||
mocker.patch("services.app_generate_service.WorkflowService", return_value=ws)
|
||||
|
||||
with pytest.raises(WorkflowNotFoundError):
|
||||
AppGenerateService._get_workflow(
|
||||
_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id=valid_uuid
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_single_iteration
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGenerateSingleIteration:
|
||||
def test_advanced_chat_mode(self, mocker):
|
||||
workflow = _make_workflow()
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
gen_spy = mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
iter_spy = mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.single_iteration_generate",
|
||||
return_value={"event": "iteration"},
|
||||
)
|
||||
app = _make_app(AppMode.ADVANCED_CHAT)
|
||||
result = AppGenerateService.generate_single_iteration(
|
||||
app_model=app, user=_make_user(), node_id="n1", args={"k": "v"}
|
||||
)
|
||||
iter_spy.assert_called_once()
|
||||
assert result == {"event": "iteration"}
|
||||
|
||||
def test_workflow_mode(self, mocker):
|
||||
workflow = _make_workflow()
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
iter_spy = mocker.patch(
|
||||
"services.app_generate_service.WorkflowAppGenerator.single_iteration_generate",
|
||||
return_value={"event": "wf-iteration"},
|
||||
)
|
||||
app = _make_app(AppMode.WORKFLOW)
|
||||
result = AppGenerateService.generate_single_iteration(
|
||||
app_model=app, user=_make_user(), node_id="n1", args={"k": "v"}
|
||||
)
|
||||
iter_spy.assert_called_once()
|
||||
assert result == {"event": "wf-iteration"}
|
||||
|
||||
def test_invalid_mode_raises(self, mocker):
|
||||
app = _make_app(AppMode.CHAT)
|
||||
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||
AppGenerateService.generate_single_iteration(app_model=app, user=_make_user(), node_id="n1", args={})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_single_loop
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGenerateSingleLoop:
|
||||
def test_advanced_chat_mode(self, mocker):
|
||||
workflow = _make_workflow()
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
loop_spy = mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.single_loop_generate",
|
||||
return_value={"event": "loop"},
|
||||
)
|
||||
app = _make_app(AppMode.ADVANCED_CHAT)
|
||||
result = AppGenerateService.generate_single_loop(
|
||||
app_model=app, user=_make_user(), node_id="n1", args=MagicMock()
|
||||
)
|
||||
loop_spy.assert_called_once()
|
||||
assert result == {"event": "loop"}
|
||||
|
||||
def test_workflow_mode(self, mocker):
|
||||
workflow = _make_workflow()
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream",
|
||||
side_effect=lambda x: x,
|
||||
)
|
||||
loop_spy = mocker.patch(
|
||||
"services.app_generate_service.WorkflowAppGenerator.single_loop_generate",
|
||||
return_value={"event": "wf-loop"},
|
||||
)
|
||||
app = _make_app(AppMode.WORKFLOW)
|
||||
result = AppGenerateService.generate_single_loop(
|
||||
app_model=app, user=_make_user(), node_id="n1", args=MagicMock()
|
||||
)
|
||||
loop_spy.assert_called_once()
|
||||
assert result == {"event": "wf-loop"}
|
||||
|
||||
def test_invalid_mode_raises(self, mocker):
|
||||
app = _make_app(AppMode.COMPLETION)
|
||||
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||
AppGenerateService.generate_single_loop(app_model=app, user=_make_user(), node_id="n1", args=MagicMock())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_more_like_this
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGenerateMoreLikeThis:
|
||||
def test_delegates_to_completion_generator(self, mocker):
|
||||
gen_spy = mocker.patch(
|
||||
"services.app_generate_service.CompletionAppGenerator.generate_more_like_this",
|
||||
return_value={"result": "similar"},
|
||||
)
|
||||
result = AppGenerateService.generate_more_like_this(
|
||||
app_model=_make_app(AppMode.COMPLETION),
|
||||
user=_make_user(),
|
||||
message_id="msg-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=True,
|
||||
)
|
||||
assert result == {"result": "similar"}
|
||||
gen_spy.assert_called_once()
|
||||
assert gen_spy.call_args.kwargs["stream"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_response_generator
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGetResponseGenerator:
|
||||
def test_non_ended_workflow_run(self, mocker):
|
||||
app = _make_app(AppMode.ADVANCED_CHAT)
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = "run-1"
|
||||
workflow_run.status.is_ended.return_value = False
|
||||
|
||||
gen_instance = MagicMock()
|
||||
gen_instance.retrieve_events.return_value = iter([{"event": "started"}])
|
||||
gen_instance.convert_to_event_stream.side_effect = lambda x: x
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator",
|
||||
return_value=gen_instance,
|
||||
)
|
||||
|
||||
result = AppGenerateService.get_response_generator(app_model=app, workflow_run=workflow_run)
|
||||
gen_instance.retrieve_events.assert_called_once()
|
||||
|
||||
def test_ended_workflow_run_still_returns_generator(self, mocker):
|
||||
"""Even when the run is ended, the current code still returns a generator (TODO branch)."""
|
||||
app = _make_app(AppMode.WORKFLOW)
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = "run-2"
|
||||
workflow_run.status.is_ended.return_value = True
|
||||
|
||||
gen_instance = MagicMock()
|
||||
gen_instance.retrieve_events.return_value = iter([])
|
||||
gen_instance.convert_to_event_stream.side_effect = lambda x: x
|
||||
mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator",
|
||||
return_value=gen_instance,
|
||||
)
|
||||
|
||||
result = AppGenerateService.get_response_generator(app_model=app, workflow_run=workflow_run)
|
||||
# current impl falls through the TODO and still creates a generator
|
||||
gen_instance.retrieve_events.assert_called_once()
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services import clear_free_plan_tenant_expired_logs as service_module
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
|
||||
|
||||
@ -156,13 +159,453 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
# Should call delete for each table that has records
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_logging_output(
|
||||
self, mock_session, sample_message_ids, sample_records, capsys
|
||||
def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes(
|
||||
self, mock_session, sample_message_ids
|
||||
):
|
||||
"""Test that logging output is generated."""
|
||||
record = Mock()
|
||||
record.id = "record-1"
|
||||
record.to_dict.side_effect = Exception("Serialization error")
|
||||
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [record]
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
pass
|
||||
mock_storage.save.assert_not_called()
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
|
||||
|
||||
class _ImmediateFuture:
|
||||
def __init__(self, fn, args, kwargs):
|
||||
self._fn = fn
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
def result(self):
|
||||
return self._fn(*self._args, **self._kwargs)
|
||||
|
||||
|
||||
class _ImmediateExecutor:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.submitted: list[tuple[object, tuple[object, ...], dict[str, object]]] = []
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
self.submitted.append((fn, args, kwargs))
|
||||
return _ImmediateFuture(fn, args, kwargs)
|
||||
|
||||
|
||||
def _session_wrapper_for_no_autoflush(session: Mock) -> Mock:
|
||||
"""
|
||||
ClearFreePlanTenantExpiredLogs.process_tenant uses:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
so Session(db.engine) must return an object with a no_autoflush context manager.
|
||||
"""
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
|
||||
wrapper = MagicMock()
|
||||
wrapper.no_autoflush = cm
|
||||
return wrapper
|
||||
|
||||
|
||||
def _session_wrapper_for_direct(session: Mock) -> Mock:
|
||||
"""ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session:"""
|
||||
wrapper = MagicMock()
|
||||
wrapper.__enter__.return_value = session
|
||||
wrapper.__exit__.return_value = None
|
||||
return wrapper
|
||||
|
||||
|
||||
def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
flask_app = service_module.Flask("test-app")
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"db",
|
||||
SimpleNamespace(
|
||||
engine=object(),
|
||||
session=SimpleNamespace(
|
||||
scalars=lambda _stmt: SimpleNamespace(
|
||||
all=lambda: [SimpleNamespace(id="app-1"), SimpleNamespace(id="app-2")]
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(service_module, "storage", mock_storage)
|
||||
monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg)
|
||||
|
||||
clear_related = MagicMock()
|
||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "_clear_message_related_tables", clear_related)
|
||||
|
||||
# Session sequence for messages, conversations, workflow_app_logs loops:
|
||||
# - messages: one batch then empty
|
||||
# - conversations: one batch then empty
|
||||
# - workflow app logs: one batch then empty
|
||||
msg1 = SimpleNamespace(id="m1", to_dict=lambda: {"id": "m1"})
|
||||
conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"})
|
||||
log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"})
|
||||
|
||||
def make_query_with_batches(batches: list[list[object]]):
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.limit.return_value = q
|
||||
q.all.side_effect = batches
|
||||
q.delete.return_value = 1
|
||||
return q
|
||||
|
||||
msg_session_1 = MagicMock()
|
||||
msg_session_1.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_1.commit.return_value = None
|
||||
|
||||
msg_session_2 = MagicMock()
|
||||
msg_session_2.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_2.commit.return_value = None
|
||||
|
||||
conv_session_1 = MagicMock()
|
||||
conv_session_1.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_1.commit.return_value = None
|
||||
|
||||
conv_session_2 = MagicMock()
|
||||
conv_session_2.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_2.commit.return_value = None
|
||||
|
||||
wal_session_1 = MagicMock()
|
||||
wal_session_1.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_1.commit.return_value = None
|
||||
|
||||
wal_session_2 = MagicMock()
|
||||
wal_session_2.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_2.commit.return_value = None
|
||||
|
||||
session_wrappers = [
|
||||
_session_wrapper_for_no_autoflush(msg_session_1),
|
||||
_session_wrapper_for_no_autoflush(msg_session_2),
|
||||
_session_wrapper_for_no_autoflush(conv_session_1),
|
||||
_session_wrapper_for_no_autoflush(conv_session_2),
|
||||
_session_wrapper_for_no_autoflush(wal_session_1),
|
||||
_session_wrapper_for_no_autoflush(wal_session_2),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0))
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
stmt = MagicMock()
|
||||
stmt.where.return_value = stmt
|
||||
return stmt
|
||||
|
||||
monkeypatch.setattr(service_module, "select", fake_select)
|
||||
|
||||
# Repositories for workflow node executions and workflow runs
|
||||
node_repo = MagicMock()
|
||||
node_repo.get_expired_executions_batch.side_effect = [[SimpleNamespace(id="ne-1")], []]
|
||||
node_repo.delete_executions_by_ids.return_value = 1
|
||||
|
||||
run_repo = MagicMock()
|
||||
run_repo.get_expired_runs_batch.side_effect = [[SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"})], []]
|
||||
run_repo.delete_runs_by_ids.return_value = 1
|
||||
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object())
|
||||
monkeypatch.setattr(
|
||||
service_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_node_execution_repository",
|
||||
lambda _sm: node_repo,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda _sm: run_repo,
|
||||
)
|
||||
|
||||
ClearFreePlanTenantExpiredLogs.process_tenant(flask_app, "tenant-1", days=7, batch=10)
|
||||
|
||||
# messages backup, conversations backup, node executions backup, runs backup, workflow app logs backup
|
||||
assert mock_storage.save.call_count >= 5
|
||||
clear_related.assert_called()
|
||||
|
||||
|
||||
def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
# Total tenant count query
|
||||
count_session = MagicMock()
|
||||
count_query = MagicMock()
|
||||
count_query.count.return_value = 2
|
||||
count_session.query.return_value = count_query
|
||||
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session))
|
||||
|
||||
# Avoid LocalProxy usage
|
||||
flask_app = service_module.Flask("test-app")
|
||||
monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app))
|
||||
|
||||
executor = _ImmediateExecutor()
|
||||
monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor)
|
||||
|
||||
monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg)
|
||||
echo_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.click, "echo", echo_mock)
|
||||
|
||||
monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", True)
|
||||
|
||||
def fake_get_info(tenant_id: str):
|
||||
if tenant_id == "t_sandbox":
|
||||
return {"subscription": {"plan": CloudPlan.SANDBOX}}
|
||||
if tenant_id == "t_fail":
|
||||
raise RuntimeError("boom")
|
||||
return {"subscription": {"plan": "team"}}
|
||||
|
||||
monkeypatch.setattr(service_module.BillingService, "get_info", staticmethod(fake_get_info))
|
||||
|
||||
process_tenant_mock = MagicMock(side_effect=lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("err")))
|
||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
|
||||
|
||||
logger_exc = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exc)
|
||||
|
||||
ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=["t_sandbox", "t_paid", "t_fail"])
|
||||
|
||||
# Only sandbox tenant should attempt processing, and its failure should be swallowed + logged.
|
||||
assert process_tenant_mock.call_count == 1
|
||||
assert logger_exc.call_count >= 1
|
||||
|
||||
|
||||
def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
||||
fixed_now = started_at + datetime.timedelta(hours=2)
|
||||
|
||||
class FixedDateTime(datetime.datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return fixed_now
|
||||
|
||||
monkeypatch.setattr(service_module.datetime, "datetime", FixedDateTime)
|
||||
|
||||
# Avoid LocalProxy usage
|
||||
flask_app = service_module.Flask("test-app")
|
||||
monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app))
|
||||
|
||||
executor = _ImmediateExecutor()
|
||||
monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor)
|
||||
|
||||
monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg)
|
||||
monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None)
|
||||
|
||||
# Sessions used:
|
||||
# 1) total tenant count
|
||||
# 2) per-batch tenant scan (count + tenant list)
|
||||
total_session = MagicMock()
|
||||
total_query = MagicMock()
|
||||
total_query.count.return_value = 250
|
||||
total_session.query.return_value = total_query
|
||||
|
||||
batch_session = MagicMock()
|
||||
q1 = MagicMock()
|
||||
q1.where.return_value = q1
|
||||
q1.count.return_value = 200
|
||||
q2 = MagicMock()
|
||||
q2.where.return_value = q2
|
||||
q2.count.return_value = 200
|
||||
q3 = MagicMock()
|
||||
q3.where.return_value = q3
|
||||
q3.count.return_value = 200
|
||||
q4 = MagicMock()
|
||||
q4.where.return_value = q4
|
||||
q4.count.return_value = 50 # choose this interval, then scale it
|
||||
|
||||
rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")]
|
||||
q_rs = MagicMock()
|
||||
q_rs.where.return_value = q_rs
|
||||
q_rs.order_by.return_value = rows
|
||||
|
||||
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
|
||||
|
||||
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)]
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0))
|
||||
|
||||
process_tenant_mock = MagicMock()
|
||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
|
||||
|
||||
ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[])
|
||||
|
||||
# Should submit/process tenants from the batch query
|
||||
assert process_tenant_mock.call_count == 2
|
||||
|
||||
|
||||
def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
count_session = MagicMock()
|
||||
count_query = MagicMock()
|
||||
count_query.count.return_value = 100
|
||||
count_session.query.return_value = count_query
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session))
|
||||
|
||||
flask_app = service_module.Flask("test-app")
|
||||
monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app))
|
||||
monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
executor = _ImmediateExecutor()
|
||||
monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor)
|
||||
|
||||
echo_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg)
|
||||
monkeypatch.setattr(service_module.click, "echo", echo_mock)
|
||||
|
||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", MagicMock())
|
||||
|
||||
tenant_ids = [f"t{i}" for i in range(100)]
|
||||
ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=tenant_ids)
|
||||
|
||||
assert any("Processed 100 tenants" in str(call.args[0]) for call in echo_mock.call_args_list)
|
||||
|
||||
|
||||
def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
||||
# Keep the total range smaller than the minimum interval (1 hour) so the loop runs once.
|
||||
fixed_now = started_at + datetime.timedelta(minutes=30)
|
||||
|
||||
class FixedDateTime(datetime.datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return fixed_now
|
||||
|
||||
monkeypatch.setattr(service_module.datetime, "datetime", FixedDateTime)
|
||||
|
||||
flask_app = service_module.Flask("test-app")
|
||||
monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app))
|
||||
|
||||
executor = _ImmediateExecutor()
|
||||
monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor)
|
||||
|
||||
monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg)
|
||||
monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None)
|
||||
|
||||
total_session = MagicMock()
|
||||
total_query = MagicMock()
|
||||
total_query.count.return_value = 250
|
||||
total_session.query.return_value = total_query
|
||||
|
||||
batch_session = MagicMock()
|
||||
# Count results for all 5 intervals, all > 100 => take the for-else path.
|
||||
count_queries = []
|
||||
for _ in range(5):
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.count.return_value = 200
|
||||
count_queries.append(q)
|
||||
|
||||
rows = [SimpleNamespace(id="tenant-a")]
|
||||
q_rs = MagicMock()
|
||||
q_rs.where.return_value = q_rs
|
||||
q_rs.order_by.return_value = rows
|
||||
|
||||
batch_session.query.side_effect = [*count_queries, q_rs]
|
||||
|
||||
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)]
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0))
|
||||
|
||||
process_tenant_mock = MagicMock()
|
||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
|
||||
|
||||
ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[])
|
||||
|
||||
assert process_tenant_mock.call_count == 1
|
||||
assert len(count_queries) == 5
|
||||
assert batch_session.query.call_count >= 6
|
||||
|
||||
|
||||
def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
flask_app = service_module.Flask("test-app")
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"db",
|
||||
SimpleNamespace(
|
||||
engine=object(),
|
||||
session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="app-1")])),
|
||||
),
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(service_module, "storage", mock_storage)
|
||||
monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg)
|
||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "_clear_message_related_tables", MagicMock())
|
||||
|
||||
# Make message/conversation/workflow_app_log loops no-op (empty immediately)
|
||||
empty_session = MagicMock()
|
||||
q_empty = MagicMock()
|
||||
q_empty.where.return_value = q_empty
|
||||
q_empty.limit.return_value = q_empty
|
||||
q_empty.all.return_value = []
|
||||
empty_session.query.return_value = q_empty
|
||||
empty_session.commit.return_value = None
|
||||
session_wrappers = [
|
||||
_session_wrapper_for_no_autoflush(empty_session),
|
||||
_session_wrapper_for_no_autoflush(empty_session),
|
||||
_session_wrapper_for_no_autoflush(empty_session),
|
||||
]
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0))
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
stmt = MagicMock()
|
||||
stmt.where.return_value = stmt
|
||||
return stmt
|
||||
|
||||
monkeypatch.setattr(service_module, "select", fake_select)
|
||||
|
||||
# Repos: first returns exactly batch items -> no "< batch" break, second returns [] -> hit the len==0 break.
|
||||
node_repo = MagicMock()
|
||||
node_repo.get_expired_executions_batch.side_effect = [
|
||||
[SimpleNamespace(id="ne-1"), SimpleNamespace(id="ne-2")],
|
||||
[],
|
||||
]
|
||||
node_repo.delete_executions_by_ids.return_value = 2
|
||||
|
||||
run_repo = MagicMock()
|
||||
run_repo.get_expired_runs_batch.side_effect = [
|
||||
[
|
||||
SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"}),
|
||||
SimpleNamespace(id="wr-2", to_dict=lambda: {"id": "wr-2"}),
|
||||
],
|
||||
[],
|
||||
]
|
||||
run_repo.delete_runs_by_ids.return_value = 2
|
||||
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object())
|
||||
monkeypatch.setattr(
|
||||
service_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_node_execution_repository",
|
||||
lambda _sm: node_repo,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda _sm: run_repo,
|
||||
)
|
||||
|
||||
ClearFreePlanTenantExpiredLogs.process_tenant(flask_app, "tenant-1", days=7, batch=2)
|
||||
|
||||
assert node_repo.get_expired_executions_batch.call_count == 2
|
||||
assert run_repo.get_expired_runs_batch.call_count == 2
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import App, EndUser
|
||||
from models.model import App, DefaultEndUserSessionID, EndUser
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
@ -44,6 +44,145 @@ class TestEndUserServiceFactory:
|
||||
return end_user
|
||||
|
||||
|
||||
class TestEndUserServiceGetEndUserById:
|
||||
"""Unit tests for EndUserService.get_end_user_by_id method."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_success(self, mock_db, mock_session_class, factory):
|
||||
"""Test successful retrieval of end user by ID."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
end_user_id = "user-789"
|
||||
|
||||
mock_end_user = factory.create_end_user_mock(user_id=end_user_id, tenant_id=tenant_id, app_id=app_id)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = mock_end_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_end_user
|
||||
mock_session.query.assert_called_once_with(EndUser)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_query.first.assert_called_once()
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_not_found(self, mock_db, mock_session_class):
|
||||
"""Test retrieval of non-existent end user returns None."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
end_user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_query_parameters(self, mock_db, mock_session_class):
|
||||
"""Test that query parameters are correctly applied."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
end_user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
|
||||
|
||||
# Assert
|
||||
# Verify the where clause was called with the correct conditions
|
||||
call_args = mock_query.where.call_args[0]
|
||||
assert len(call_args) == 3
|
||||
# Check that the conditions match the expected filters
|
||||
# (We can't easily test the exact conditions without importing SQLAlchemy)
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUser:
|
||||
"""Unit tests for EndUserService.get_or_create_end_user method."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
@patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type")
|
||||
def test_get_or_create_end_user_with_user_id(self, mock_get_or_create_by_type, factory):
|
||||
"""Test get_or_create_end_user with specific user_id."""
|
||||
# Arrange
|
||||
app_mock = factory.create_app_mock()
|
||||
user_id = "user-123"
|
||||
expected_end_user = factory.create_end_user_mock()
|
||||
mock_get_or_create_by_type.return_value = expected_end_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_mock, user_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_end_user
|
||||
mock_get_or_create_by_type.assert_called_once_with(
|
||||
InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, user_id
|
||||
)
|
||||
|
||||
@patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type")
|
||||
def test_get_or_create_end_user_without_user_id(self, mock_get_or_create_by_type, factory):
|
||||
"""Test get_or_create_end_user without user_id (None)."""
|
||||
# Arrange
|
||||
app_mock = factory.create_app_mock()
|
||||
expected_end_user = factory.create_end_user_mock()
|
||||
mock_get_or_create_by_type.return_value = expected_end_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_mock, None)
|
||||
|
||||
# Assert
|
||||
assert result == expected_end_user
|
||||
mock_get_or_create_by_type.assert_called_once_with(
|
||||
InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, None
|
||||
)
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
"""
|
||||
Unit tests for EndUserService.get_or_create_end_user_by_type method.
|
||||
@ -60,6 +199,191 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_new_end_user_with_user_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating a new end user with specific user_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify new EndUser was created with correct parameters
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.tenant_id == tenant_id
|
||||
assert added_user.app_id == app_id
|
||||
assert added_user.type == type_enum
|
||||
assert added_user.session_id == user_id
|
||||
assert added_user.external_user_id == user_id
|
||||
assert added_user._is_anonymous is False
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_new_end_user_default_session(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating a new end user with default session ID."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = None
|
||||
type_enum = InvokeFrom.WEB_APP
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert added_user._is_anonymous is True
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
@patch("services.end_user_service.logger")
|
||||
def test_existing_user_same_type(self, mock_logger, mock_db, mock_session_class, factory):
|
||||
"""Test retrieving existing user with same type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=type_enum
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
@patch("services.end_user_service.logger")
|
||||
def test_existing_user_different_type_upgrade(self, mock_logger, mock_db, mock_session_class, factory):
|
||||
"""Test upgrading existing user with different type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
old_type = InvokeFrom.WEB_APP
|
||||
new_type = InvokeFrom.SERVICE_API
|
||||
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=old_type
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=new_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
assert existing_user.type == new_type
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_logger.info.assert_called_once()
|
||||
logger_call_args = mock_logger.info.call_args[0]
|
||||
assert "Upgrading legacy EndUser" in logger_call_args[0]
|
||||
# The old and new types are passed as separate arguments
|
||||
assert mock_logger.info.call_args[0][1] == existing_user.id
|
||||
assert mock_logger.info.call_args[0][2] == old_type
|
||||
assert mock_logger.info.call_args[0][3] == new_type
|
||||
assert mock_logger.info.call_args[0][4] == user_id
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_query_ordering_prioritizes_exact_type_match(self, mock_db, mock_session_class, factory):
|
||||
"""Test that query ordering prioritizes exact type matches."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
target_type = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_or_create_end_user_by_type(
|
||||
type=target_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_query.order_by.assert_called_once()
|
||||
# Verify that case statement is used for ordering
|
||||
order_by_call = mock_query.order_by.call_args[0][0]
|
||||
# The exact structure depends on SQLAlchemy's case implementation
|
||||
# but we can verify it was called
|
||||
|
||||
# Test 10: Session context manager properly closes
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
@ -93,3 +417,425 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
# Verify context manager was entered and exited
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_all_invokefrom_types_supported(self, mock_db, mock_session_class):
|
||||
"""Test that all InvokeFrom enum values are supported."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
for invoke_type in InvokeFrom:
|
||||
with patch("services.end_user_service.Session") as mock_session_class:
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=invoke_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == invoke_type
|
||||
|
||||
|
||||
class TestEndUserServiceCreateEndUserBatch:
|
||||
"""Unit tests for EndUserService.create_end_user_batch method."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_empty_app_ids(self, mock_db, mock_session_class):
|
||||
"""Test batch creation with empty app_ids list."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids: list[str] = []
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_default_session_id(self, mock_db, mock_session_class):
|
||||
"""Test batch creation with empty user_id (uses default session)."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789"]
|
||||
user_id = ""
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
for app_id, end_user in result.items():
|
||||
assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert end_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert end_user._is_anonymous is True
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_deduplicate_app_ids(self, mock_db, mock_session_class):
|
||||
"""Test that duplicate app_ids are deduplicated while preserving order."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789", "app-456", "app-123", "app-789"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should have 3 unique app_ids in original order
|
||||
assert len(result) == 3
|
||||
assert "app-456" in result
|
||||
assert "app-789" in result
|
||||
assert "app-123" in result
|
||||
|
||||
# Verify the order is preserved
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
assert len(added_users) == 3
|
||||
assert added_users[0].app_id == "app-456"
|
||||
assert added_users[1].app_id == "app-789"
|
||||
assert added_users[2].app_id == "app-123"
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_all_existing_users(self, mock_db, mock_session_class, factory):
|
||||
"""Test batch creation when all users already exist."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
existing_user1 = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum
|
||||
)
|
||||
existing_user2 = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id="app-789", session_id=user_id, type=type_enum
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [existing_user1, existing_user2]
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result["app-456"] == existing_user1
|
||||
assert result["app-789"] == existing_user2
|
||||
mock_session.add_all.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_partial_existing_users(self, mock_db, mock_session_class, factory):
|
||||
"""Test batch creation with some existing and some new users."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789", "app-123"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
existing_user1 = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum
|
||||
)
|
||||
# app-789 and app-123 don't exist
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [existing_user1]
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["app-456"] == existing_user1
|
||||
assert "app-789" in result
|
||||
assert "app-123" in result
|
||||
|
||||
# Should create 2 new users
|
||||
mock_session.add_all.assert_called_once()
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
assert len(added_users) == 2
|
||||
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_handles_duplicates_in_existing(self, mock_db, mock_session_class, factory):
|
||||
"""Test batch creation handles duplicates in existing users gracefully."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
# Simulate duplicate records in database
|
||||
existing_user1 = factory.create_end_user_mock(
|
||||
user_id="user-1", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum
|
||||
)
|
||||
existing_user2 = factory.create_end_user_mock(
|
||||
user_id="user-2", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [existing_user1, existing_user2]
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
# Should prefer the first one found
|
||||
assert result["app-456"] == existing_user1
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_all_invokefrom_types(self, mock_db, mock_session_class):
|
||||
"""Test batch creation with all InvokeFrom types."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456"]
|
||||
user_id = "user-789"
|
||||
|
||||
for invoke_type in InvokeFrom:
|
||||
with patch("services.end_user_service.Session") as mock_session_class:
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=invoke_type, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add_all.call_args[0][0][0]
|
||||
assert added_user.type == invoke_type
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_single_app_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test batch creation with single app_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert "app-456" in result
|
||||
mock_session.add_all.assert_called_once()
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
assert len(added_users) == 1
|
||||
assert added_users[0].app_id == "app-456"
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_anonymous_vs_authenticated(self, mock_db, mock_session_class):
|
||||
"""Test batch creation correctly sets anonymous flag."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789"]
|
||||
|
||||
# Test with regular user ID
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act - authenticated user
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="user-789"
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
for user in added_users:
|
||||
assert user._is_anonymous is False
|
||||
|
||||
# Test with default session ID
|
||||
mock_session.reset_mock()
|
||||
mock_query.reset_mock()
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act - anonymous user
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_ids=app_ids,
|
||||
user_id=DefaultEndUserSessionID.DEFAULT_SESSION_ID,
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
for user in added_users:
|
||||
assert user._is_anonymous is True
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_efficient_single_query(self, mock_db, mock_session_class):
|
||||
"""Test that batch creation uses efficient single query for existing users."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789", "app-123"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
# Should make exactly one query to check for existing users
|
||||
mock_session.query.assert_called_once_with(EndUser)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_query.all.assert_called_once()
|
||||
|
||||
# Verify the where clause uses .in_() for app_ids
|
||||
where_call = mock_query.where.call_args[0]
|
||||
# The exact structure depends on SQLAlchemy implementation
|
||||
# but we can verify it was called with the right parameters
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_session_context_manager(self, mock_db, mock_session_class):
|
||||
"""Test that batch creation properly uses session context manager."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
420
api/tests/unit_tests/services/test_file_service.py
Normal file
420
api/tests/unit_tests/services/test_file_service.py
Normal file
@ -0,0 +1,420 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Account, EndUser, UploadFile
|
||||
from services.errors.file import BlockedFileExtensionError, FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class TestFileService:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
session = MagicMock(spec=Session)
|
||||
# Mock context manager behavior
|
||||
session.__enter__.return_value = session
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(self, mock_db_session):
|
||||
maker = MagicMock(spec=sessionmaker)
|
||||
maker.return_value = mock_db_session
|
||||
return maker
|
||||
|
||||
@pytest.fixture
|
||||
def file_service(self, mock_session_maker):
|
||||
return FileService(session_factory=mock_session_maker)
|
||||
|
||||
def test_init_with_engine(self):
|
||||
engine = MagicMock(spec=Engine)
|
||||
service = FileService(session_factory=engine)
|
||||
assert isinstance(service._session_maker, sessionmaker)
|
||||
|
||||
def test_init_with_sessionmaker(self):
|
||||
maker = MagicMock(spec=sessionmaker)
|
||||
service = FileService(session_factory=maker)
|
||||
assert service._session_maker == maker
|
||||
|
||||
def test_init_invalid_factory(self):
|
||||
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
|
||||
FileService(session_factory="invalid")
|
||||
|
||||
@patch("services.file_service.storage")
|
||||
@patch("services.file_service.naive_utc_now")
|
||||
@patch("services.file_service.extract_tenant_id")
|
||||
@patch("services.file_service.file_helpers.get_signed_file_url")
|
||||
def test_upload_file_success(
|
||||
self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service, mock_db_session
|
||||
):
|
||||
# Setup
|
||||
mock_tenant_id.return_value = "tenant_id"
|
||||
mock_now.return_value = "2024-01-01"
|
||||
mock_get_url.return_value = "http://signed-url"
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "user_id"
|
||||
content = b"file content"
|
||||
filename = "test.jpg"
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
# Execute
|
||||
result = file_service.upload_file(filename=filename, content=content, mimetype=mimetype, user=user)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, UploadFile)
|
||||
assert result.name == filename
|
||||
assert result.tenant_id == "tenant_id"
|
||||
assert result.size == len(content)
|
||||
assert result.extension == "jpg"
|
||||
assert result.mime_type == mimetype
|
||||
assert result.created_by_role == CreatorUserRole.ACCOUNT
|
||||
assert result.created_by == "user_id"
|
||||
assert result.hash == hashlib.sha3_256(content).hexdigest()
|
||||
assert result.source_url == "http://signed-url"
|
||||
|
||||
mock_storage.save.assert_called_once()
|
||||
mock_db_session.add.assert_called_once_with(result)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_upload_file_invalid_characters(self, file_service):
|
||||
with pytest.raises(ValueError, match="Filename contains invalid characters"):
|
||||
file_service.upload_file(filename="invalid/file.txt", content=b"", mimetype="text/plain", user=MagicMock())
|
||||
|
||||
def test_upload_file_long_filename(self, file_service, mock_db_session):
|
||||
# Setup
|
||||
long_name = "a" * 210 + ".txt"
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "user_id"
|
||||
|
||||
with (
|
||||
patch("services.file_service.storage"),
|
||||
patch("services.file_service.extract_tenant_id") as mock_tenant,
|
||||
patch("services.file_service.file_helpers.get_signed_file_url"),
|
||||
):
|
||||
mock_tenant.return_value = "tenant"
|
||||
result = file_service.upload_file(filename=long_name, content=b"test", mimetype="text/plain", user=user)
|
||||
assert len(result.name) <= 205 # 200 + . + extension
|
||||
assert result.name.endswith(".txt")
|
||||
|
||||
def test_upload_file_blocked_extension(self, file_service):
|
||||
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe"):
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
file_service.upload_file(
|
||||
filename="test.exe", content=b"", mimetype="application/octet-stream", user=MagicMock()
|
||||
)
|
||||
|
||||
def test_upload_file_unsupported_type_for_datasets(self, file_service):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
file_service.upload_file(
|
||||
filename="test.jpg", content=b"", mimetype="image/jpeg", user=MagicMock(), source="datasets"
|
||||
)
|
||||
|
||||
def test_upload_file_too_large(self, file_service):
|
||||
# 16MB file for an image with 15MB limit
|
||||
content = b"a" * (16 * 1024 * 1024)
|
||||
with patch.object(dify_config, "UPLOAD_IMAGE_FILE_SIZE_LIMIT", 15):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
file_service.upload_file(filename="test.jpg", content=content, mimetype="image/jpeg", user=MagicMock())
|
||||
|
||||
def test_upload_file_end_user(self, file_service, mock_db_session):
|
||||
user = MagicMock(spec=EndUser)
|
||||
user.id = "end_user_id"
|
||||
|
||||
with (
|
||||
patch("services.file_service.storage"),
|
||||
patch("services.file_service.extract_tenant_id") as mock_tenant,
|
||||
patch("services.file_service.file_helpers.get_signed_file_url"),
|
||||
):
|
||||
mock_tenant.return_value = "tenant"
|
||||
result = file_service.upload_file(filename="test.txt", content=b"test", mimetype="text/plain", user=user)
|
||||
assert result.created_by_role == CreatorUserRole.END_USER
|
||||
|
||||
def test_is_file_size_within_limit(self):
|
||||
with (
|
||||
patch.object(dify_config, "UPLOAD_IMAGE_FILE_SIZE_LIMIT", 10),
|
||||
patch.object(dify_config, "UPLOAD_VIDEO_FILE_SIZE_LIMIT", 20),
|
||||
patch.object(dify_config, "UPLOAD_AUDIO_FILE_SIZE_LIMIT", 30),
|
||||
patch.object(dify_config, "UPLOAD_FILE_SIZE_LIMIT", 5),
|
||||
):
|
||||
# Image
|
||||
assert FileService.is_file_size_within_limit(extension="jpg", file_size=10 * 1024 * 1024) is True
|
||||
assert FileService.is_file_size_within_limit(extension="png", file_size=11 * 1024 * 1024) is False
|
||||
|
||||
# Video
|
||||
assert FileService.is_file_size_within_limit(extension="mp4", file_size=20 * 1024 * 1024) is True
|
||||
assert FileService.is_file_size_within_limit(extension="avi", file_size=21 * 1024 * 1024) is False
|
||||
|
||||
# Audio
|
||||
assert FileService.is_file_size_within_limit(extension="mp3", file_size=30 * 1024 * 1024) is True
|
||||
assert FileService.is_file_size_within_limit(extension="wav", file_size=31 * 1024 * 1024) is False
|
||||
|
||||
# Default
|
||||
assert FileService.is_file_size_within_limit(extension="txt", file_size=5 * 1024 * 1024) is True
|
||||
assert FileService.is_file_size_within_limit(extension="pdf", file_size=6 * 1024 * 1024) is False
|
||||
|
||||
def test_get_file_base64_success(self, file_service, mock_db_session):
|
||||
# Setup
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "test_key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
mock_storage.load_once.return_value = b"test content"
|
||||
|
||||
# Execute
|
||||
result = file_service.get_file_base64("file_id")
|
||||
|
||||
# Assert
|
||||
assert result == base64.b64encode(b"test content").decode()
|
||||
mock_storage.load_once.assert_called_once_with("test_key")
|
||||
|
||||
def test_get_file_base64_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
file_service.get_file_base64("non_existent")
|
||||
|
||||
def test_upload_text_success(self, file_service, mock_db_session):
|
||||
# Setup
|
||||
text = "sample text"
|
||||
text_name = "test.txt"
|
||||
user_id = "user_id"
|
||||
tenant_id = "tenant_id"
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
# Execute
|
||||
result = file_service.upload_text(text, text_name, user_id, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result.name == text_name
|
||||
assert result.size == len(text)
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.created_by == user_id
|
||||
assert result.used is True
|
||||
assert result.extension == "txt"
|
||||
mock_storage.save.assert_called_once()
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_upload_text_long_name(self, file_service, mock_db_session):
|
||||
long_name = "a" * 210
|
||||
with patch("services.file_service.storage"):
|
||||
result = file_service.upload_text("text", long_name, "user", "tenant")
|
||||
assert len(result.name) == 200
|
||||
|
||||
def test_get_file_preview_success(self, file_service, mock_db_session):
|
||||
# Setup
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "pdf"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract:
|
||||
mock_extract.return_value = "Extracted text content"
|
||||
|
||||
# Execute
|
||||
result = file_service.get_file_preview("file_id")
|
||||
|
||||
# Assert
|
||||
assert result == "Extracted text content"
|
||||
|
||||
def test_get_file_preview_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
file_service.get_file_preview("non_existent")
|
||||
|
||||
def test_get_file_preview_unsupported_type(self, file_service, mock_db_session):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "exe"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
file_service.get_file_preview("file_id")
|
||||
|
||||
def test_get_image_preview_success(self, file_service, mock_db_session):
|
||||
# Setup
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "jpg"
|
||||
upload_file.mime_type = "image/jpeg"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
|
||||
with (
|
||||
patch("services.file_service.file_helpers.verify_image_signature") as mock_verify,
|
||||
patch("services.file_service.storage") as mock_storage,
|
||||
):
|
||||
mock_verify.return_value = True
|
||||
mock_storage.load.return_value = iter([b"chunk1"])
|
||||
|
||||
# Execute
|
||||
gen, mime = file_service.get_image_preview("file_id", "ts", "nonce", "sign")
|
||||
|
||||
# Assert
|
||||
assert list(gen) == [b"chunk1"]
|
||||
assert mime == "image/jpeg"
|
||||
|
||||
def test_get_image_preview_invalid_sig(self, file_service):
|
||||
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
|
||||
mock_verify.return_value = False
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
|
||||
|
||||
def test_get_image_preview_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
|
||||
mock_verify.return_value = True
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
|
||||
|
||||
def test_get_image_preview_unsupported_type(self, file_service, mock_db_session):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "txt"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
|
||||
mock_verify.return_value = True
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
|
||||
|
||||
def test_get_file_generator_by_file_id_success(self, file_service, mock_db_session):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
|
||||
with (
|
||||
patch("services.file_service.file_helpers.verify_file_signature") as mock_verify,
|
||||
patch("services.file_service.storage") as mock_storage,
|
||||
):
|
||||
mock_verify.return_value = True
|
||||
mock_storage.load.return_value = iter([b"chunk"])
|
||||
|
||||
gen, file = file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
|
||||
assert list(gen) == [b"chunk"]
|
||||
assert file == upload_file
|
||||
|
||||
def test_get_file_generator_by_file_id_invalid_sig(self, file_service):
|
||||
with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify:
|
||||
mock_verify.return_value = False
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
|
||||
|
||||
def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify:
|
||||
mock_verify.return_value = True
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
|
||||
|
||||
def test_get_public_image_preview_success(self, file_service, mock_db_session):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "png"
|
||||
upload_file.mime_type = "image/png"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"image content"
|
||||
gen, mime = file_service.get_public_image_preview("file_id")
|
||||
assert gen == b"image content"
|
||||
assert mime == "image/png"
|
||||
|
||||
def test_get_public_image_preview_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
file_service.get_public_image_preview("file_id")
|
||||
|
||||
def test_get_public_image_preview_unsupported_type(self, file_service, mock_db_session):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "txt"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
file_service.get_public_image_preview("file_id")
|
||||
|
||||
def test_get_file_content_success(self, file_service, mock_db_session):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"hello world"
|
||||
result = file_service.get_file_content("file_id")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_get_file_content_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
file_service.get_file_content("file_id")
|
||||
|
||||
def test_delete_file_success(self, file_service, mock_db_session):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "key"
|
||||
# For session.scalar(select(...))
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
file_service.delete_file("file_id")
|
||||
mock_storage.delete.assert_called_once_with("key")
|
||||
mock_db_session.delete.assert_called_once_with(upload_file)
|
||||
|
||||
def test_delete_file_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.scalar.return_value = None
|
||||
file_service.delete_file("file_id")
|
||||
# Should return without doing anything
|
||||
|
||||
@patch("services.file_service.db")
|
||||
def test_get_upload_files_by_ids_empty(self, mock_db):
|
||||
result = FileService.get_upload_files_by_ids("tenant_id", [])
|
||||
assert result == {}
|
||||
|
||||
@patch("services.file_service.db")
|
||||
def test_get_upload_files_by_ids(self, mock_db):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "550e8400-e29b-41d4-a716-446655440000"
|
||||
upload_file.tenant_id = "tenant_id"
|
||||
mock_db.session.scalars().all.return_value = [upload_file]
|
||||
|
||||
result = FileService.get_upload_files_by_ids("tenant_id", ["550e8400-e29b-41d4-a716-446655440000"])
|
||||
assert result["550e8400-e29b-41d4-a716-446655440000"] == upload_file
|
||||
|
||||
def test_sanitize_zip_entry_name(self):
|
||||
assert FileService._sanitize_zip_entry_name("path/to/file.txt") == "file.txt"
|
||||
assert FileService._sanitize_zip_entry_name("../../../etc/passwd") == "passwd"
|
||||
assert FileService._sanitize_zip_entry_name(" ") == "file"
|
||||
assert FileService._sanitize_zip_entry_name("a\\b") == "a_b"
|
||||
|
||||
def test_dedupe_zip_entry_name(self):
|
||||
used = {"a.txt"}
|
||||
assert FileService._dedupe_zip_entry_name("b.txt", used) == "b.txt"
|
||||
assert FileService._dedupe_zip_entry_name("a.txt", used) == "a (1).txt"
|
||||
used.add("a (1).txt")
|
||||
assert FileService._dedupe_zip_entry_name("a.txt", used) == "a (2).txt"
|
||||
|
||||
def test_build_upload_files_zip_tempfile(self):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.name = "test.txt"
|
||||
upload_file.key = "key"
|
||||
|
||||
with (
|
||||
patch("services.file_service.storage") as mock_storage,
|
||||
patch("services.file_service.os.remove") as mock_remove,
|
||||
):
|
||||
mock_storage.load.return_value = [b"chunk1", b"chunk2"]
|
||||
|
||||
with FileService.build_upload_files_zip_tempfile(upload_files=[upload_file]) as tmp_path:
|
||||
assert os.path.exists(tmp_path)
|
||||
|
||||
mock_remove.assert_called_once()
|
||||
@ -1,97 +1,291 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
MemberRecipient,
|
||||
)
|
||||
from dify_graph.runtime import VariablePool
|
||||
from services import human_input_delivery_test_service as service_module
|
||||
from services.human_input_delivery_test_service import (
|
||||
DeliveryTestContext,
|
||||
DeliveryTestEmailRecipient,
|
||||
DeliveryTestError,
|
||||
DeliveryTestRegistry,
|
||||
DeliveryTestResult,
|
||||
DeliveryTestStatus,
|
||||
DeliveryTestUnsupportedError,
|
||||
EmailDeliveryTestHandler,
|
||||
HumanInputDeliveryTestService,
|
||||
_build_form_link,
|
||||
)
|
||||
|
||||
|
||||
def _make_email_method() -> EmailDeliveryMethod:
|
||||
return EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[ExternalRecipient(email="tester@example.com")],
|
||||
),
|
||||
subject="Test subject",
|
||||
body="Test body",
|
||||
@pytest.fixture
|
||||
def mock_db(monkeypatch):
|
||||
mock_db = MagicMock()
|
||||
monkeypatch.setattr(service_module, "db", mock_db)
|
||||
return mock_db
|
||||
|
||||
|
||||
def _make_valid_email_config():
|
||||
return EmailDeliveryConfig(recipients=EmailRecipients(whole_workspace=False, items=[]), subject="Subj", body="Body")
|
||||
|
||||
|
||||
def test_build_form_link():
|
||||
with patch.object(dify_config, "APP_WEB_URL", "http://example.com/"):
|
||||
assert _build_form_link("token123") == "http://example.com/form/token123"
|
||||
|
||||
with patch.object(dify_config, "APP_WEB_URL", "http://example.com"):
|
||||
assert _build_form_link("token123") == "http://example.com/form/token123"
|
||||
|
||||
assert _build_form_link(None) is None
|
||||
|
||||
with patch.object(dify_config, "APP_WEB_URL", None):
|
||||
assert _build_form_link("token123") is None
|
||||
|
||||
|
||||
class TestDeliveryTestRegistry:
|
||||
def test_register(self):
|
||||
registry = DeliveryTestRegistry()
|
||||
assert len(registry._handlers) == 0
|
||||
handler = MagicMock()
|
||||
registry.register(handler)
|
||||
assert len(registry._handlers) == 1
|
||||
assert registry._handlers[0] == handler
|
||||
|
||||
def test_register_and_dispatch(self):
|
||||
handler = MagicMock()
|
||||
handler.supports.return_value = True
|
||||
handler.send_test.return_value = DeliveryTestResult(status=DeliveryTestStatus.OK)
|
||||
|
||||
registry = DeliveryTestRegistry([handler])
|
||||
context = MagicMock(spec=DeliveryTestContext)
|
||||
method = MagicMock()
|
||||
|
||||
result = registry.dispatch(context=context, method=method)
|
||||
|
||||
assert result.status == DeliveryTestStatus.OK
|
||||
handler.supports.assert_called_once_with(method)
|
||||
handler.send_test.assert_called_once_with(context=context, method=method)
|
||||
|
||||
def test_dispatch_unsupported(self):
|
||||
handler = MagicMock()
|
||||
handler.supports.return_value = False
|
||||
|
||||
registry = DeliveryTestRegistry([handler])
|
||||
context = MagicMock(spec=DeliveryTestContext)
|
||||
method = MagicMock()
|
||||
|
||||
with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."):
|
||||
registry.dispatch(context=context, method=method)
|
||||
|
||||
def test_default(self, mock_db):
|
||||
registry = DeliveryTestRegistry.default()
|
||||
assert len(registry._handlers) == 1
|
||||
assert isinstance(registry._handlers[0], EmailDeliveryTestHandler)
|
||||
|
||||
|
||||
def test_human_input_delivery_test_service():
|
||||
registry = MagicMock(spec=DeliveryTestRegistry)
|
||||
service = HumanInputDeliveryTestService(registry=registry)
|
||||
context = MagicMock(spec=DeliveryTestContext)
|
||||
method = MagicMock()
|
||||
|
||||
service.send_test(context=context, method=method)
|
||||
registry.dispatch.assert_called_once_with(context=context, method=method)
|
||||
|
||||
|
||||
class TestEmailDeliveryTestHandler:
|
||||
def test_init_with_engine(self):
|
||||
engine = MagicMock(spec=Engine)
|
||||
handler = EmailDeliveryTestHandler(session_factory=engine)
|
||||
assert handler._session_factory.kw["bind"] == engine
|
||||
|
||||
def test_supports(self):
|
||||
handler = EmailDeliveryTestHandler(session_factory=MagicMock())
|
||||
method = EmailDeliveryMethod(config=_make_valid_email_config())
|
||||
assert handler.supports(method) is True
|
||||
assert handler.supports(MagicMock()) is False
|
||||
|
||||
def test_send_test_unsupported_method(self):
|
||||
handler = EmailDeliveryTestHandler(session_factory=MagicMock())
|
||||
with pytest.raises(DeliveryTestUnsupportedError):
|
||||
handler.send_test(context=MagicMock(), method=MagicMock())
|
||||
|
||||
def test_send_test_feature_disabled(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
|
||||
)
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=object())
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
rendered_content="content",
|
||||
)
|
||||
method = _make_email_method()
|
||||
|
||||
with pytest.raises(DeliveryTestError, match="Email delivery is not available"):
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
|
||||
def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch):
|
||||
class DummyMail:
|
||||
def __init__(self):
|
||||
self.sent: list[dict[str, str]] = []
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return True
|
||||
|
||||
def send(self, *, to: str, subject: str, html: str):
|
||||
self.sent.append({"to": to, "subject": subject, "html": html})
|
||||
|
||||
mail = DummyMail()
|
||||
monkeypatch.setattr(service_module, "mail", mail)
|
||||
monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=object())
|
||||
handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment]
|
||||
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]),
|
||||
subject="Subject",
|
||||
body="Value {{#node1.value#}}",
|
||||
handler = EmailDeliveryTestHandler(session_factory=MagicMock())
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content"
|
||||
)
|
||||
)
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(["node1", "value"], "OK")
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
rendered_content="content",
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
method = EmailDeliveryMethod(config=_make_valid_email_config())
|
||||
|
||||
handler.send_test(context=context, method=method)
|
||||
with pytest.raises(DeliveryTestError, match="Email delivery is not available"):
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
assert mail.sent[0]["html"] == "Value OK"
|
||||
def test_send_test_mail_not_inited(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
monkeypatch.setattr(service_module.mail, "is_inited", lambda: False)
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=MagicMock())
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content"
|
||||
)
|
||||
method = EmailDeliveryMethod(config=_make_valid_email_config())
|
||||
|
||||
with pytest.raises(DeliveryTestError, match="Mail client is not initialized."):
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
def test_send_test_no_recipients(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
monkeypatch.setattr(service_module.mail, "is_inited", lambda: True)
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=MagicMock())
|
||||
handler._resolve_recipients = MagicMock(return_value=[])
|
||||
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content"
|
||||
)
|
||||
method = EmailDeliveryMethod(config=_make_valid_email_config())
|
||||
|
||||
with pytest.raises(DeliveryTestError, match="No recipients configured"):
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
def test_send_test_success(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
monkeypatch.setattr(service_module.mail, "is_inited", lambda: True)
|
||||
mock_mail_send = MagicMock()
|
||||
monkeypatch.setattr(service_module.mail, "send", mock_mail_send)
|
||||
monkeypatch.setattr(service_module, "render_email_template", lambda t, s: f"RENDERED_{t}")
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=MagicMock())
|
||||
handler._resolve_recipients = MagicMock(return_value=["test@example.com"])
|
||||
|
||||
variable_pool = VariablePool()
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
node_id="n1",
|
||||
node_title="title",
|
||||
rendered_content="content",
|
||||
variable_pool=variable_pool,
|
||||
recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")],
|
||||
)
|
||||
|
||||
method = EmailDeliveryMethod(config=_make_valid_email_config())
|
||||
|
||||
result = handler.send_test(context=context, method=method)
|
||||
|
||||
assert result.status == DeliveryTestStatus.OK
|
||||
assert result.delivered_to == ["test@example.com"]
|
||||
mock_mail_send.assert_called_once()
|
||||
args, kwargs = mock_mail_send.call_args
|
||||
assert kwargs["to"] == "test@example.com"
|
||||
assert "RENDERED_Subj" in kwargs["subject"]
|
||||
|
||||
def test_resolve_recipients(self):
|
||||
handler = EmailDeliveryTestHandler(session_factory=MagicMock())
|
||||
|
||||
# Test Case 1: External Recipient
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(items=[ExternalRecipient(email="ext@example.com")], whole_workspace=False),
|
||||
subject="",
|
||||
body="",
|
||||
)
|
||||
)
|
||||
assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"]
|
||||
|
||||
# Test Case 2: Member Recipient
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(items=[MemberRecipient(user_id="u1")], whole_workspace=False),
|
||||
subject="",
|
||||
body="",
|
||||
)
|
||||
)
|
||||
handler._query_workspace_member_emails = MagicMock(return_value={"u1": "u1@example.com"})
|
||||
assert handler._resolve_recipients(tenant_id="t1", method=method) == ["u1@example.com"]
|
||||
|
||||
# Test Case 3: Whole Workspace
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(recipients=EmailRecipients(items=[], whole_workspace=True), subject="", body="")
|
||||
)
|
||||
handler._query_workspace_member_emails = MagicMock(
|
||||
return_value={"u1": "u1@example.com", "u2": "u2@example.com"}
|
||||
)
|
||||
recipients = handler._resolve_recipients(tenant_id="t1", method=method)
|
||||
assert set(recipients) == {"u1@example.com", "u2@example.com"}
|
||||
|
||||
def test_query_workspace_member_emails(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory = MagicMock(return_value=mock_session)
|
||||
mock_session.__enter__.return_value = mock_session
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=mock_session_factory)
|
||||
|
||||
# Empty user_ids
|
||||
assert handler._query_workspace_member_emails(tenant_id="t1", user_ids=[]) == {}
|
||||
|
||||
# user_ids is None (all)
|
||||
mock_execute = MagicMock()
|
||||
mock_session.execute.return_value = mock_execute
|
||||
mock_execute.all.return_value = [("u1", "u1@example.com")]
|
||||
|
||||
result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None)
|
||||
assert result == {"u1": "u1@example.com"}
|
||||
|
||||
# user_ids with values
|
||||
result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=["u1"])
|
||||
assert result == {"u1": "u1@example.com"}
|
||||
|
||||
def test_build_substitutions(self):
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
node_id="n1",
|
||||
node_title="title",
|
||||
rendered_content="content",
|
||||
template_vars={"custom": "var"},
|
||||
recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")],
|
||||
)
|
||||
|
||||
subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com")
|
||||
|
||||
assert subs["node_title"] == "title"
|
||||
assert subs["form_content"] == "content"
|
||||
assert subs["recipient_email"] == "test@example.com"
|
||||
assert subs["custom"] == "var"
|
||||
assert subs["form_token"] == "token123"
|
||||
assert "form/token123" in subs["form_link"]
|
||||
|
||||
# Without matching recipient
|
||||
subs_no_match = EmailDeliveryTestHandler._build_substitutions(
|
||||
context=context, recipient_email="other@example.com"
|
||||
)
|
||||
assert subs_no_match["form_token"] == ""
|
||||
assert subs_no_match["form_link"] == ""
|
||||
|
||||
@ -16,7 +16,13 @@ from dify_graph.nodes.human_input.entities import (
|
||||
)
|
||||
from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError
|
||||
from services.human_input_service import (
|
||||
Form,
|
||||
FormExpiredError,
|
||||
FormSubmittedError,
|
||||
HumanInputService,
|
||||
InvalidFormDataError,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -285,3 +291,172 @@ def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_fa
|
||||
|
||||
assert "Missing required inputs" in str(exc_info.value)
|
||||
repo.mark_submitted.assert_not_called()
|
||||
|
||||
|
||||
def test_form_properties(sample_form_record):
|
||||
form = Form(sample_form_record)
|
||||
assert form.id == "form-id"
|
||||
assert form.workflow_run_id == "workflow-run-id"
|
||||
assert form.tenant_id == "tenant-id"
|
||||
assert form.app_id == "app-id"
|
||||
assert form.recipient_id == "recipient-id"
|
||||
assert form.recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
assert form.status == HumanInputFormStatus.WAITING
|
||||
assert form.form_kind == HumanInputFormKind.RUNTIME
|
||||
assert isinstance(form.created_at, datetime)
|
||||
assert isinstance(form.expiration_time, datetime)
|
||||
|
||||
|
||||
def test_form_submitted_error_init():
|
||||
error = FormSubmittedError(form_id="test-form")
|
||||
assert "form_id=test-form" in error.description
|
||||
assert error.code == 412
|
||||
|
||||
|
||||
def test_human_input_service_init_with_engine(mocker):
|
||||
engine = MagicMock(spec=human_input_service_module.Engine)
|
||||
sessionmaker_mock = mocker.patch("services.human_input_service.sessionmaker")
|
||||
|
||||
HumanInputService(session_factory=engine)
|
||||
sessionmaker_mock.assert_called_once_with(bind=engine)
|
||||
|
||||
|
||||
def test_get_form_by_token_none(mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = None
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
assert service.get_form_by_token("invalid") is None
|
||||
|
||||
|
||||
def test_get_form_definition_by_token_mismatch(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
# RecipientType mismatch
|
||||
assert service.get_form_definition_by_token(RecipientType.CONSOLE, "token") is None
|
||||
|
||||
|
||||
def test_get_form_definition_by_token_success(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
form = service.get_form_definition_by_token(RecipientType.STANDALONE_WEB_APP, "token")
|
||||
assert form is not None
|
||||
assert form.id == sample_form_record.form_id
|
||||
|
||||
|
||||
def test_get_form_definition_by_token_for_console_mismatch(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = sample_form_record # is STANDALONE_WEB_APP
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
assert service.get_form_definition_by_token_for_console("token") is None
|
||||
|
||||
|
||||
def test_submit_form_by_token_delivery_not_enabled(mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = None
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
with pytest.raises(human_input_service_module.WebAppDeliveryNotEnabledError):
|
||||
service.submit_form_by_token(RecipientType.STANDALONE_WEB_APP, "token", "action", {})
|
||||
|
||||
|
||||
def test_submit_form_by_token_no_workflow_run_id(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
|
||||
# Return record with no workflow_run_id
|
||||
result_record = dataclasses.replace(sample_form_record, workflow_run_id=None)
|
||||
repo.mark_submitted.return_value = result_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
|
||||
|
||||
service.submit_form_by_token(RecipientType.STANDALONE_WEB_APP, "token", "submit", {})
|
||||
enqueue_spy.assert_not_called()
|
||||
|
||||
|
||||
def test_ensure_form_active_errors(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
# Submitted
|
||||
submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow())
|
||||
with pytest.raises(human_input_service_module.FormSubmittedError):
|
||||
service.ensure_form_active(Form(submitted_record))
|
||||
|
||||
# Timeout status
|
||||
timeout_record = dataclasses.replace(sample_form_record, status=HumanInputFormStatus.TIMEOUT)
|
||||
with pytest.raises(FormExpiredError):
|
||||
service.ensure_form_active(Form(timeout_record))
|
||||
|
||||
# Expired time
|
||||
expired_time_record = dataclasses.replace(
|
||||
sample_form_record, expiration_time=datetime.utcnow() - timedelta(minutes=1)
|
||||
)
|
||||
with pytest.raises(FormExpiredError):
|
||||
service.ensure_form_active(Form(expired_time_record))
|
||||
|
||||
|
||||
def test_ensure_not_submitted_raises(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow())
|
||||
|
||||
with pytest.raises(human_input_service_module.FormSubmittedError):
|
||||
service._ensure_not_submitted(Form(submitted_record))
|
||||
|
||||
|
||||
def test_enqueue_resume_workflow_not_found(mocker, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = None
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
service.enqueue_resume("workflow-run-id")
|
||||
assert "WorkflowRun not found" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_enqueue_resume_app_not_found(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
logger_spy = mocker.patch("services.human_input_service.logger")
|
||||
|
||||
service.enqueue_resume("workflow-run-id")
|
||||
logger_spy.error.assert_called_once()
|
||||
|
||||
|
||||
def test_is_globally_expired_zero_timeout(monkeypatch, sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0)
|
||||
assert service._is_globally_expired(Form(sample_form_record)) is False
|
||||
|
||||
@ -5,8 +5,13 @@ import pytest
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.model import App, AppMode, EndUser, Message
|
||||
from services.errors.message import FirstMessageNotExistsError, LastMessageNotExistsError
|
||||
from services.message_service import MessageService
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
LastMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
|
||||
|
||||
class TestMessageServiceFactory:
|
||||
@ -244,14 +249,12 @@ class TestMessageServicePaginationByFirstId:
|
||||
mock_query_first = MagicMock()
|
||||
mock_query_history = MagicMock()
|
||||
|
||||
query_calls = []
|
||||
|
||||
def query_side_effect(*args):
|
||||
if args[0] == Message:
|
||||
# First call returns mock for first_message query
|
||||
if not hasattr(query_side_effect, "call_count"):
|
||||
query_side_effect.call_count = 0
|
||||
query_side_effect.call_count += 1
|
||||
|
||||
if query_side_effect.call_count == 1:
|
||||
query_calls.append(args)
|
||||
if len(query_calls) == 1:
|
||||
return mock_query_first
|
||||
else:
|
||||
return mock_query_history
|
||||
@ -647,3 +650,410 @@ class TestMessageServicePaginationByLastId:
|
||||
assert len(result.data) == 10 # Last message trimmed
|
||||
assert result.has_more is True
|
||||
assert result.limit == 10
|
||||
|
||||
|
||||
class TestMessageServiceUtilities:
|
||||
"""Unit tests for MessageService module-level utility functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestMessageServiceFactory()
|
||||
|
||||
# Test 16: attach_message_extra_contents with empty list
|
||||
def test_attach_message_extra_contents_empty(self):
|
||||
"""Test attach_message_extra_contents with empty list does nothing."""
|
||||
# Act & Assert (should not raise error)
|
||||
attach_message_extra_contents([])
|
||||
|
||||
# Test 17: attach_message_extra_contents with messages
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory):
|
||||
"""Test attach_message_extra_contents correctly attaches content."""
|
||||
# Arrange
|
||||
messages = [factory.create_message_mock(message_id="msg-1"), factory.create_message_mock(message_id="msg-2")]
|
||||
|
||||
mock_repo = MagicMock()
|
||||
mock_create_repo.return_value = mock_repo
|
||||
|
||||
# Mock extra content models
|
||||
mock_content1 = MagicMock()
|
||||
mock_content1.model_dump.return_value = {"key": "value1"}
|
||||
mock_content2 = MagicMock()
|
||||
mock_content2.model_dump.return_value = {"key": "value2"}
|
||||
|
||||
mock_repo.get_by_message_ids.return_value = [[mock_content1], [mock_content2]]
|
||||
|
||||
# Act
|
||||
attach_message_extra_contents(messages)
|
||||
|
||||
# Assert
|
||||
mock_repo.get_by_message_ids.assert_called_once_with(["msg-1", "msg-2"])
|
||||
messages[0].set_extra_contents.assert_called_once_with([{"key": "value1"}])
|
||||
messages[1].set_extra_contents.assert_called_once_with([{"key": "value2"}])
|
||||
|
||||
# Test 18: attach_message_extra_contents with index out of bounds
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
def test_attach_message_extra_contents_index_out_of_bounds(self, mock_create_repo, factory):
|
||||
"""Test attach_message_extra_contents handles missing content lists."""
|
||||
# Arrange
|
||||
messages = [factory.create_message_mock(message_id="msg-1")]
|
||||
|
||||
mock_repo = MagicMock()
|
||||
mock_create_repo.return_value = mock_repo
|
||||
mock_repo.get_by_message_ids.return_value = [] # Empty returned list
|
||||
|
||||
# Act
|
||||
attach_message_extra_contents(messages)
|
||||
|
||||
# Assert
|
||||
messages[0].set_extra_contents.assert_called_once_with([])
|
||||
|
||||
# Test 19: _create_execution_extra_content_repository
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.sessionmaker")
|
||||
@patch("services.message_service.SQLAlchemyExecutionExtraContentRepository")
|
||||
def test_create_execution_extra_content_repository(self, mock_repo_class, mock_sessionmaker, mock_db):
|
||||
"""Test _create_execution_extra_content_repository creates expected repository."""
|
||||
from services.message_service import _create_execution_extra_content_repository
|
||||
|
||||
# Act
|
||||
_create_execution_extra_content_repository()
|
||||
|
||||
# Assert
|
||||
mock_sessionmaker.assert_called_once()
|
||||
mock_repo_class.assert_called_once()
|
||||
|
||||
|
||||
class TestMessageServiceGetMessage:
|
||||
"""Unit tests for MessageService.get_message method."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestMessageServiceFactory()
|
||||
|
||||
# Test 20: get_message success for EndUser
|
||||
@patch("services.message_service.db")
|
||||
def test_get_message_end_user_success(self, mock_db, factory):
|
||||
"""Test get_message returns message for EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock(user_id="end-user-123")
|
||||
message = factory.create_message_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Act
|
||||
result = MessageService.get_message(app_model=app, user=user, message_id="msg-123")
|
||||
|
||||
# Assert
|
||||
assert result == message
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
# Test 21: get_message success for Account (Admin)
|
||||
@patch("services.message_service.db")
|
||||
def test_get_message_account_success(self, mock_db, factory):
|
||||
"""Test get_message returns message for Account."""
|
||||
# Arrange
|
||||
from models import Account
|
||||
|
||||
app = factory.create_app_mock()
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "account-123"
|
||||
message = factory.create_message_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Act
|
||||
result = MessageService.get_message(app_model=app, user=user, message_id="msg-123")
|
||||
|
||||
# Assert
|
||||
assert result == message
|
||||
|
||||
# Test 22: get_message not found
|
||||
@patch("services.message_service.db")
|
||||
def test_get_message_not_found(self, mock_db, factory):
|
||||
"""Test get_message raises MessageNotExistsError when not found."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.get_message(app_model=app, user=user, message_id="msg-123")
|
||||
|
||||
|
||||
class TestMessageServiceFeedback:
|
||||
"""Unit tests for MessageService feedback-related methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestMessageServiceFactory()
|
||||
|
||||
# Test 23: create_feedback - new feedback for EndUser
|
||||
@patch("services.message_service.db")
|
||||
@patch.object(MessageService, "get_message")
|
||||
def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory):
|
||||
"""Test creating new feedback for an end user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
message = factory.create_message_mock()
|
||||
message.user_feedback = None
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
result = MessageService.create_feedback(
|
||||
app_model=app,
|
||||
message_id="msg-123",
|
||||
user=user,
|
||||
rating="like",
|
||||
content="Good answer",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.rating == "like"
|
||||
assert result.content == "Good answer"
|
||||
assert result.from_source == "user"
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
# Test 24: create_feedback - update feedback for Account
|
||||
@patch("services.message_service.db")
|
||||
@patch.object(MessageService, "get_message")
|
||||
def test_create_feedback_update_account(self, mock_get_message, mock_db, factory):
|
||||
"""Test updating existing feedback for an account."""
|
||||
# Arrange
|
||||
from models import Account, MessageFeedback
|
||||
|
||||
app = factory.create_app_mock()
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "account-123"
|
||||
message = factory.create_message_mock()
|
||||
feedback = MagicMock(spec=MessageFeedback)
|
||||
message.admin_feedback = feedback
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
result = MessageService.create_feedback(
|
||||
app_model=app,
|
||||
message_id="msg-123",
|
||||
user=user,
|
||||
rating="dislike",
|
||||
content="Bad answer",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == feedback
|
||||
assert feedback.rating == "dislike"
|
||||
assert feedback.content == "Bad answer"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
# Test 25: create_feedback - delete feedback (rating is None)
|
||||
@patch("services.message_service.db")
|
||||
@patch.object(MessageService, "get_message")
|
||||
def test_create_feedback_delete(self, mock_get_message, mock_db, factory):
|
||||
"""Test deleting feedback by passing rating=None."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
message = factory.create_message_mock()
|
||||
feedback = MagicMock()
|
||||
message.user_feedback = feedback
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
result = MessageService.create_feedback(
|
||||
app_model=app,
|
||||
message_id="msg-123",
|
||||
user=user,
|
||||
rating=None,
|
||||
content=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == feedback
|
||||
mock_db.session.delete.assert_called_once_with(feedback)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
# Test 26: get_all_messages_feedbacks
|
||||
@patch("services.message_service.db")
|
||||
def test_get_all_messages_feedbacks(self, mock_db, factory):
|
||||
"""Test get_all_messages_feedbacks returns list of dicts."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
feedback = MagicMock()
|
||||
feedback.to_dict.return_value = {"id": "fb-1"}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.offset.return_value = mock_query
|
||||
mock_query.all.return_value = [feedback]
|
||||
|
||||
# Act
|
||||
result = MessageService.get_all_messages_feedbacks(app_model=app, page=1, limit=10)
|
||||
|
||||
# Assert
|
||||
assert result == [{"id": "fb-1"}]
|
||||
mock_query.limit.assert_called_with(10)
|
||||
mock_query.offset.assert_called_with(0)
|
||||
|
||||
|
||||
class TestMessageServiceSuggestedQuestions:
|
||||
"""Unit tests for MessageService.get_suggested_questions_after_answer method."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestMessageServiceFactory()
|
||||
|
||||
# Test 27: get_suggested_questions_after_answer - user is None
|
||||
def test_get_suggested_questions_user_none(self, factory):
|
||||
app = factory.create_app_mock()
|
||||
with pytest.raises(ValueError, match="user cannot be None"):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=None, message_id="msg-123", invoke_from=MagicMock()
|
||||
)
|
||||
|
||||
# Test 28: get_suggested_questions_after_answer - Advanced Chat success
|
||||
@patch("services.message_service.ModelManager")
|
||||
@patch("services.message_service.WorkflowService")
|
||||
@patch("services.message_service.AdvancedChatAppConfigManager")
|
||||
@patch("services.message_service.TokenBufferMemory")
|
||||
@patch("services.message_service.LLMGenerator")
|
||||
@patch("services.message_service.TraceQueueManager")
|
||||
@patch.object(MessageService, "get_message")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_get_suggested_questions_advanced_chat_success(
|
||||
self,
|
||||
mock_conversation_service,
|
||||
mock_get_message,
|
||||
mock_trace_manager,
|
||||
mock_llm_gen,
|
||||
mock_memory,
|
||||
mock_config_manager,
|
||||
mock_workflow_service,
|
||||
mock_model_manager,
|
||||
factory,
|
||||
):
|
||||
"""Test successful suggested questions generation in Advanced Chat mode."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
# Arrange
|
||||
app = factory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value)
|
||||
user = factory.create_end_user_mock()
|
||||
message = factory.create_message_mock()
|
||||
mock_get_message.return_value = message
|
||||
|
||||
workflow = MagicMock()
|
||||
mock_workflow_service.return_value.get_published_workflow.return_value = workflow
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.additional_features.suggested_questions_after_answer = True
|
||||
mock_config_manager.get_app_config.return_value = app_config
|
||||
|
||||
mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"]
|
||||
|
||||
# Act
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=user, message_id="msg-123", invoke_from=InvokeFrom.WEB_APP
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == ["Q1?"]
|
||||
mock_workflow_service.return_value.get_published_workflow.assert_called_once()
|
||||
mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once()
|
||||
|
||||
# Test 29: get_suggested_questions_after_answer - Chat app success (no override)
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ModelManager")
|
||||
@patch("services.message_service.TokenBufferMemory")
|
||||
@patch("services.message_service.LLMGenerator")
|
||||
@patch("services.message_service.TraceQueueManager")
|
||||
@patch.object(MessageService, "get_message")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_get_suggested_questions_chat_app_success(
|
||||
self,
|
||||
mock_conversation_service,
|
||||
mock_get_message,
|
||||
mock_trace_manager,
|
||||
mock_llm_gen,
|
||||
mock_memory,
|
||||
mock_model_manager,
|
||||
mock_db,
|
||||
factory,
|
||||
):
|
||||
"""Test successful suggested questions generation in basic Chat mode."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock(mode=AppMode.CHAT.value)
|
||||
user = factory.create_end_user_mock()
|
||||
message = factory.create_message_mock()
|
||||
mock_get_message.return_value = message
|
||||
|
||||
conversation = MagicMock()
|
||||
conversation.override_model_configs = None
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.suggested_questions_after_answer_dict = {"enabled": True}
|
||||
app_model_config.model_dict = {"provider": "openai", "name": "gpt-4"}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = app_model_config
|
||||
|
||||
mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"]
|
||||
|
||||
# Act
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=user, message_id="msg-123", invoke_from=MagicMock()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == ["Q1?"]
|
||||
mock_query.first.assert_called_once()
|
||||
mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once()
|
||||
|
||||
# Test 30: get_suggested_questions_after_answer - Disabled Error
|
||||
@patch("services.message_service.WorkflowService")
|
||||
@patch("services.message_service.AdvancedChatAppConfigManager")
|
||||
@patch.object(MessageService, "get_message")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_get_suggested_questions_disabled_error(
|
||||
self, mock_conversation_service, mock_get_message, mock_config_manager, mock_workflow_service, factory
|
||||
):
|
||||
"""Test SuggestedQuestionsAfterAnswerDisabledError is raised when feature is disabled."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value)
|
||||
user = factory.create_end_user_mock()
|
||||
mock_get_message.return_value = factory.create_message_mock()
|
||||
|
||||
workflow = MagicMock()
|
||||
mock_workflow_service.return_value.get_published_workflow.return_value = workflow
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.additional_features.suggested_questions_after_answer = False
|
||||
mock_config_manager.get_app_config.return_value = app_config
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=user, message_id="msg-123", invoke_from=MagicMock()
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user