From 8ca8b3d59ae564e96d3fbe1477a2a5321c116cd1 Mon Sep 17 00:00:00 2001 From: Evan <2869018789@qq.com> Date: Wed, 17 Jun 2026 10:22:39 +0800 Subject: [PATCH] refactor: replace mock.patch logger with pytest caplog in tests (#37560) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../core/app/layers/test_timeslice_layer.py | 13 ++++++------- .../app/layers/test_trigger_post_layer.py | 9 +++++---- ...test_message_cycle_manager_optimization.py | 9 +++++---- .../rag/embedding/test_cached_embedding.py | 19 ++++++++++--------- .../core/rag/extractor/test_word_extractor.py | 10 +++++----- .../core/rag/splitter/test_text_splitter.py | 7 ++++--- 6 files changed, 35 insertions(+), 32 deletions(-) diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py index 1ac9a4d8c0..191c103a8a 100644 --- a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -1,3 +1,4 @@ +import logging from unittest.mock import Mock, patch from core.app.layers.timeslice_layer import TimeSliceLayer @@ -64,21 +65,19 @@ class TestTimeSliceLayer: scheduler.remove_job.assert_called_once_with("job-1") - def test_checker_job_handles_resource_limit_without_command_channel(self): + def test_checker_job_handles_resource_limit_without_command_channel(self, caplog): scheduler = Mock() scheduler.running = True cfs_plan_scheduler = Mock(plan=Mock()) cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED - with ( - patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), - patch("core.app.layers.timeslice_layer.logger") as mock_logger, - ): + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) - layer._checker_job("job-1") + with caplog.at_level(logging.ERROR, logger="core.app.layers.timeslice_layer"): + layer._checker_job("job-1") scheduler.remove_job.assert_called_once_with("job-1") - mock_logger.exception.assert_called_once() + assert any(record.levelno == logging.ERROR for record in caplog.records) def test_checker_job_sends_pause_command(self): scheduler = Mock() diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index f82cf20142..88f4a6cc31 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -1,3 +1,4 @@ +import logging from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import Mock, patch @@ -114,7 +115,7 @@ class TestTriggerPostLayer: repo.update.assert_called_once_with(trigger_log) session.commit.assert_called_once() - def test_on_event_handles_missing_trigger_log(self): + def test_on_event_handles_missing_trigger_log(self, caplog): runtime_state = SimpleNamespace( outputs={}, variable_pool=VariablePool.from_bootstrap( @@ -126,7 +127,6 @@ class TestTriggerPostLayer: with ( patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, - patch("core.app.layers.trigger_post_layer.logger") as mock_logger, ): session = Mock() mock_session_factory.create_session.return_value.__enter__.return_value = session @@ -142,9 +142,10 @@ class TestTriggerPostLayer: ) layer.initialize(runtime_state, Mock()) - layer.on_event(GraphRunFailedEvent(error="boom")) + with caplog.at_level(logging.ERROR, logger="core.app.layers.trigger_post_layer"): + layer.on_event(GraphRunFailedEvent(error="boom")) - mock_logger.exception.assert_called_once() + assert any(record.levelno == logging.ERROR for record in caplog.records) session.commit.assert_not_called() def test_on_event_ignores_non_status_events(self): diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index 92fe3cbec6..4324fdf884 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,5 +1,6 @@ """Unit tests for the message cycle manager optimization.""" +import logging from types import SimpleNamespace from unittest.mock import Mock, patch @@ -344,7 +345,7 @@ class TestMessageCycleManagerOptimization: db_session.close.assert_called_once() mock_redis.setex.assert_called_once() - def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager): + def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager, caplog): """Fallback to truncated query when LLM generation fails.""" flask_app = Flask(__name__) conversation = SimpleNamespace( @@ -362,19 +363,19 @@ class TestMessageCycleManagerOptimization: patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, patch("core.app.task_pipeline.message_cycle_manager.dify_config") as mock_dify_config, - patch("core.app.task_pipeline.message_cycle_manager.logger") as mock_logger, ): mock_db.session = db_session mock_redis.get.return_value = None mock_llm_generator.generate_conversation_name.side_effect = RuntimeError("generation failed") mock_dify_config.DEBUG = True - message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query) + with caplog.at_level(logging.ERROR, logger="core.app.task_pipeline.message_cycle_manager"): + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query) assert conversation.name == (long_query[:47] + "...") db_session.commit.assert_called_once() db_session.close.assert_called_once() - mock_logger.exception.assert_called_once() + assert any(record.levelno == logging.ERROR for record in caplog.records) def test_handle_annotation_reply_sets_metadata(self, message_cycle_manager): """Populate task metadata from annotation reply events. diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index 051a1455ae..364f688c8e 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -7,6 +7,7 @@ This test file covers the methods not fully tested in test_embedding_service.py: """ import base64 +import logging from decimal import Decimal from unittest.mock import Mock, patch @@ -188,7 +189,7 @@ class TestCacheEmbeddingMultimodalDocuments: assert len(result) == 3 assert result[0] == normalized_cached - def test_embed_multimodal_documents_nan_handling(self, mock_model_instance): + def test_embed_multimodal_documents_nan_handling(self, mock_model_instance, caplog): """Test handling of NaN values in multimodal embeddings.""" cache_embedding = CacheEmbedding(mock_model_instance) documents = [{"file_id": "valid"}, {"file_id": "nan"}] @@ -216,14 +217,14 @@ class TestCacheEmbeddingMultimodalDocuments: mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result - with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + with caplog.at_level(logging.WARNING, logger="core.rag.embedding.cached_embedding"): result = cache_embedding.embed_multimodal_documents(documents) assert len(result) == 2 assert result[0] is not None assert result[1] is None - mock_logger.warning.assert_called_once() + assert any(record.levelno == logging.WARNING for record in caplog.records) def test_embed_multimodal_documents_large_batch(self, mock_model_instance): """Test embedding large batch of multimodal documents respecting MAX_CHUNKS.""" @@ -463,7 +464,7 @@ class TestCacheEmbeddingQueryErrors: model_instance.credentials = {"api_key": "test-key"} return model_instance - def test_embed_query_api_error_debug_mode(self, mock_model_instance): + def test_embed_query_api_error_debug_mode(self, mock_model_instance, caplog): """Test handling of API errors in debug mode.""" cache_embedding = CacheEmbedding(mock_model_instance) query = "test query" @@ -475,14 +476,14 @@ class TestCacheEmbeddingQueryErrors: with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: mock_config.DEBUG = True - with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + with caplog.at_level(logging.ERROR, logger="core.rag.embedding.cached_embedding"): with pytest.raises(RuntimeError) as exc_info: cache_embedding.embed_query(query) assert "API Error" in str(exc_info.value) - mock_logger.exception.assert_called() + assert any(record.levelno == logging.ERROR for record in caplog.records) - def test_embed_query_redis_set_error_debug_mode(self, mock_model_instance): + def test_embed_query_redis_set_error_debug_mode(self, mock_model_instance, caplog): """Test handling of Redis set errors in debug mode.""" cache_embedding = CacheEmbedding(mock_model_instance) query = "test query" @@ -514,11 +515,11 @@ class TestCacheEmbeddingQueryErrors: with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: mock_config.DEBUG = True - with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + with caplog.at_level(logging.ERROR, logger="core.rag.embedding.cached_embedding"): with pytest.raises(RuntimeError): cache_embedding.embed_query(query) - mock_logger.exception.assert_called() + assert any(record.levelno == logging.ERROR for record in caplog.records) class TestCacheEmbeddingInitialization: diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 45d6fc1cd0..e85bb2f68e 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,6 +1,7 @@ """Primarily used for testing merged cell scenarios""" import io +import logging import os import tempfile from collections import UserDict @@ -548,7 +549,7 @@ def test_parse_docx_reads_real_paragraph_table_order(monkeypatch: pytest.MonkeyP os.remove(tmp_path) -def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch: pytest.MonkeyPatch): +def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch: pytest.MonkeyPatch, caplog): extractor = object.__new__(WordExtractor) ext_image_id = "ext-image" @@ -709,10 +710,9 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke monkeypatch.setattr(we, "Run", FakeRun) monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map) monkeypatch.setattr(extractor, "_table_to_markdown", lambda table, image_map: "TABLE-MARKDOWN") - logger_exception = MagicMock() - monkeypatch.setattr(we.logger, "exception", logger_exception) - content = extractor.parse_docx("dummy.docx") + with caplog.at_level(logging.ERROR, logger="core.rag.extractor.word_extractor"): + content = extractor.parse_docx("dummy.docx") assert "[EXT]" in content assert "[INT]" in content @@ -720,7 +720,7 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke assert "[LinkText](https://example.com)" in content assert "BrokenLink" in content assert "TABLE-MARKDOWN" in content - logger_exception.assert_called_once() + assert any(record.levelno == logging.ERROR for record in caplog.records) def test_parse_cell_paragraph_hyperlink_in_table_cell_http(): diff --git a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py index 976de10d89..980139192c 100644 --- a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py +++ b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py @@ -126,6 +126,7 @@ Run with coverage: """ import asyncio +import logging import string import sys import types @@ -644,13 +645,13 @@ class TestTextSplitterBasePaths: with pytest.raises(NotImplementedError): asyncio.run(splitter.atransform_documents([Document(page_content="x", metadata={})])) - def test_merge_splits_logs_warning_for_oversized_total(self): + def test_merge_splits_logs_warning_for_oversized_total(self, caplog): """Cover logger.warning path in _merge_splits.""" splitter = RecursiveCharacterTextSplitter(chunk_size=5, chunk_overlap=1) - with patch("core.rag.splitter.text_splitter.logger.warning") as mock_warning: + with caplog.at_level(logging.WARNING, logger="core.rag.splitter.text_splitter"): merged = splitter._merge_splits(["abcdefghij", "b"], "", [10, 1]) assert merged - mock_warning.assert_called_once() + assert any(record.levelno == logging.WARNING for record in caplog.records) # ============================================================================