refactor: replace logger patches with pytest caplog in tests (#37890)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Willow Lopez 2026-06-25 05:40:06 +08:00 committed by GitHub
parent a421362847
commit 2483c091aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 127 additions and 114 deletions

View File

@ -195,16 +195,16 @@ class TestTencentDataTrace:
mock_proc.assert_called_once_with(trace_info, 123)
mock_dur.assert_called_once_with(trace_info)
def test_workflow_trace_exception(self, tencent_data_trace):
def test_workflow_trace_exception(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.workflow_run_id = "run-id"
with patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
tencent_data_trace.workflow_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace")
assert "[Tencent APM] Failed to process workflow trace" in caplog.text
def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
trace_info = MagicMock(spec=MessageTraceInfo)
@ -228,15 +228,15 @@ class TestTencentDataTrace:
mock_metrics.assert_called_once_with(trace_info)
mock_dur.assert_called_once_with(trace_info)
def test_message_trace_exception(self, tencent_data_trace):
def test_message_trace_exception(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=MessageTraceInfo)
with patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
tencent_data_trace.message_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace")
assert "[Tencent APM] Failed to process message trace" in caplog.text
def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
trace_info = MagicMock(spec=ToolTraceInfo)
@ -259,16 +259,16 @@ class TestTencentDataTrace:
tencent_data_trace.tool_trace(trace_info)
tencent_data_trace.trace_client.add_span.assert_not_called()
def test_tool_trace_exception(self, tencent_data_trace):
def test_tool_trace_exception(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=ToolTraceInfo)
trace_info.message_id = "msg-id"
with patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
tencent_data_trace.tool_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace")
assert "[Tencent APM] Failed to process tool trace" in caplog.text
def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
@ -291,29 +291,30 @@ class TestTencentDataTrace:
tencent_data_trace.dataset_retrieval_trace(trace_info)
tencent_data_trace.trace_client.add_span.assert_not_called()
def test_dataset_retrieval_trace_exception(self, tencent_data_trace):
def test_dataset_retrieval_trace_exception(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_info.message_id = "msg-id"
with patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
tencent_data_trace.dataset_retrieval_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace")
assert "[Tencent APM] Failed to process dataset retrieval trace" in caplog.text
def test_suggested_question_trace(self, tencent_data_trace):
def test_suggested_question_trace(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log:
with caplog.at_level(logging.INFO):
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace")
assert "[Tencent APM] Processing suggested question trace" in caplog.text
def test_suggested_question_trace_exception(self, tencent_data_trace):
def test_suggested_question_trace_exception(self, tencent_data_trace, monkeypatch, caplog):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")):
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace")
target_logger = logging.getLogger("dify_trace_tencent.tencent_trace")
monkeypatch.setattr(target_logger, "info", MagicMock(side_effect=Exception("error")))
with caplog.at_level(logging.ERROR):
tencent_data_trace.suggested_question_trace(trace_info)
assert "[Tencent APM] Failed to process suggested question trace" in caplog.text
def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils):
trace_info = MagicMock(spec=WorkflowTraceInfo)
@ -335,7 +336,7 @@ class TestTencentDataTrace:
assert tencent_data_trace.trace_client.add_span.call_count == 2
mock_metrics.assert_called_once_with(node1)
def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils):
def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils, caplog):
trace_info = MagicMock(spec=WorkflowTraceInfo)
mock_trace_utils.convert_to_span_id.return_value = 111
@ -344,18 +345,17 @@ class TestTencentDataTrace:
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
tencent_data_trace._process_workflow_nodes(trace_info, 123)
# The exception should be caught by the outer handler since convert_to_span_id is called first
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
assert "[Tencent APM] Failed to process workflow nodes" in caplog.text
def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils):
def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils, caplog):
trace_info = MagicMock(spec=WorkflowTraceInfo)
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
tencent_data_trace._process_workflow_nodes(trace_info, 123)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
assert "[Tencent APM] Failed to process workflow nodes" in caplog.text
def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder):
trace_info = MagicMock(spec=WorkflowTraceInfo)
@ -377,16 +377,16 @@ class TestTencentDataTrace:
assert result == "span"
builder_method.assert_called_once_with(123, 456, trace_info, node)
def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder):
def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder, caplog):
node = MagicMock(spec=WorkflowNodeExecution)
node.node_type = BuiltinNodeTypes.LLM
node.id = "n1"
mock_span_builder.build_workflow_llm_span.side_effect = Exception("error")
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
with caplog.at_level(logging.DEBUG):
result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456)
assert result is None
mock_log.assert_called_once()
assert result is None
assert len([r for r in caplog.records if r.levelno == logging.DEBUG]) >= 1
def test_get_workflow_node_executions(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
@ -419,16 +419,16 @@ class TestTencentDataTrace:
assert results == mock_executions
account.set_tenant_id.assert_called_once_with("tenant-1")
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace):
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {}
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
assert results == []
assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace):
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"app_id": "app-1"}
@ -439,10 +439,10 @@ class TestTencentDataTrace:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.return_value = None
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
assert results == []
assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1
def test_get_user_id_workflow(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
@ -471,16 +471,16 @@ class TestTencentDataTrace:
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "anonymous"
def test_get_user_id_exception(self, tencent_data_trace):
def test_get_user_id_exception(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.tenant_id = "t"
trace_info.metadata = {"user_id": "u"}
with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")):
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "unknown"
mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID")
assert user_id == "unknown"
assert "[Tencent APM] Failed to get user ID" in caplog.text
def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace):
node = MagicMock(spec=WorkflowNodeExecution)
@ -514,14 +514,14 @@ class TestTencentDataTrace:
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
tencent_data_trace.trace_client.record_token_usage.assert_called_once()
def test_record_llm_metrics_exception(self, tencent_data_trace):
def test_record_llm_metrics_exception(self, tencent_data_trace, caplog):
node = MagicMock(spec=WorkflowNodeExecution)
node.process_data = None
node.outputs = None
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
with caplog.at_level(logging.DEBUG):
tencent_data_trace._record_llm_metrics(node)
# Should not crash
# Should not crash
def test_record_message_llm_metrics(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
@ -553,13 +553,13 @@ class TestTencentDataTrace:
tencent_data_trace._record_message_llm_metrics(trace_info)
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
def test_record_message_llm_metrics_exception(self, tencent_data_trace):
def test_record_message_llm_metrics_exception(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = None
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
with caplog.at_level(logging.DEBUG):
tencent_data_trace._record_message_llm_metrics(trace_info)
# Should not crash
# Should not crash
def test_record_workflow_trace_duration(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
@ -605,11 +605,11 @@ class TestTencentDataTrace:
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
assert attributes["has_conversation"] == "false"
def test_record_workflow_trace_duration_exception(self, tencent_data_trace):
def test_record_workflow_trace_duration_exception(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
with caplog.at_level(logging.DEBUG):
tencent_data_trace._record_workflow_trace_duration(trace_info)
def test_record_message_trace_duration(self, tencent_data_trace):
@ -627,11 +627,11 @@ class TestTencentDataTrace:
2.0, {"conversation_mode": "chat", "stream": "true"}
)
def test_record_message_trace_duration_exception(self, tencent_data_trace):
def test_record_message_trace_duration_exception(self, tencent_data_trace, caplog):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.start_time = None
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
with caplog.at_level(logging.DEBUG):
tencent_data_trace._record_message_trace_duration(trace_info)
def test_close(self, tencent_data_trace):
@ -647,11 +647,11 @@ class TestTencentDataTrace:
client.shutdown.assert_called_once()
def test_close_exception(self, tencent_data_trace):
def test_close_exception(self, tencent_data_trace, caplog):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
with caplog.at_level(logging.ERROR):
tencent_data_trace.close()
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
assert "[Tencent APM] Failed to shutdown trace client during cleanup" in caplog.text
def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
shutdown = AsyncMock()

View File

@ -9,6 +9,7 @@ This module tests the core authentication endpoints including:
"""
import base64
import logging
from unittest.mock import ANY, MagicMock, Mock, patch
import pytest
@ -191,7 +192,7 @@ class TestLoginApi:
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask):
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog):
"""
Test login rejection when rate limit is exceeded.
@ -204,22 +205,24 @@ class TestLoginApi:
mock_get_invitation.return_value = None
# Act & Assert
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(EmailPasswordLoginLimitError):
login_api.post()
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(EmailPasswordLoginLimitError):
login_api.post()
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "test@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.LOGIN_RATE_LIMITED
warn_records = [
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
]
assert len(warn_records) == 1
assert warn_records[0].args[0] == "test@example.com"
assert warn_records[0].args[1] == LoginFailureReason.LOGIN_RATE_LIMITED
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True)
@patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask):
def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask, caplog):
"""
Test login rejection for frozen accounts.
@ -231,17 +234,19 @@ class TestLoginApi:
mock_is_frozen.return_value = True
# Act & Assert
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(AccountInFreezeError):
login_api.post()
with app.test_request_context(
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(AccountInFreezeError):
login_api.post()
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "frozen@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_IN_FREEZE
warn_records = [
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
]
assert len(warn_records) == 1
assert warn_records[0].args[0] == "frozen@example.com"
assert warn_records[0].args[1] == LoginFailureReason.ACCOUNT_IN_FREEZE
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@ -257,6 +262,7 @@ class TestLoginApi:
mock_is_rate_limit,
mock_db,
app: Flask,
caplog,
):
"""
Test login failure with invalid credentials.
@ -272,20 +278,22 @@ class TestLoginApi:
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
# Act & Assert
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/login",
method="POST",
json={"email": "test@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()
with pytest.raises(AuthenticationFailedError):
login_api.post()
with app.test_request_context(
"/login",
method="POST",
json={"email": "test@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()
with pytest.raises(AuthenticationFailedError):
login_api.post()
mock_add_rate_limit.assert_called_once_with("test@example.com")
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "test@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS
warn_records = [
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
]
assert len(warn_records) == 1
assert warn_records[0].args[0] == "test@example.com"
assert warn_records[0].args[1] == LoginFailureReason.INVALID_CREDENTIALS
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@ -293,7 +301,7 @@ class TestLoginApi:
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
def test_login_fails_for_banned_account(
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog
):
"""
Test login rejection for banned accounts.
@ -308,19 +316,21 @@ class TestLoginApi:
mock_authenticate.side_effect = AccountLoginError("Account is banned")
# Act & Assert
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/login",
method="POST",
json={"email": "banned@example.com", "password": encode_password("ValidPass123!")},
):
login_api = LoginApi()
with pytest.raises(AccountBannedError):
login_api.post()
with app.test_request_context(
"/login",
method="POST",
json={"email": "banned@example.com", "password": encode_password("ValidPass123!")},
):
login_api = LoginApi()
with pytest.raises(AccountBannedError):
login_api.post()
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "banned@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
warn_records = [
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
]
assert len(warn_records) == 1
assert warn_records[0].args[0] == "banned@example.com"
assert warn_records[0].args[1] == LoginFailureReason.ACCOUNT_BANNED
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@ -452,23 +462,26 @@ class TestLoginApi:
mock_get_token_data: MagicMock,
mock_db: MagicMock,
app: Flask,
caplog,
):
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
):
with pytest.raises(AccountBannedError):
EmailCodeLoginApi().post()
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
):
with pytest.raises(AccountBannedError):
EmailCodeLoginApi().post()
mock_revoke_token.assert_called_once_with("token-123")
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "user@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
warn_records = [
r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING
]
assert len(warn_records) == 1
assert warn_records[0].args[0] == "user@example.com"
assert warn_records[0].args[1] == LoginFailureReason.ACCOUNT_BANNED
class TestLogoutApi: