test: replace patch logger with caplog in core/rag tests (#37468) (#37621)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
frank 2026-06-21 15:30:40 +08:00 committed by GitHub
parent b60f83e308
commit 44464c8c63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 36 deletions

View File

@ -44,6 +44,7 @@ Tests follow the Arrange-Act-Assert pattern for clarity.
"""
import base64
import logging
from decimal import Decimal
from unittest.mock import Mock, patch
@ -406,7 +407,7 @@ class TestCacheEmbeddingDocuments:
assert len(calls[1].kwargs["texts"]) == 10
assert len(calls[2].kwargs["texts"]) == 5
def test_embed_documents_nan_handling(self, mock_model_instance):
def test_embed_documents_nan_handling(self, mock_model_instance, caplog):
"""Test handling of NaN values in embeddings.
Verifies:
@ -446,7 +447,7 @@ class TestCacheEmbeddingDocuments:
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
with caplog.at_level(logging.WARNING, logger="core.rag.embedding.cached_embedding"):
# Act
result = cache_embedding.embed_documents(texts)
@ -461,8 +462,8 @@ class TestCacheEmbeddingDocuments:
assert result[1] is None
# Verify warning was logged
mock_logger.warning.assert_called_once()
assert "Normalized embedding is nan" in str(mock_logger.warning.call_args)
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) >= 1
assert any("Normalized embedding is nan" in record.message for record in caplog.records)
def test_embed_documents_api_connection_error(self, mock_model_instance):
"""Test handling of API connection errors during embedding.

View File

@ -1,3 +1,4 @@
import logging
from types import SimpleNamespace
from typing import Any
from unittest.mock import Mock, patch
@ -384,7 +385,7 @@ class TestParagraphIndexProcessor:
with pytest.raises(ValueError, match="model_name and model_provider_name"):
ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True})
def test_generate_summary_text_only_flow(self) -> None:
def test_generate_summary_text_only_flow(self, caplog) -> None:
model_instance = Mock()
model_instance.credentials = {"k": "v"}
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[])
@ -402,19 +403,22 @@ class TestParagraphIndexProcessor:
"core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota",
side_effect=RuntimeError("quota"),
),
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
summary, usage = ParagraphIndexProcessor.generate_summary(
"tenant-1",
"text content",
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
document_language="English",
)
with caplog.at_level(
logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"
):
summary, usage = ParagraphIndexProcessor.generate_summary(
"tenant-1",
"text content",
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
document_language="English",
)
assert summary == "text summary"
assert isinstance(usage, LLMUsage)
mock_logger.warning.assert_called_with("Failed to deduct quota for summary generation: %s", "quota")
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1
assert any("Failed to deduct quota for summary generation" in record.message for record in caplog.records)
def test_generate_summary_handles_vision_and_image_conversion(self) -> None:
model_instance = Mock()
@ -455,7 +459,7 @@ class TestParagraphIndexProcessor:
assert summary == "vision summary"
mock_extract_text.assert_not_called()
def test_generate_summary_fallbacks_for_prompt_and_result_types(self) -> None:
def test_generate_summary_fallbacks_for_prompt_and_result_types(self, caplog) -> None:
model_instance = Mock()
model_instance.credentials = {"k": "v"}
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(
@ -482,21 +486,24 @@ class TestParagraphIndexProcessor:
"core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content",
side_effect=RuntimeError("bad image"),
),
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
with pytest.raises(ValueError, match="Expected LLMResult"):
ParagraphIndexProcessor.generate_summary(
"tenant-1",
"text content",
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
)
with caplog.at_level(
logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"
):
ParagraphIndexProcessor.generate_summary(
"tenant-1",
"text content",
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
)
mock_logger.warning.assert_called_with(
"Failed to convert image file to prompt message content: %s", "bad image"
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1
assert any(
"Failed to convert image file to prompt message content" in record.message for record in caplog.records
)
def test_extract_images_from_text_handles_patterns_and_build_errors(self) -> None:
def test_extract_images_from_text_handles_patterns_and_build_errors(self, caplog) -> None:
text = (
"![img](/files/11111111-1111-1111-1111-111111111111/image-preview) "
"![img2](/files/22222222-2222-2222-2222-222222222222/file-preview) "
@ -532,13 +539,13 @@ class TestParagraphIndexProcessor:
"core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping",
return_value=SimpleNamespace(id="file-1"),
) as mock_builder,
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
caplog.at_level(logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"),
):
files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text, session)
assert len(files) == 1
assert mock_builder.call_count == 1
mock_logger.warning.assert_not_called()
assert not any(record.levelno == logging.WARNING for record in caplog.records)
def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None:
scalars_result = Mock()
@ -547,7 +554,7 @@ class TestParagraphIndexProcessor:
session.scalars.return_value = scalars_result
assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here", session) == []
def test_extract_images_from_text_logs_when_build_fails(self) -> None:
def test_extract_images_from_text_logs_when_build_fails(self, caplog) -> None:
text = "![img](/files/11111111-1111-1111-1111-111111111111/image-preview)"
image_upload = SimpleNamespace(
id="11111111-1111-1111-1111-111111111111",
@ -569,14 +576,14 @@ class TestParagraphIndexProcessor:
"core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping",
side_effect=RuntimeError("build failed"),
),
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
caplog.at_level(logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"),
):
files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text, session)
assert files == []
mock_logger.warning.assert_called_once()
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1
def test_extract_images_from_segment_attachments(self) -> None:
def test_extract_images_from_segment_attachments(self, caplog) -> None:
image_upload = SimpleNamespace(
id="file-1",
name="image",
@ -609,13 +616,11 @@ class TestParagraphIndexProcessor:
session = Mock()
session.execute.return_value = execute_result
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
with caplog.at_level(logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"):
files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1", session)
assert len(files) == 1
mock_logger.warning.assert_called_once()
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1
def test_extract_images_from_segment_attachments_empty(self) -> None:
execute_result = Mock()

View File

@ -1,3 +1,4 @@
import logging
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock, patch
@ -350,7 +351,7 @@ class TestQAIndexProcessor:
assert all_qa_documents[0].metadata["answer"] == "A test."
assert all_qa_documents[1].metadata["answer"] == "Coverage."
def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app) -> None:
def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app, caplog) -> None:
all_qa_documents: list[Document] = []
source_document = Document(page_content="source text", metadata={"origin": "doc-1"})
@ -359,12 +360,14 @@ class TestQAIndexProcessor:
"core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document",
side_effect=RuntimeError("llm failure"),
),
patch("core.rag.index_processor.processor.qa_index_processor.logger") as mock_logger,
caplog.at_level(logging.ERROR, logger="core.rag.index_processor.processor.qa_index_processor"),
):
processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English")
assert all_qa_documents == []
mock_logger.exception.assert_called_once_with("Failed to format qa document")
assert len(caplog.records) == 1
assert caplog.records[0].levelname == "ERROR"
assert "Failed to format qa document" in caplog.records[0].message
def test_format_split_text_extracts_question_answer_pairs(self, processor: QAIndexProcessor) -> None:
parsed = processor._format_split_text("Q1: First?\nA1: One.\nQ2: Second?\nA2: Two.\n")