refactor: replace mock.patch logger with pytest caplog in tests (#37560)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Evan 2026-06-17 10:22:39 +08:00 committed by GitHub
parent 3f81ec1212
commit 8ca8b3d59a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 35 additions and 32 deletions

View File

@ -1,3 +1,4 @@
import logging
from unittest.mock import Mock, patch
from core.app.layers.timeslice_layer import TimeSliceLayer
@ -64,21 +65,19 @@ class TestTimeSliceLayer:
scheduler.remove_job.assert_called_once_with("job-1")
def test_checker_job_handles_resource_limit_without_command_channel(self):
def test_checker_job_handles_resource_limit_without_command_channel(self, caplog):
scheduler = Mock()
scheduler.running = True
cfs_plan_scheduler = Mock(plan=Mock())
cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED
with (
patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler),
patch("core.app.layers.timeslice_layer.logger") as mock_logger,
):
with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler):
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
layer._checker_job("job-1")
with caplog.at_level(logging.ERROR, logger="core.app.layers.timeslice_layer"):
layer._checker_job("job-1")
scheduler.remove_job.assert_called_once_with("job-1")
mock_logger.exception.assert_called_once()
assert any(record.levelno == logging.ERROR for record in caplog.records)
def test_checker_job_sends_pause_command(self):
scheduler = Mock()

View File

@ -1,3 +1,4 @@
import logging
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import Mock, patch
@ -114,7 +115,7 @@ class TestTriggerPostLayer:
repo.update.assert_called_once_with(trigger_log)
session.commit.assert_called_once()
def test_on_event_handles_missing_trigger_log(self):
def test_on_event_handles_missing_trigger_log(self, caplog):
runtime_state = SimpleNamespace(
outputs={},
variable_pool=VariablePool.from_bootstrap(
@ -126,7 +127,6 @@ class TestTriggerPostLayer:
with (
patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory,
patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls,
patch("core.app.layers.trigger_post_layer.logger") as mock_logger,
):
session = Mock()
mock_session_factory.create_session.return_value.__enter__.return_value = session
@ -142,9 +142,10 @@ class TestTriggerPostLayer:
)
layer.initialize(runtime_state, Mock())
layer.on_event(GraphRunFailedEvent(error="boom"))
with caplog.at_level(logging.ERROR, logger="core.app.layers.trigger_post_layer"):
layer.on_event(GraphRunFailedEvent(error="boom"))
mock_logger.exception.assert_called_once()
assert any(record.levelno == logging.ERROR for record in caplog.records)
session.commit.assert_not_called()
def test_on_event_ignores_non_status_events(self):

View File

@ -1,5 +1,6 @@
"""Unit tests for the message cycle manager optimization."""
import logging
from types import SimpleNamespace
from unittest.mock import Mock, patch
@ -344,7 +345,7 @@ class TestMessageCycleManagerOptimization:
db_session.close.assert_called_once()
mock_redis.setex.assert_called_once()
def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager):
def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager, caplog):
"""Fallback to truncated query when LLM generation fails."""
flask_app = Flask(__name__)
conversation = SimpleNamespace(
@ -362,19 +363,19 @@ class TestMessageCycleManagerOptimization:
patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis,
patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator,
patch("core.app.task_pipeline.message_cycle_manager.dify_config") as mock_dify_config,
patch("core.app.task_pipeline.message_cycle_manager.logger") as mock_logger,
):
mock_db.session = db_session
mock_redis.get.return_value = None
mock_llm_generator.generate_conversation_name.side_effect = RuntimeError("generation failed")
mock_dify_config.DEBUG = True
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query)
with caplog.at_level(logging.ERROR, logger="core.app.task_pipeline.message_cycle_manager"):
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query)
assert conversation.name == (long_query[:47] + "...")
db_session.commit.assert_called_once()
db_session.close.assert_called_once()
mock_logger.exception.assert_called_once()
assert any(record.levelno == logging.ERROR for record in caplog.records)
def test_handle_annotation_reply_sets_metadata(self, message_cycle_manager):
"""Populate task metadata from annotation reply events.

View File

@ -7,6 +7,7 @@ This test file covers the methods not fully tested in test_embedding_service.py:
"""
import base64
import logging
from decimal import Decimal
from unittest.mock import Mock, patch
@ -188,7 +189,7 @@ class TestCacheEmbeddingMultimodalDocuments:
assert len(result) == 3
assert result[0] == normalized_cached
def test_embed_multimodal_documents_nan_handling(self, mock_model_instance):
def test_embed_multimodal_documents_nan_handling(self, mock_model_instance, caplog):
"""Test handling of NaN values in multimodal embeddings."""
cache_embedding = CacheEmbedding(mock_model_instance)
documents = [{"file_id": "valid"}, {"file_id": "nan"}]
@ -216,14 +217,14 @@ class TestCacheEmbeddingMultimodalDocuments:
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
with caplog.at_level(logging.WARNING, logger="core.rag.embedding.cached_embedding"):
result = cache_embedding.embed_multimodal_documents(documents)
assert len(result) == 2
assert result[0] is not None
assert result[1] is None
mock_logger.warning.assert_called_once()
assert any(record.levelno == logging.WARNING for record in caplog.records)
def test_embed_multimodal_documents_large_batch(self, mock_model_instance):
"""Test embedding large batch of multimodal documents respecting MAX_CHUNKS."""
@ -463,7 +464,7 @@ class TestCacheEmbeddingQueryErrors:
model_instance.credentials = {"api_key": "test-key"}
return model_instance
def test_embed_query_api_error_debug_mode(self, mock_model_instance):
def test_embed_query_api_error_debug_mode(self, mock_model_instance, caplog):
"""Test handling of API errors in debug mode."""
cache_embedding = CacheEmbedding(mock_model_instance)
query = "test query"
@ -475,14 +476,14 @@ class TestCacheEmbeddingQueryErrors:
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
mock_config.DEBUG = True
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
with caplog.at_level(logging.ERROR, logger="core.rag.embedding.cached_embedding"):
with pytest.raises(RuntimeError) as exc_info:
cache_embedding.embed_query(query)
assert "API Error" in str(exc_info.value)
mock_logger.exception.assert_called()
assert any(record.levelno == logging.ERROR for record in caplog.records)
def test_embed_query_redis_set_error_debug_mode(self, mock_model_instance):
def test_embed_query_redis_set_error_debug_mode(self, mock_model_instance, caplog):
"""Test handling of Redis set errors in debug mode."""
cache_embedding = CacheEmbedding(mock_model_instance)
query = "test query"
@ -514,11 +515,11 @@ class TestCacheEmbeddingQueryErrors:
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
mock_config.DEBUG = True
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
with caplog.at_level(logging.ERROR, logger="core.rag.embedding.cached_embedding"):
with pytest.raises(RuntimeError):
cache_embedding.embed_query(query)
mock_logger.exception.assert_called()
assert any(record.levelno == logging.ERROR for record in caplog.records)
class TestCacheEmbeddingInitialization:

View File

@ -1,6 +1,7 @@
"""Primarily used for testing merged cell scenarios"""
import io
import logging
import os
import tempfile
from collections import UserDict
@ -548,7 +549,7 @@ def test_parse_docx_reads_real_paragraph_table_order(monkeypatch: pytest.MonkeyP
os.remove(tmp_path)
def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch: pytest.MonkeyPatch):
def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch: pytest.MonkeyPatch, caplog):
extractor = object.__new__(WordExtractor)
ext_image_id = "ext-image"
@ -709,10 +710,9 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke
monkeypatch.setattr(we, "Run", FakeRun)
monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map)
monkeypatch.setattr(extractor, "_table_to_markdown", lambda table, image_map: "TABLE-MARKDOWN")
logger_exception = MagicMock()
monkeypatch.setattr(we.logger, "exception", logger_exception)
content = extractor.parse_docx("dummy.docx")
with caplog.at_level(logging.ERROR, logger="core.rag.extractor.word_extractor"):
content = extractor.parse_docx("dummy.docx")
assert "[EXT]" in content
assert "[INT]" in content
@ -720,7 +720,7 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke
assert "[LinkText](https://example.com)" in content
assert "BrokenLink" in content
assert "TABLE-MARKDOWN" in content
logger_exception.assert_called_once()
assert any(record.levelno == logging.ERROR for record in caplog.records)
def test_parse_cell_paragraph_hyperlink_in_table_cell_http():

View File

@ -126,6 +126,7 @@ Run with coverage:
"""
import asyncio
import logging
import string
import sys
import types
@ -644,13 +645,13 @@ class TestTextSplitterBasePaths:
with pytest.raises(NotImplementedError):
asyncio.run(splitter.atransform_documents([Document(page_content="x", metadata={})]))
def test_merge_splits_logs_warning_for_oversized_total(self):
def test_merge_splits_logs_warning_for_oversized_total(self, caplog):
"""Cover logger.warning path in _merge_splits."""
splitter = RecursiveCharacterTextSplitter(chunk_size=5, chunk_overlap=1)
with patch("core.rag.splitter.text_splitter.logger.warning") as mock_warning:
with caplog.at_level(logging.WARNING, logger="core.rag.splitter.text_splitter"):
merged = splitter._merge_splits(["abcdefghij", "b"], "", [10, 1])
assert merged
mock_warning.assert_called_once()
assert any(record.levelno == logging.WARNING for record in caplog.records)
# ============================================================================