mirror of
https://github.com/langgenius/dify.git
synced 2026-06-25 22:31:10 +08:00
refactor: replace logger patches with pytest caplog in tests (#37890)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
parent
a421362847
commit
2483c091aa
@ -195,16 +195,16 @@ class TestTencentDataTrace:
|
||||
mock_proc.assert_called_once_with(trace_info, 123)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
|
||||
def test_workflow_trace_exception(self, tencent_data_trace):
|
||||
def test_workflow_trace_exception(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run-id"
|
||||
|
||||
with patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace.workflow_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace")
|
||||
assert "[Tencent APM] Failed to process workflow trace" in caplog.text
|
||||
|
||||
def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
@ -228,15 +228,15 @@ class TestTencentDataTrace:
|
||||
mock_metrics.assert_called_once_with(trace_info)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
|
||||
def test_message_trace_exception(self, tencent_data_trace):
|
||||
def test_message_trace_exception(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
|
||||
with patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace.message_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace")
|
||||
assert "[Tencent APM] Failed to process message trace" in caplog.text
|
||||
|
||||
def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
@ -259,16 +259,16 @@ class TestTencentDataTrace:
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
tencent_data_trace.trace_client.add_span.assert_not_called()
|
||||
|
||||
def test_tool_trace_exception(self, tencent_data_trace):
|
||||
def test_tool_trace_exception(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
with patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace")
|
||||
assert "[Tencent APM] Failed to process tool trace" in caplog.text
|
||||
|
||||
def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
@ -291,29 +291,30 @@ class TestTencentDataTrace:
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
tencent_data_trace.trace_client.add_span.assert_not_called()
|
||||
|
||||
def test_dataset_retrieval_trace_exception(self, tencent_data_trace):
|
||||
def test_dataset_retrieval_trace_exception(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
with patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace")
|
||||
assert "[Tencent APM] Failed to process dataset retrieval trace" in caplog.text
|
||||
|
||||
def test_suggested_question_trace(self, tencent_data_trace):
|
||||
def test_suggested_question_trace(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log:
|
||||
with caplog.at_level(logging.INFO):
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace")
|
||||
assert "[Tencent APM] Processing suggested question trace" in caplog.text
|
||||
|
||||
def test_suggested_question_trace_exception(self, tencent_data_trace):
|
||||
def test_suggested_question_trace_exception(self, tencent_data_trace, monkeypatch, caplog):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace")
|
||||
target_logger = logging.getLogger("dify_trace_tencent.tencent_trace")
|
||||
monkeypatch.setattr(target_logger, "info", MagicMock(side_effect=Exception("error")))
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
assert "[Tencent APM] Failed to process suggested question trace" in caplog.text
|
||||
|
||||
def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
@ -335,7 +336,7 @@ class TestTencentDataTrace:
|
||||
assert tencent_data_trace.trace_client.add_span.call_count == 2
|
||||
mock_metrics.assert_called_once_with(node1)
|
||||
|
||||
def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils):
|
||||
def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils, caplog):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
mock_trace_utils.convert_to_span_id.return_value = 111
|
||||
|
||||
@ -344,18 +345,17 @@ class TestTencentDataTrace:
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
|
||||
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
# The exception should be caught by the outer handler since convert_to_span_id is called first
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
|
||||
assert "[Tencent APM] Failed to process workflow nodes" in caplog.text
|
||||
|
||||
def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils):
|
||||
def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils, caplog):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
|
||||
assert "[Tencent APM] Failed to process workflow nodes" in caplog.text
|
||||
|
||||
def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
@ -377,16 +377,16 @@ class TestTencentDataTrace:
|
||||
assert result == "span"
|
||||
builder_method.assert_called_once_with(123, 456, trace_info, node)
|
||||
|
||||
def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder):
|
||||
def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder, caplog):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.id = "n1"
|
||||
mock_span_builder.build_workflow_llm_span.side_effect = Exception("error")
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456)
|
||||
assert result is None
|
||||
mock_log.assert_called_once()
|
||||
assert result is None
|
||||
assert len([r for r in caplog.records if r.levelno == logging.DEBUG]) >= 1
|
||||
|
||||
def test_get_workflow_node_executions(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
@ -419,16 +419,16 @@ class TestTencentDataTrace:
|
||||
assert results == mock_executions
|
||||
account.set_tenant_id.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace):
|
||||
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
assert results == []
|
||||
mock_log.assert_called_once()
|
||||
assert results == []
|
||||
assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1
|
||||
|
||||
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace):
|
||||
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"app_id": "app-1"}
|
||||
|
||||
@ -439,10 +439,10 @@ class TestTencentDataTrace:
|
||||
session = mock_session_ctx.return_value.__enter__.return_value
|
||||
session.scalar.return_value = None
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
assert results == []
|
||||
mock_log.assert_called_once()
|
||||
assert results == []
|
||||
assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1
|
||||
|
||||
def test_get_user_id_workflow(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
@ -471,16 +471,16 @@ class TestTencentDataTrace:
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "anonymous"
|
||||
|
||||
def test_get_user_id_exception(self, tencent_data_trace):
|
||||
def test_get_user_id_exception(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.tenant_id = "t"
|
||||
trace_info.metadata = {"user_id": "u"}
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "unknown"
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID")
|
||||
assert user_id == "unknown"
|
||||
assert "[Tencent APM] Failed to get user ID" in caplog.text
|
||||
|
||||
def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
@ -514,14 +514,14 @@ class TestTencentDataTrace:
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_token_usage.assert_called_once()
|
||||
|
||||
def test_record_llm_metrics_exception(self, tencent_data_trace):
|
||||
def test_record_llm_metrics_exception(self, tencent_data_trace, caplog):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.process_data = None
|
||||
node.outputs = None
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
tencent_data_trace._record_llm_metrics(node)
|
||||
# Should not crash
|
||||
# Should not crash
|
||||
|
||||
def test_record_message_llm_metrics(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
@ -553,13 +553,13 @@ class TestTencentDataTrace:
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
|
||||
def test_record_message_llm_metrics_exception(self, tencent_data_trace):
|
||||
def test_record_message_llm_metrics_exception(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = None
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
# Should not crash
|
||||
# Should not crash
|
||||
|
||||
def test_record_workflow_trace_duration(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
@ -605,11 +605,11 @@ class TestTencentDataTrace:
|
||||
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
|
||||
assert attributes["has_conversation"] == "false"
|
||||
|
||||
def test_record_workflow_trace_duration_exception(self, tencent_data_trace):
|
||||
def test_record_workflow_trace_duration_exception(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
tencent_data_trace._record_workflow_trace_duration(trace_info)
|
||||
|
||||
def test_record_message_trace_duration(self, tencent_data_trace):
|
||||
@ -627,11 +627,11 @@ class TestTencentDataTrace:
|
||||
2.0, {"conversation_mode": "chat", "stream": "true"}
|
||||
)
|
||||
|
||||
def test_record_message_trace_duration_exception(self, tencent_data_trace):
|
||||
def test_record_message_trace_duration_exception(self, tencent_data_trace, caplog):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.start_time = None
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
tencent_data_trace._record_message_trace_duration(trace_info)
|
||||
|
||||
def test_close(self, tencent_data_trace):
|
||||
@ -647,11 +647,11 @@ class TestTencentDataTrace:
|
||||
|
||||
client.shutdown.assert_called_once()
|
||||
|
||||
def test_close_exception(self, tencent_data_trace):
|
||||
def test_close_exception(self, tencent_data_trace, caplog):
|
||||
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace.close()
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
assert "[Tencent APM] Failed to shutdown trace client during cleanup" in caplog.text
|
||||
|
||||
def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
|
||||
shutdown = AsyncMock()
|
||||
|
||||
@ -9,6 +9,7 @@ This module tests the core authentication endpoints including:
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from unittest.mock import ANY, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -191,7 +192,7 @@ class TestLoginApi:
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask):
|
||||
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog):
|
||||
"""
|
||||
Test login rejection when rate limit is exceeded.
|
||||
|
||||
@ -204,22 +205,24 @@ class TestLoginApi:
|
||||
mock_get_invitation.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(EmailPasswordLoginLimitError):
|
||||
login_api.post()
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(EmailPasswordLoginLimitError):
|
||||
login_api.post()
|
||||
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "test@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.LOGIN_RATE_LIMITED
|
||||
warn_records = [
|
||||
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
|
||||
]
|
||||
assert len(warn_records) == 1
|
||||
assert warn_records[0].args[0] == "test@example.com"
|
||||
assert warn_records[0].args[1] == LoginFailureReason.LOGIN_RATE_LIMITED
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True)
|
||||
@patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
|
||||
def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask):
|
||||
def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask, caplog):
|
||||
"""
|
||||
Test login rejection for frozen accounts.
|
||||
|
||||
@ -231,17 +234,19 @@ class TestLoginApi:
|
||||
mock_is_frozen.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountInFreezeError):
|
||||
login_api.post()
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountInFreezeError):
|
||||
login_api.post()
|
||||
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "frozen@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_IN_FREEZE
|
||||
warn_records = [
|
||||
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
|
||||
]
|
||||
assert len(warn_records) == 1
|
||||
assert warn_records[0].args[0] == "frozen@example.com"
|
||||
assert warn_records[0].args[1] == LoginFailureReason.ACCOUNT_IN_FREEZE
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@ -257,6 +262,7 @@ class TestLoginApi:
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
caplog,
|
||||
):
|
||||
"""
|
||||
Test login failure with invalid credentials.
|
||||
@ -272,20 +278,22 @@ class TestLoginApi:
|
||||
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "password": encode_password("WrongPass123!")},
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
login_api.post()
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "password": encode_password("WrongPass123!")},
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
login_api.post()
|
||||
|
||||
mock_add_rate_limit.assert_called_once_with("test@example.com")
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "test@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS
|
||||
warn_records = [
|
||||
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
|
||||
]
|
||||
assert len(warn_records) == 1
|
||||
assert warn_records[0].args[0] == "test@example.com"
|
||||
assert warn_records[0].args[1] == LoginFailureReason.INVALID_CREDENTIALS
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@ -293,7 +301,7 @@ class TestLoginApi:
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
def test_login_fails_for_banned_account(
|
||||
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask
|
||||
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog
|
||||
):
|
||||
"""
|
||||
Test login rejection for banned accounts.
|
||||
@ -308,19 +316,21 @@ class TestLoginApi:
|
||||
mock_authenticate.side_effect = AccountLoginError("Account is banned")
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "banned@example.com", "password": encode_password("ValidPass123!")},
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountBannedError):
|
||||
login_api.post()
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "banned@example.com", "password": encode_password("ValidPass123!")},
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountBannedError):
|
||||
login_api.post()
|
||||
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "banned@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
|
||||
warn_records = [
|
||||
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
|
||||
]
|
||||
assert len(warn_records) == 1
|
||||
assert warn_records[0].args[0] == "banned@example.com"
|
||||
assert warn_records[0].args[1] == LoginFailureReason.ACCOUNT_BANNED
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@ -452,23 +462,26 @@ class TestLoginApi:
|
||||
mock_get_token_data: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
app: Flask,
|
||||
caplog,
|
||||
):
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
mock_get_account.side_effect = Unauthorized("Account is banned.")
|
||||
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
EmailCodeLoginApi().post()
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
EmailCodeLoginApi().post()
|
||||
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "user@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
|
||||
warn_records = [
|
||||
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
|
||||
]
|
||||
assert len(warn_records) == 1
|
||||
assert warn_records[0].args[0] == "user@example.com"
|
||||
assert warn_records[0].args[1] == LoginFailureReason.ACCOUNT_BANNED
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user