mirror of
https://github.com/langgenius/dify.git
synced 2026-06-23 04:11:09 +08:00
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b60f83e308
commit
44464c8c63
@ -44,6 +44,7 @@ Tests follow the Arrange-Act-Assert pattern for clarity.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -406,7 +407,7 @@ class TestCacheEmbeddingDocuments:
|
||||
assert len(calls[1].kwargs["texts"]) == 10
|
||||
assert len(calls[2].kwargs["texts"]) == 5
|
||||
|
||||
def test_embed_documents_nan_handling(self, mock_model_instance):
|
||||
def test_embed_documents_nan_handling(self, mock_model_instance, caplog):
|
||||
"""Test handling of NaN values in embeddings.
|
||||
|
||||
Verifies:
|
||||
@ -446,7 +447,7 @@ class TestCacheEmbeddingDocuments:
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
|
||||
with caplog.at_level(logging.WARNING, logger="core.rag.embedding.cached_embedding"):
|
||||
# Act
|
||||
result = cache_embedding.embed_documents(texts)
|
||||
|
||||
@ -461,8 +462,8 @@ class TestCacheEmbeddingDocuments:
|
||||
assert result[1] is None
|
||||
|
||||
# Verify warning was logged
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Normalized embedding is nan" in str(mock_logger.warning.call_args)
|
||||
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) >= 1
|
||||
assert any("Normalized embedding is nan" in record.message for record in caplog.records)
|
||||
|
||||
def test_embed_documents_api_connection_error(self, mock_model_instance):
|
||||
"""Test handling of API connection errors during embedding.
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
@ -384,7 +385,7 @@ class TestParagraphIndexProcessor:
|
||||
with pytest.raises(ValueError, match="model_name and model_provider_name"):
|
||||
ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True})
|
||||
|
||||
def test_generate_summary_text_only_flow(self) -> None:
|
||||
def test_generate_summary_text_only_flow(self, caplog) -> None:
|
||||
model_instance = Mock()
|
||||
model_instance.credentials = {"k": "v"}
|
||||
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[])
|
||||
@ -402,19 +403,22 @@ class TestParagraphIndexProcessor:
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota",
|
||||
side_effect=RuntimeError("quota"),
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
summary, usage = ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
"text content",
|
||||
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
|
||||
document_language="English",
|
||||
)
|
||||
with caplog.at_level(
|
||||
logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"
|
||||
):
|
||||
summary, usage = ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
"text content",
|
||||
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
|
||||
document_language="English",
|
||||
)
|
||||
|
||||
assert summary == "text summary"
|
||||
assert isinstance(usage, LLMUsage)
|
||||
mock_logger.warning.assert_called_with("Failed to deduct quota for summary generation: %s", "quota")
|
||||
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1
|
||||
assert any("Failed to deduct quota for summary generation" in record.message for record in caplog.records)
|
||||
|
||||
def test_generate_summary_handles_vision_and_image_conversion(self) -> None:
|
||||
model_instance = Mock()
|
||||
@ -455,7 +459,7 @@ class TestParagraphIndexProcessor:
|
||||
assert summary == "vision summary"
|
||||
mock_extract_text.assert_not_called()
|
||||
|
||||
def test_generate_summary_fallbacks_for_prompt_and_result_types(self) -> None:
|
||||
def test_generate_summary_fallbacks_for_prompt_and_result_types(self, caplog) -> None:
|
||||
model_instance = Mock()
|
||||
model_instance.credentials = {"k": "v"}
|
||||
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(
|
||||
@ -482,21 +486,24 @@ class TestParagraphIndexProcessor:
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content",
|
||||
side_effect=RuntimeError("bad image"),
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
with pytest.raises(ValueError, match="Expected LLMResult"):
|
||||
ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
"text content",
|
||||
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
|
||||
)
|
||||
with caplog.at_level(
|
||||
logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"
|
||||
):
|
||||
ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
"text content",
|
||||
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
|
||||
)
|
||||
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Failed to convert image file to prompt message content: %s", "bad image"
|
||||
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1
|
||||
assert any(
|
||||
"Failed to convert image file to prompt message content" in record.message for record in caplog.records
|
||||
)
|
||||
|
||||
def test_extract_images_from_text_handles_patterns_and_build_errors(self) -> None:
|
||||
def test_extract_images_from_text_handles_patterns_and_build_errors(self, caplog) -> None:
|
||||
text = (
|
||||
" "
|
||||
" "
|
||||
@ -532,13 +539,13 @@ class TestParagraphIndexProcessor:
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping",
|
||||
return_value=SimpleNamespace(id="file-1"),
|
||||
) as mock_builder,
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
caplog.at_level(logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"),
|
||||
):
|
||||
files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text, session)
|
||||
|
||||
assert len(files) == 1
|
||||
assert mock_builder.call_count == 1
|
||||
mock_logger.warning.assert_not_called()
|
||||
assert not any(record.levelno == logging.WARNING for record in caplog.records)
|
||||
|
||||
def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None:
|
||||
scalars_result = Mock()
|
||||
@ -547,7 +554,7 @@ class TestParagraphIndexProcessor:
|
||||
session.scalars.return_value = scalars_result
|
||||
assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here", session) == []
|
||||
|
||||
def test_extract_images_from_text_logs_when_build_fails(self) -> None:
|
||||
def test_extract_images_from_text_logs_when_build_fails(self, caplog) -> None:
|
||||
text = ""
|
||||
image_upload = SimpleNamespace(
|
||||
id="11111111-1111-1111-1111-111111111111",
|
||||
@ -569,14 +576,14 @@ class TestParagraphIndexProcessor:
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping",
|
||||
side_effect=RuntimeError("build failed"),
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
caplog.at_level(logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"),
|
||||
):
|
||||
files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text, session)
|
||||
|
||||
assert files == []
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1
|
||||
|
||||
def test_extract_images_from_segment_attachments(self) -> None:
|
||||
def test_extract_images_from_segment_attachments(self, caplog) -> None:
|
||||
image_upload = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="image",
|
||||
@ -609,13 +616,11 @@ class TestParagraphIndexProcessor:
|
||||
session = Mock()
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
with caplog.at_level(logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"):
|
||||
files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1", session)
|
||||
|
||||
assert len(files) == 1
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1
|
||||
|
||||
def test_extract_images_from_segment_attachments_empty(self) -> None:
|
||||
execute_result = Mock()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
@ -350,7 +351,7 @@ class TestQAIndexProcessor:
|
||||
assert all_qa_documents[0].metadata["answer"] == "A test."
|
||||
assert all_qa_documents[1].metadata["answer"] == "Coverage."
|
||||
|
||||
def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app) -> None:
|
||||
def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app, caplog) -> None:
|
||||
all_qa_documents: list[Document] = []
|
||||
source_document = Document(page_content="source text", metadata={"origin": "doc-1"})
|
||||
|
||||
@ -359,12 +360,14 @@ class TestQAIndexProcessor:
|
||||
"core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document",
|
||||
side_effect=RuntimeError("llm failure"),
|
||||
),
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.logger") as mock_logger,
|
||||
caplog.at_level(logging.ERROR, logger="core.rag.index_processor.processor.qa_index_processor"),
|
||||
):
|
||||
processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English")
|
||||
|
||||
assert all_qa_documents == []
|
||||
mock_logger.exception.assert_called_once_with("Failed to format qa document")
|
||||
assert len(caplog.records) == 1
|
||||
assert caplog.records[0].levelname == "ERROR"
|
||||
assert "Failed to format qa document" in caplog.records[0].message
|
||||
|
||||
def test_format_split_text_extracts_question_answer_pairs(self, processor: QAIndexProcessor) -> None:
|
||||
parsed = processor._format_split_text("Q1: First?\nA1: One.\nQ2: Second?\nA2: Two.\n")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user