diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 692a3639cd..713c4c6782 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -50,12 +50,16 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) - # Check if include_ids is not None and not empty to avoid WHERE false condition - if include_ids is not None and len(include_ids) > 0: + # Check if include_ids is not None to apply filter + if include_ids is not None: + if len(include_ids) == 0: + # If include_ids is empty, return empty result + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) stmt = stmt.where(Conversation.id.in_(include_ids)) - # Check if exclude_ids is not None and not empty to avoid WHERE false condition - if exclude_ids is not None and len(exclude_ids) > 0: - stmt = stmt.where(~Conversation.id.in_(exclude_ids)) + # Check if exclude_ids is not None to apply filter + if exclude_ids is not None: + if len(exclude_ids) > 0: + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py new file mode 100644 index 0000000000..9c1c044f03 --- /dev/null +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -0,0 +1,127 @@ +import uuid +from unittest.mock import MagicMock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from services.conversation_service import ConversationService + + +class TestConversationService: + def test_pagination_with_empty_include_ids(self): + """Test that empty include_ids returns empty result""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=[], # Empty include_ids should return empty result + exclude_ids=None, + ) + + assert result.data == [] + assert result.has_more is False + assert result.limit == 20 + + def test_pagination_with_non_empty_include_ids(self): + """Test that non-empty include_ids filters properly""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=["conv1", "conv2"], # Non-empty include_ids + exclude_ids=None, + ) + + # Verify the where clause was called with id.in_ + assert mock_stmt.where.called + + def test_pagination_with_empty_exclude_ids(self): + """Test that empty exclude_ids doesn't filter""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=None, + exclude_ids=[], # Empty exclude_ids should not filter + ) + + # Result should contain the mocked conversations + assert len(result.data) == 5 + + def test_pagination_with_non_empty_exclude_ids(self): + """Test that non-empty exclude_ids filters properly""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=None, + exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids + ) + + # Verify the where clause was called for exclusion + assert mock_stmt.where.called