From e6d1431a02db69c002609b63347d1f91790acb61 Mon Sep 17 00:00:00 2001 From: Dev Sharma <50591491+cryptus-neoxys@users.noreply.github.com> Date: Sun, 22 Mar 2026 21:59:18 +0530 Subject: [PATCH] test: improve code-cov for controller tests (#33225) --- .../controllers/console/app/test_message.py | 320 ++++++++++++++ .../controllers/console/app/test_statistic.py | 275 ++++++++++++ .../app/test_workflow_draft_variable.py | 313 +++++++++++++ .../auth/test_data_source_bearer_auth.py | 209 +++++++++ .../console/auth/test_data_source_oauth.py | 192 ++++++++ .../console/auth/test_oauth_server.py | 417 ++++++++++++++++++ 6 files changed, 1726 insertions(+) create mode 100644 api/tests/unit_tests/controllers/console/app/test_message.py create mode 100644 api/tests/unit_tests/controllers/console/app/test_statistic.py create mode 100644 api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py create mode 100644 api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py create mode 100644 api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py create mode 100644 api/tests/unit_tests/controllers/console/auth/test_oauth_server.py diff --git a/api/tests/unit_tests/controllers/console/app/test_message.py b/api/tests/unit_tests/controllers/console/app/test_message.py new file mode 100644 index 0000000000..3ffa53b6db --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_message.py @@ -0,0 +1,320 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, request +from werkzeug.exceptions import InternalServerError, NotFound +from werkzeug.local import LocalProxy + +from controllers.console.app.error import ( + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.app.message import ( + ChatMessageListApi, + ChatMessagesQuery, + FeedbackExportQuery, + MessageAnnotationCountApi, + MessageApi, + MessageFeedbackApi, + MessageFeedbackExportApi, + MessageFeedbackPayload, + MessageSuggestedQuestionApi, +) +from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from models import App, AppMode +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" + return flask_app + + +@pytest.fixture +def mock_account(): + from models.account import Account, AccountStatus + + account = MagicMock(spec=Account) + account.id = "user_123" + account.timezone = "UTC" + account.status = AccountStatus.ACTIVE + account.is_admin_or_owner = True + account.current_tenant.current_role = "owner" + account.has_edit_permission = True + return account + + +@pytest.fixture +def mock_app_model(): + app_model = MagicMock(spec=App) + app_model.id = "app_123" + app_model.mode = AppMode.CHAT + app_model.tenant_id = "tenant_123" + return app_model + + +@pytest.fixture(autouse=True) +def mock_csrf(): + with patch("libs.login.check_csrf_token") as mock: + yield mock + + +import contextlib + + +@contextlib.contextmanager +def setup_test_context( + test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None +): + with ( + patch("extensions.ext_database.db") as mock_db, + patch("controllers.console.app.wraps.db", mock_db), + patch("controllers.console.wraps.db", mock_db), + patch("controllers.console.app.message.db", mock_db), + patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + ): + # Set up a generic query mock that usually returns mock_app_model when getting app + app_query_mock = MagicMock() + app_query_mock.filter.return_value.first.return_value = mock_app_model + app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model + app_query_mock.where.return_value.first.return_value = mock_app_model + app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model + + data_query_mock = MagicMock() + + def query_side_effect(*args, **kwargs): + if args and hasattr(args[0], "__name__") and args[0].__name__ == "App": + return app_query_mock + return data_query_mock + + mock_db.session.query.side_effect = query_side_effect + mock_db.data_query = data_query_mock + + # Let the caller override the stat db query logic + proxy_mock = LocalProxy(lambda: mock_account) + + query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()]) + full_path = f"{route_path}?{query_string}" if qs else route_path + + with ( + patch("libs.login.current_user", proxy_mock), + patch("flask_login.current_user", proxy_mock), + patch("controllers.console.app.message.attach_message_extra_contents", return_value=None), + ): + with test_app.test_request_context(full_path, method=method, json=payload): + request.view_args = {"app_id": "app_123"} + + if "suggested-questions" in route_path: + # simplistic extraction for message_id + parts = route_path.split("chat-messages/") + if len(parts) > 1: + request.view_args["message_id"] = parts[1].split("/")[0] + elif "messages/" in route_path and "chat-messages" not in route_path: + parts = route_path.split("messages/") + if len(parts) > 1: + request.view_args["message_id"] = parts[1].split("/")[0] + + api_instance = endpoint_class() + + # Check if it has a dispatch_request or method + if hasattr(api_instance, method.lower()): + yield api_instance, mock_db, request.view_args + + +class TestMessageValidators: + def test_chat_messages_query_validators(self): + # Test empty_to_none + assert ChatMessagesQuery.empty_to_none("") is None + assert ChatMessagesQuery.empty_to_none("val") == "val" + + # Test validate_uuid + assert ChatMessagesQuery.validate_uuid(None) is None + assert ( + ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_message_feedback_validators(self): + assert ( + MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_feedback_export_validators(self): + assert FeedbackExportQuery.parse_bool(None) is None + assert FeedbackExportQuery.parse_bool(True) is True + assert FeedbackExportQuery.parse_bool("1") is True + assert FeedbackExportQuery.parse_bool("0") is False + assert FeedbackExportQuery.parse_bool("off") is False + + with pytest.raises(ValueError): + FeedbackExportQuery.parse_bool("invalid") + + +class TestMessageEndpoints: + def test_chat_message_list_not_found(self, app, mock_account, mock_app_model): + with setup_test_context( + app, + ChatMessageListApi, + "/apps/app_123/chat-messages", + "GET", + mock_account, + mock_app_model, + qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}, + ) as (api, mock_db, v_args): + mock_db.data_query.where.return_value.first.return_value = None + + with pytest.raises(NotFound): + api.get(**v_args) + + def test_chat_message_list_success(self, app, mock_account, mock_app_model): + with setup_test_context( + app, + ChatMessageListApi, + "/apps/app_123/chat-messages", + "GET", + mock_account, + mock_app_model, + qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1}, + ) as (api, mock_db, v_args): + mock_conv = MagicMock() + mock_conv.id = "123e4567-e89b-12d3-a456-426614174000" + mock_msg = MagicMock() + mock_msg.id = "msg_123" + mock_msg.feedbacks = [] + mock_msg.annotation = None + mock_msg.annotation_hit_history = None + mock_msg.agent_thoughts = [] + mock_msg.message_files = [] + mock_msg.extra_contents = [] + mock_msg.message = {} + mock_msg.message_metadata_dict = {} + + # mock returns + q_mock = mock_db.data_query + q_mock.where.return_value.first.side_effect = [mock_conv] + q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg] + mock_db.session.scalar.return_value = False + + resp = api.get(**v_args) + assert resp["limit"] == 1 + assert resp["has_more"] is False + assert len(resp["data"]) == 1 + + def test_message_feedback_not_found(self, app, mock_account, mock_app_model): + with setup_test_context( + app, + MessageFeedbackApi, + "/apps/app_123/feedbacks", + "POST", + mock_account, + mock_app_model, + payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"}, + ) as (api, mock_db, v_args): + mock_db.data_query.where.return_value.first.return_value = None + + with pytest.raises(NotFound): + api.post(**v_args) + + def test_message_feedback_success(self, app, mock_account, mock_app_model): + payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"} + with setup_test_context( + app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload + ) as (api, mock_db, v_args): + mock_msg = MagicMock() + mock_msg.admin_feedback = None + mock_db.data_query.where.return_value.first.return_value = mock_msg + + resp = api.post(**v_args) + assert resp == {"result": "success"} + + def test_message_annotation_count(self, app, mock_account, mock_app_model): + with setup_test_context( + app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model + ) as (api, mock_db, v_args): + mock_db.data_query.where.return_value.count.return_value = 5 + + resp = api.get(**v_args) + assert resp == {"count": 5} + + @patch("controllers.console.app.message.MessageService") + def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model): + mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"] + + with setup_test_context( + app, + MessageSuggestedQuestionApi, + "/apps/app_123/chat-messages/msg_123/suggested-questions", + "GET", + mock_account, + mock_app_model, + ) as (api, mock_db, v_args): + resp = api.get(**v_args) + assert resp == {"data": ["q1", "q2"]} + + @pytest.mark.parametrize( + ("exc", "expected_exc"), + [ + (MessageNotExistsError, NotFound), + (ConversationNotExistsError, NotFound), + (ProviderTokenNotInitError, ProviderNotInitializeError), + (QuotaExceededError, ProviderQuotaExceededError), + (ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError), + (SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError), + (Exception, InternalServerError), + ], + ) + @patch("controllers.console.app.message.MessageService") + def test_message_suggested_questions_errors( + self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model + ): + mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc() + + with setup_test_context( + app, + MessageSuggestedQuestionApi, + "/apps/app_123/chat-messages/msg_123/suggested-questions", + "GET", + mock_account, + mock_app_model, + ) as (api, mock_db, v_args): + with pytest.raises(expected_exc): + api.get(**v_args) + + @patch("services.feedback_service.FeedbackService.export_feedbacks") + def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model): + mock_export.return_value = {"exported": True} + + with setup_test_context( + app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model + ) as (api, mock_db, v_args): + resp = api.get(**v_args) + assert resp == {"exported": True} + + def test_message_api_get_success(self, app, mock_account, mock_app_model): + with setup_test_context( + app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model + ) as (api, mock_db, v_args): + mock_msg = MagicMock() + mock_msg.id = "msg_123" + mock_msg.feedbacks = [] + mock_msg.annotation = None + mock_msg.annotation_hit_history = None + mock_msg.agent_thoughts = [] + mock_msg.message_files = [] + mock_msg.extra_contents = [] + mock_msg.message = {} + mock_msg.message_metadata_dict = {} + + mock_db.data_query.where.return_value.first.return_value = mock_msg + + resp = api.get(**v_args) + assert resp["id"] == "msg_123" diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic.py b/api/tests/unit_tests/controllers/console/app/test_statistic.py new file mode 100644 index 0000000000..beba23385d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_statistic.py @@ -0,0 +1,275 @@ +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, request +from werkzeug.local import LocalProxy + +from controllers.console.app.statistic import ( + AverageResponseTimeStatistic, + AverageSessionInteractionStatistic, + DailyConversationStatistic, + DailyMessageStatistic, + DailyTerminalsStatistic, + DailyTokenCostStatistic, + TokensPerSecondStatistic, + UserSatisfactionRateStatistic, +) +from models import App, AppMode + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture +def mock_account(): + from models.account import Account, AccountStatus + + account = MagicMock(spec=Account) + account.id = "user_123" + account.timezone = "UTC" + account.status = AccountStatus.ACTIVE + account.is_admin_or_owner = True + account.current_tenant.current_role = "owner" + account.has_edit_permission = True + return account + + +@pytest.fixture +def mock_app_model(): + app_model = MagicMock(spec=App) + app_model.id = "app_123" + app_model.mode = AppMode.CHAT + app_model.tenant_id = "tenant_123" + return app_model + + +@pytest.fixture(autouse=True) +def mock_csrf(): + with patch("libs.login.check_csrf_token") as mock: + yield mock + + +def setup_test_context( + test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None) +): + with ( + patch("controllers.console.app.statistic.db") as mock_db_stat, + patch("controllers.console.app.wraps.db") as mock_db_wraps, + patch("controllers.console.wraps.db", mock_db_wraps), + patch( + "controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123") + ), + patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + ): + mock_conn = MagicMock() + mock_conn.execute.return_value = mock_rs + + mock_begin = MagicMock() + mock_begin.__enter__.return_value = mock_conn + mock_db_stat.engine.begin.return_value = mock_begin + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = mock_app_model + mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model + mock_query.where.return_value.first.return_value = mock_app_model + mock_query.where.return_value.where.return_value.first.return_value = mock_app_model + mock_db_wraps.session.query.return_value = mock_query + + proxy_mock = LocalProxy(lambda: mock_account) + + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + with test_app.test_request_context(route_path, method="GET"): + request.view_args = {"app_id": "app_123"} + api_instance = endpoint_class() + response = api_instance.get(app_id="app_123") + return response + + +class TestStatisticEndpoints: + def test_daily_message_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.message_count = 10 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + DailyMessageStatistic, + "/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["message_count"] == 10 + + def test_daily_conversation_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.conversation_count = 5 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + DailyConversationStatistic, + "/apps/app_123/statistics/daily-conversations", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["conversation_count"] == 5 + + def test_daily_terminals_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.terminal_count = 2 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + DailyTerminalsStatistic, + "/apps/app_123/statistics/daily-end-users", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["terminal_count"] == 2 + + def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.token_count = 100 + mock_row.total_price = Decimal("0.02") + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + DailyTokenCostStatistic, + "/apps/app_123/statistics/token-costs", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["token_count"] == 100 + assert response.json["data"][0]["total_price"] == "0.02" + + def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.interactions = Decimal("3.523") + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + AverageSessionInteractionStatistic, + "/apps/app_123/statistics/average-session-interactions", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["interactions"] == 3.52 + + def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.message_count = 100 + mock_row.feedback_count = 10 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + UserSatisfactionRateStatistic, + "/apps/app_123/statistics/user-satisfaction-rate", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["rate"] == 100.0 + + def test_average_response_time_statistic(self, app, mock_account, mock_app_model): + mock_app_model.mode = AppMode.COMPLETION + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.latency = 1.234 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + AverageResponseTimeStatistic, + "/apps/app_123/statistics/average-response-time", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["latency"] == 1234.0 + + def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.tokens_per_second = 15.5 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + TokensPerSecondStatistic, + "/apps/app_123/statistics/tokens-per-second", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["tps"] == 15.5 + + @patch("controllers.console.app.statistic.parse_time_range") + def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model): + mock_parse.side_effect = ValueError("Invalid time") + + from werkzeug.exceptions import BadRequest + + with pytest.raises(BadRequest): + setup_test_context( + app, + DailyMessageStatistic, + "/apps/app_123/statistics/daily-messages?start=invalid&end=invalid", + mock_account, + mock_app_model, + [], + ) + + @patch("controllers.console.app.statistic.parse_time_range") + def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model): + import datetime + + start = datetime.datetime.now() + end = datetime.datetime.now() + mock_parse.return_value = (start, end) + + response = setup_test_context( + app, + DailyMessageStatistic, + "/apps/app_123/statistics/daily-messages?start=something&end=something", + mock_account, + mock_app_model, + [], + ) + assert response.status_code == 200 + mock_parse.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 0000000000..9b5d47c208 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,313 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, request +from werkzeug.local import LocalProxy + +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.app.workflow_draft_variable import ( + ConversationVariableCollectionApi, + EnvironmentVariableCollectionApi, + NodeVariableCollectionApi, + SystemVariableCollectionApi, + VariableApi, + VariableResetApi, + WorkflowVariableCollectionApi, +) +from controllers.web.error import InvalidArgumentError, NotFoundError +from models import App, AppMode +from models.enums import DraftVariableType + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" + return flask_app + + +@pytest.fixture +def mock_account(): + from models.account import Account, AccountStatus + + account = MagicMock(spec=Account) + account.id = "user_123" + account.timezone = "UTC" + account.status = AccountStatus.ACTIVE + account.is_admin_or_owner = True + account.current_tenant.current_role = "owner" + account.has_edit_permission = True + return account + + +@pytest.fixture +def mock_app_model(): + app_model = MagicMock(spec=App) + app_model.id = "app_123" + app_model.mode = AppMode.WORKFLOW + app_model.tenant_id = "tenant_123" + return app_model + + +@pytest.fixture(autouse=True) +def mock_csrf(): + with patch("libs.login.check_csrf_token") as mock: + yield mock + + +def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None): + with ( + patch("controllers.console.app.wraps.db") as mock_db_wraps, + patch("controllers.console.wraps.db", mock_db_wraps), + patch("controllers.console.app.workflow_draft_variable.db"), + patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + ): + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = mock_app_model + mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model + mock_query.where.return_value.first.return_value = mock_app_model + mock_query.where.return_value.where.return_value.first.return_value = mock_app_model + mock_db_wraps.session.query.return_value = mock_query + + proxy_mock = LocalProxy(lambda: mock_account) + + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + with test_app.test_request_context(route_path, method=method, json=payload): + request.view_args = {"app_id": "app_123"} + # extract node_id or variable_id from path manually since view_args overrides + if "nodes/" in route_path: + request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0] + if "variables/" in route_path: + # simplistic extraction + parts = route_path.split("variables/") + if len(parts) > 1 and parts[1] and parts[1] != "reset": + request.view_args["variable_id"] = parts[1].split("/")[0] + + api_instance = endpoint_class() + # we just call dispatch_request to avoid manual argument passing + if hasattr(api_instance, method.lower()): + func = getattr(api_instance, method.lower()) + return func(**request.view_args) + + +class TestWorkflowDraftVariableEndpoints: + @staticmethod + def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock: + class DummyValueType: + def exposed_type(self): + return DraftVariableType.NODE + + mock_var = MagicMock() + mock_var.app_id = "app_123" + mock_var.id = "var_123" + mock_var.name = "test_var" + mock_var.description = "" + mock_var.get_variable_type.return_value = variable_type + mock_var.get_selector.return_value = [] + mock_var.value_type = DummyValueType() + mock_var.edited = False + mock_var.visible = True + mock_var.file_id = None + mock_var.variable_file = None + mock_var.is_truncated.return_value = False + mock_var.get_value.return_value.model_copy.return_value.value = "test_value" + return mock_var + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_workflow_variable_collection_get_success( + self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model + ): + mock_wf_srv.return_value.is_workflow_exist.return_value = True + from services.workflow_draft_variable_service import WorkflowDraftVariableList + + mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList( + variables=[], total=0 + ) + + resp = setup_test_context( + app, + WorkflowVariableCollectionApi, + "/apps/app_123/workflows/draft/variables?page=1&limit=20", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": [], "total": 0} + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model): + mock_wf_srv.return_value.is_workflow_exist.return_value = False + + with pytest.raises(DraftWorkflowNotExist): + setup_test_context( + app, + WorkflowVariableCollectionApi, + "/apps/app_123/workflows/draft/variables", + "GET", + mock_account, + mock_app_model, + ) + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): + resp = setup_test_context( + app, + WorkflowVariableCollectionApi, + "/apps/app_123/workflows/draft/variables", + "DELETE", + mock_account, + mock_app_model, + ) + assert resp.status_code == 204 + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): + from services.workflow_draft_variable_service import WorkflowDraftVariableList + + mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[]) + resp = setup_test_context( + app, + NodeVariableCollectionApi, + "/apps/app_123/workflows/draft/nodes/node_123/variables", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": []} + + def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model): + with pytest.raises(InvalidArgumentError): + setup_test_context( + app, + NodeVariableCollectionApi, + "/apps/app_123/workflows/draft/nodes/sys/variables", + "GET", + mock_account, + mock_app_model, + ) + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): + resp = setup_test_context( + app, + NodeVariableCollectionApi, + "/apps/app_123/workflows/draft/nodes/node_123/variables", + "DELETE", + mock_account, + mock_app_model, + ) + assert resp.status_code == 204 + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): + mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() + + resp = setup_test_context( + app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model + ) + assert resp["id"] == "var_123" + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model): + mock_draft_srv.return_value.get_variable.return_value = None + + with pytest.raises(NotFoundError): + setup_test_context( + app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model + ) + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model): + mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() + + resp = setup_test_context( + app, + VariableApi, + "/apps/app_123/workflows/draft/variables/var_123", + "PATCH", + mock_account, + mock_app_model, + payload={"name": "new_name"}, + ) + assert resp["id"] == "var_123" + mock_draft_srv.return_value.update_variable.assert_called_once() + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model): + mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() + + resp = setup_test_context( + app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model + ) + assert resp.status_code == 204 + mock_draft_srv.return_value.delete_variable.assert_called_once() + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): + mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() + mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() + mock_draft_srv.return_value.reset_variable.return_value = None # means no content + + resp = setup_test_context( + app, + VariableResetApi, + "/apps/app_123/workflows/draft/variables/var_123/reset", + "PUT", + mock_account, + mock_app_model, + ) + assert resp.status_code == 204 + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): + mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() + from services.workflow_draft_variable_service import WorkflowDraftVariableList + + mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[]) + + resp = setup_test_context( + app, + ConversationVariableCollectionApi, + "/apps/app_123/workflows/draft/conversation-variables", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": []} + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model): + from services.workflow_draft_variable_service import WorkflowDraftVariableList + + mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[]) + + resp = setup_test_context( + app, + SystemVariableCollectionApi, + "/apps/app_123/workflows/draft/system-variables", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": []} + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model): + mock_wf = MagicMock() + mock_wf.environment_variables = [] + mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf + + resp = setup_test_context( + app, + EnvironmentVariableCollectionApi, + "/apps/app_123/workflows/draft/environment-variables", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": []} diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 0000000000..bc4c7e0993 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,209 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.data_source_bearer_auth import ( + ApiKeyAuthDataSource, + ApiKeyAuthDataSourceBinding, + ApiKeyAuthDataSourceBindingDelete, +) +from controllers.console.auth.error import ApiKeyAuthFailedError + + +class TestApiKeyAuthDataSource: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["WTF_CSRF_ENABLED"] = False + return app + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") + def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + mock_binding = MagicMock() + mock_binding.id = "bind_123" + mock_binding.category = "api_key" + mock_binding.provider = "custom_provider" + mock_binding.disabled = False + mock_binding.created_at.timestamp.return_value = 1620000000 + mock_binding.updated_at.timestamp.return_value = 1620000001 + + mock_get_list.return_value = [mock_binding] + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSource() + response = api_instance.get() + + assert "sources" in response + assert len(response["sources"]) == 1 + assert response["sources"][0]["provider"] == "custom_provider" + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") + def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + mock_get_list.return_value = None + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSource() + response = api_instance.get() + + assert "sources" in response + assert len(response["sources"]) == 0 + + +class TestApiKeyAuthDataSourceBinding: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["WTF_CSRF_ENABLED"] = False + return app + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") + def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context( + "/console/api/api-key-auth/data-source/binding", + method="POST", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + ): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSourceBinding() + response = api_instance.post() + + assert response[0]["result"] == "success" + assert response[1] == 200 + mock_validate.assert_called_once() + mock_create.assert_called_once() + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") + def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + mock_create.side_effect = ValueError("Invalid structure") + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context( + "/console/api/api-key-auth/data-source/binding", + method="POST", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + ): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSourceBinding() + with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"): + api_instance.post() + + +class TestApiKeyAuthDataSourceBindingDelete: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["WTF_CSRF_ENABLED"] = False + return app + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth") + def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSourceBindingDelete() + response = api_instance.delete("binding_123") + + assert response[0]["result"] == "success" + assert response[1] == 204 + mock_delete.assert_called_once_with("tenant_123", "binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py new file mode 100644 index 0000000000..f369565946 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py @@ -0,0 +1,192 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.local import LocalProxy + +from controllers.console.auth.data_source_oauth import ( + OAuthDataSource, + OAuthDataSourceBinding, + OAuthDataSourceCallback, + OAuthDataSourceSync, +) + + +class TestOAuthDataSource: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + @patch("flask_login.current_user") + @patch("libs.login.current_user") + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None) + def test_get_oauth_url_successful( + self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app + ): + mock_oauth_provider = MagicMock() + mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth" + mock_get_providers.return_value = {"notion": mock_oauth_provider} + + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + mock_libs_user.return_value = mock_account + mock_flask_user.return_value = mock_account + + # also patch current_account_with_tenant + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): + with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"): + proxy_mock = LocalProxy(lambda: mock_account) + with patch("libs.login.current_user", proxy_mock): + api_instance = OAuthDataSource() + response = api_instance.get("notion") + + assert response[0]["data"] == "http://oauth.provider/auth" + assert response[1] == 200 + mock_oauth_provider.get_authorization_url.assert_called_once() + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + @patch("flask_login.current_user") + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app): + mock_get_providers.return_value = {"notion": MagicMock()} + + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): + with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"): + proxy_mock = LocalProxy(lambda: mock_account) + with patch("libs.login.current_user", proxy_mock): + api_instance = OAuthDataSource() + response = api_instance.get("unknown_provider") + + assert response[0]["error"] == "Invalid provider" + assert response[1] == 400 + + +class TestOAuthDataSourceCallback: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_oauth_callback_successful(self, mock_get_providers, app): + provider_mock = MagicMock() + mock_get_providers.return_value = {"notion": provider_mock} + + with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"): + api_instance = OAuthDataSourceCallback() + response = api_instance.get("notion") + + assert response.status_code == 302 + assert "code=mock_code" in response.location + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_oauth_callback_missing_code(self, mock_get_providers, app): + provider_mock = MagicMock() + mock_get_providers.return_value = {"notion": provider_mock} + + with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"): + api_instance = OAuthDataSourceCallback() + response = api_instance.get("notion") + + assert response.status_code == 302 + assert "error=Access denied" in response.location + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_oauth_callback_invalid_provider(self, mock_get_providers, app): + mock_get_providers.return_value = {"notion": MagicMock()} + + with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"): + api_instance = OAuthDataSourceCallback() + response = api_instance.get("invalid") + + assert response[0]["error"] == "Invalid provider" + assert response[1] == 400 + + +class TestOAuthDataSourceBinding: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_get_binding_successful(self, mock_get_providers, app): + mock_provider = MagicMock() + mock_provider.get_access_token.return_value = None + mock_get_providers.return_value = {"notion": mock_provider} + + with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"): + api_instance = OAuthDataSourceBinding() + response = api_instance.get("notion") + + assert response[0]["result"] == "success" + assert response[1] == 200 + mock_provider.get_access_token.assert_called_once_with("auth_code_123") + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_get_binding_missing_code(self, mock_get_providers, app): + mock_get_providers.return_value = {"notion": MagicMock()} + + with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"): + api_instance = OAuthDataSourceBinding() + response = api_instance.get("notion") + + assert response[0]["error"] == "Invalid code" + assert response[1] == 400 + + +class TestOAuthDataSourceSync: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app): + mock_provider = MagicMock() + mock_provider.sync_data_source.return_value = None + mock_get_providers.return_value = {"notion": mock_provider} + + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): + with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"): + proxy_mock = LocalProxy(lambda: mock_account) + with patch("libs.login.current_user", proxy_mock): + api_instance = OAuthDataSourceSync() + # The route pattern uses , so we just pass a string for unit testing + response = api_instance.get("notion", "binding_123") + + assert response[0]["result"] == "success" + assert response[1] == 200 + mock_provider.sync_data_source.assert_called_once_with("binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 0000000000..fc5663e72d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,417 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.console.auth.oauth_server import ( + OAuthServerAppApi, + OAuthServerUserAccountApi, + OAuthServerUserAuthorizeApi, + OAuthServerUserTokenApi, +) + + +class TestOAuthServerAppApi: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_oauth_provider_app(self): + from models.model import OAuthProviderApp + + oauth_app = MagicMock(spec=OAuthProviderApp) + oauth_app.client_id = "test_client_id" + oauth_app.redirect_uris = ["http://localhost/callback"] + oauth_app.app_icon = "icon_url" + oauth_app.app_label = "Test App" + oauth_app.scope = "read,write" + return oauth_app + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider", + method="POST", + json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, + ): + api_instance = OAuthServerAppApi() + response = api_instance.post() + + assert response["app_icon"] == "icon_url" + assert response["app_label"] == "Test App" + assert response["scope"] == "read,write" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider", + method="POST", + json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, + ): + api_instance = OAuthServerAppApi() + with pytest.raises(BadRequest, match="redirect_uri is invalid"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_client_id(self, mock_get_app, mock_db, app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = None + + with app.test_request_context( + "/oauth/provider", + method="POST", + json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, + ): + api_instance = OAuthServerAppApi() + with pytest.raises(NotFound, match="client_id is invalid"): + api_instance.post() + + +class TestOAuthServerUserAuthorizeApi: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_oauth_provider_app(self): + oauth_app = MagicMock() + oauth_app.client_id = "test_client_id" + return oauth_app + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.current_account_with_tenant") + @patch("controllers.console.wraps.current_account_with_tenant") + @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code") + @patch("libs.login.check_csrf_token") + def test_successful_authorize( + self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + mock_account = MagicMock() + mock_account.id = "user_123" + from models.account import AccountStatus + + mock_account.status = AccountStatus.ACTIVE + + mock_current.return_value = (mock_account, MagicMock()) + mock_wrap_current.return_value = (mock_account, MagicMock()) + + mock_sign.return_value = "auth_code_123" + + with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}): + with patch("libs.login.current_user", mock_account): + api_instance = OAuthServerUserAuthorizeApi() + response = api_instance.post() + + assert response["code"] == "auth_code_123" + mock_sign.assert_called_once_with("test_client_id", "user_123") + + +class TestOAuthServerUserTokenApi: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_oauth_provider_app(self): + from models.model import OAuthProviderApp + + oauth_app = MagicMock(spec=OAuthProviderApp) + oauth_app.client_id = "test_client_id" + oauth_app.client_secret = "test_secret" + oauth_app.redirect_uris = ["http://localhost/callback"] + return oauth_app + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") + def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + mock_sign.return_value = ("access_123", "refresh_123") + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ): + api_instance = OAuthServerUserTokenApi() + response = api_instance.post() + + assert response["access_token"] == "access_123" + assert response["refresh_token"] == "refresh_123" + assert response["token_type"] == "Bearer" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="code is required"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "invalid_secret", + "redirect_uri": "http://localhost/callback", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="client_secret is invalid"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://invalid/callback", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="redirect_uri is invalid"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") + def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + mock_sign.return_value = ("new_access", "new_refresh") + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, + ): + api_instance = OAuthServerUserTokenApi() + response = api_instance.post() + + assert response["access_token"] == "new_access" + assert response["refresh_token"] == "new_refresh" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "refresh_token", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="refresh_token is required"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "invalid_grant", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="invalid grant_type"): + api_instance.post() + + +class TestOAuthServerUserAccountApi: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_oauth_provider_app(self): + from models.model import OAuthProviderApp + + oauth_app = MagicMock(spec=OAuthProviderApp) + oauth_app.client_id = "test_client_id" + return oauth_app + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") + def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + mock_account = MagicMock() + mock_account.name = "Test User" + mock_account.email = "test@example.com" + mock_account.avatar = "avatar_url" + mock_account.interface_language = "en-US" + mock_account.timezone = "UTC" + mock_validate.return_value = mock_account + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer valid_access_token"}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response["name"] == "Test User" + assert response["email"] == "test@example.com" + assert response["avatar"] == "avatar_url" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "Authorization header is required" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "InvalidFormat"}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "Invalid Authorization header format" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Basic something"}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "token_type is invalid" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer "}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "Invalid Authorization header format" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") + def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + mock_validate.return_value = None + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer invalid_token"}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "access_token or client_id is invalid"