refactor: core/rag docstore, datasource, embedding, rerank, retrieval (#34203)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Renzo 2026-03-30 10:09:49 +02:00 committed by GitHub
parent 40fa0f365c
commit 456684dfc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 170 additions and 214 deletions

View File

@ -97,13 +97,13 @@ class Jieba(BaseKeyword):
documents = [] documents = []
segment_query_stmt = db.session.query(DocumentSegment).where( segment_query_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices) DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
) )
if document_ids_filter: if document_ids_filter:
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter)) segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
segments = db.session.execute(segment_query_stmt).scalars().all() segments = db.session.scalars(segment_query_stmt).all()
segment_map = {segment.index_node_id: segment for segment in segments} segment_map = {segment.index_node_id: segment for segment in segments}
for chunk_index in sorted_chunk_indices: for chunk_index in sorted_chunk_indices:
segment = segment_map.get(chunk_index) segment = segment_map.get(chunk_index)

View File

@ -432,10 +432,11 @@ class RetrievalService:
# Batch query dataset documents # Batch query dataset documents
dataset_documents = { dataset_documents = {
doc.id: doc doc.id: doc
for doc in db.session.query(DatasetDocument) for doc in db.session.scalars(
.where(DatasetDocument.id.in_(document_ids)) select(DatasetDocument)
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) .where(DatasetDocument.id.in_(document_ids))
.all() .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
).all()
} }
valid_dataset_documents = {} valid_dataset_documents = {}

View File

@ -426,11 +426,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else: else:
idle_tidb_auth_binding = ( idle_tidb_auth_binding = db.session.scalar(
db.session.query(TidbAuthBinding) select(TidbAuthBinding)
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1) .limit(1)
.one_or_none()
) )
if idle_tidb_auth_binding: if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True idle_tidb_auth_binding.active = True

View File

@ -277,7 +277,7 @@ class Vector:
return self._vector_processor.search_by_vector(query_vector, **kwargs) return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]: def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first() upload_file: UploadFile | None = db.session.get(UploadFile, file_id)
if not upload_file: if not upload_file:
return [] return []

View File

@ -4,7 +4,7 @@ from collections.abc import Sequence
from typing import Any from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import func, select from sqlalchemy import delete, func, select
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexTechniqueType
@ -63,10 +63,8 @@ class DatasetDocumentStore:
return output return output
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False): def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False):
max_position = ( max_position = db.session.scalar(
db.session.query(func.max(DocumentSegment.position)) select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == self._document_id)
.where(DocumentSegment.document_id == self._document_id)
.scalar()
) )
if max_position is None: if max_position is None:
@ -155,12 +153,14 @@ class DatasetDocumentStore:
) )
if save_child and doc.children: if save_child and doc.children:
# delete the existing child chunks # delete the existing child chunks
db.session.query(ChildChunk).where( db.session.execute(
ChildChunk.tenant_id == self._dataset.tenant_id, delete(ChildChunk).where(
ChildChunk.dataset_id == self._dataset.id, ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.document_id == self._document_id, ChildChunk.dataset_id == self._dataset.id,
ChildChunk.segment_id == segment_document.id, ChildChunk.document_id == self._document_id,
).delete() ChildChunk.segment_id == segment_document.id,
)
)
# add new child chunks # add new child chunks
for position, child in enumerate(doc.children, start=1): for position, child in enumerate(doc.children, start=1):
child_segment = ChildChunk( child_segment = ChildChunk(

View File

@ -6,6 +6,7 @@ from typing import Any, cast
import numpy as np import numpy as np
from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.entities.model_entities import ModelPropertyKey
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from configs import dify_config from configs import dify_config
@ -31,14 +32,14 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = [] embedding_queue_indices = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = ( embedding = db.session.scalar(
db.session.query(Embedding) select(Embedding)
.filter_by( .where(
model_name=self._model_instance.model_name, Embedding.model_name == self._model_instance.model_name,
hash=hash, Embedding.hash == hash,
provider_name=self._model_instance.provider, Embedding.provider_name == self._model_instance.provider,
) )
.first() .limit(1)
) )
if embedding: if embedding:
text_embeddings[i] = embedding.get_embedding() text_embeddings[i] = embedding.get_embedding()
@ -112,14 +113,14 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = [] embedding_queue_indices = []
for i, multimodel_document in enumerate(multimodel_documents): for i, multimodel_document in enumerate(multimodel_documents):
file_id = multimodel_document["file_id"] file_id = multimodel_document["file_id"]
embedding = ( embedding = db.session.scalar(
db.session.query(Embedding) select(Embedding)
.filter_by( .where(
model_name=self._model_instance.model_name, Embedding.model_name == self._model_instance.model_name,
hash=file_id, Embedding.hash == file_id,
provider_name=self._model_instance.provider, Embedding.provider_name == self._model_instance.provider,
) )
.first() .limit(1)
) )
if embedding: if embedding:
multimodel_embeddings[i] = embedding.get_embedding() multimodel_embeddings[i] = embedding.get_embedding()

View File

@ -4,6 +4,7 @@ import operator
from typing import Any, cast from typing import Any, cast
import httpx import httpx
from sqlalchemy import update
from configs import dify_config from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
@ -346,9 +347,11 @@ class NotionExtractor(BaseExtractor):
if data_source_info: if data_source_info:
data_source_info["last_edited_time"] = last_edited_time data_source_info["last_edited_time"] = last_edited_time
db.session.query(DocumentModel).filter_by(id=document_model.id).update( db.session.execute(
{DocumentModel.data_source_info: json.dumps(data_source_info)} update(DocumentModel)
) # type: ignore .where(DocumentModel.id == document_model.id)
.values(data_source_info=json.dumps(data_source_info))
)
db.session.commit() db.session.commit()
def get_notion_last_edited_time(self) -> str: def get_notion_last_edited_time(self) -> str:

View File

@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, NotRequired, Optional
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
import httpx import httpx
from sqlalchemy import select
from typing_extensions import TypedDict from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
@ -200,7 +201,7 @@ class BaseIndexProcessor(ABC):
# Get unique IDs for database query # Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list)) unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all() upload_files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids))).all()
# Create a mapping from ID to UploadFile for quick lookup # Create a mapping from ID to UploadFile for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files} upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
@ -312,7 +313,7 @@ class BaseIndexProcessor(ABC):
""" """
from services.file_service import FileService from services.file_service import FileService
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() tool_file = db.session.get(ToolFile, tool_file_id)
if not tool_file: if not tool_file:
return None return None
blob = storage.load_once(tool_file.file_key) blob = storage.load_once(tool_file.file_key)

View File

@ -18,6 +18,7 @@ from graphon.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from sqlalchemy import select
from core.app.file_access import DatabaseFileAccessController from core.app.file_access import DatabaseFileAccessController
from core.app.llm import deduct_llm_quota from core.app.llm import deduct_llm_quota
@ -145,14 +146,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if delete_summaries: if delete_summaries:
if node_ids: if node_ids:
# Find segments by index_node_id # Find segments by index_node_id
segments = ( segments = db.session.scalars(
db.session.query(DocumentSegment) select(DocumentSegment).where(
.filter(
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids), DocumentSegment.index_node_id.in_(node_ids),
) )
.all() ).all()
)
segment_ids = [segment.id for segment in segments] segment_ids = [segment.id for segment in segments]
if segment_ids: if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
@ -537,11 +536,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Get unique IDs for database query # Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list)) unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = ( upload_files = db.session.scalars(
db.session.query(UploadFile) select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) ).all()
.all()
)
# Create File objects from UploadFile records # Create File objects from UploadFile records
file_objects = [] file_objects = []

View File

@ -6,6 +6,8 @@ import uuid
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from sqlalchemy import delete, select
from configs import dify_config from configs import dify_config
from core.db.session_factory import session_factory from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail from core.entities.knowledge_entities import PreviewDetail
@ -177,17 +179,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = precomputed_child_node_ids child_node_ids = precomputed_child_node_ids
else: else:
# Fallback to original query (may fail if segments are already deleted) # Fallback to original query (may fail if segments are already deleted)
child_node_ids = ( rows = db.session.execute(
db.session.query(ChildChunk.index_node_id) select(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.where( .where(
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids), DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id, ChildChunk.dataset_id == dataset.id,
) )
.all() ).all()
) child_node_ids = [row[0] for row in rows if row[0]]
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]]
# Delete from vector index # Delete from vector index
if child_node_ids: if child_node_ids:
@ -195,18 +196,22 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
# Delete from database # Delete from database
if delete_child_chunks and child_node_ids: if delete_child_chunks and child_node_ids:
db.session.query(ChildChunk).where( db.session.execute(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) delete(ChildChunk).where(
).delete(synchronize_session=False) ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
)
)
db.session.commit() db.session.commit()
else: else:
vector.delete() vector.delete()
if delete_child_chunks: if delete_child_chunks:
# Use existing compound index: (tenant_id, dataset_id, ...) # Use existing compound index: (tenant_id, dataset_id, ...)
db.session.query(ChildChunk).where( db.session.execute(
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id delete(ChildChunk).where(
).delete(synchronize_session=False) ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
)
)
db.session.commit() db.session.commit()
def retrieve( def retrieve(

View File

@ -134,9 +134,7 @@ class RerankModelRunner(BaseRerankRunner):
): ):
if document.metadata.get("doc_type") == DocType.IMAGE: if document.metadata.get("doc_type") == DocType.IMAGE:
# Query file info within db.session context to ensure thread-safe access # Query file info within db.session context to ensure thread-safe access
upload_file = ( upload_file = db.session.get(UploadFile, document.metadata["doc_id"])
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
)
if upload_file: if upload_file:
blob = storage.load_once(upload_file.key) blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode() document_file_base64 = base64.b64encode(blob).decode()
@ -169,7 +167,7 @@ class RerankModelRunner(BaseRerankRunner):
return rerank_result, unique_documents return rerank_result, unique_documents
elif query_type == QueryType.IMAGE_QUERY: elif query_type == QueryType.IMAGE_QUERY:
# Query file info within db.session context to ensure thread-safe access # Query file info within db.session context to ensure thread-safe access
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first() upload_file = db.session.get(UploadFile, query)
if upload_file: if upload_file:
blob = storage.load_once(upload_file.key) blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode() file_query = base64.b64encode(blob).decode()

View File

@ -1340,7 +1340,7 @@ class DatasetRetrieval:
metadata_filtering_conditions: MetadataFilteringCondition | None, metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict, inputs: dict,
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
document_query = db.session.query(DatasetDocument).where( document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
@ -1411,7 +1411,7 @@ class DatasetRetrieval:
document_query = document_query.where(and_(*filters)) document_query = document_query.where(and_(*filters))
else: else:
document_query = document_query.where(or_(*filters)) document_query = document_query.where(or_(*filters))
documents = document_query.all() documents = db.session.scalars(document_query).all()
# group by dataset_id # group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents: for document in documents:

View File

@ -201,27 +201,23 @@ def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch,
document_id = _Field("document_id") document_id = _Field("document_id")
keyword = Jieba(_dataset(_dataset_keyword_table())) keyword = Jieba(_dataset(_dataset_keyword_table()))
query_stmt = _FakeQuery() patched_runtime.session.scalars.return_value.all.return_value = [
patched_runtime.session.query.return_value = query_stmt SimpleNamespace(
patched_runtime.session.execute.return_value = _FakeExecuteResult( index_node_id="node-2",
[ content="segment-content",
SimpleNamespace( index_node_hash="hash-2",
index_node_id="node-2", document_id="doc-2",
content="segment-content", dataset_id="dataset-1",
index_node_hash="hash-2", )
document_id="doc-2", ]
dataset_id="dataset-1",
)
]
)
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect())
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"])) monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"]))
documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"]) documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"])
assert len(query_stmt.where_calls) == 2
assert len(documents) == 1 assert len(documents) == 1
assert documents[0].page_content == "segment-content" assert documents[0].page_content == "segment-content"
assert documents[0].metadata["doc_id"] == "node-2" assert documents[0].metadata["doc_id"] == "node-2"

View File

@ -714,13 +714,13 @@ class TestRetrievalServiceInternals:
dataset_id="dataset-id", dataset_id="dataset-id",
) )
dataset_query = Mock() scalars_result = Mock()
dataset_query.where.return_value.options.return_value.all.return_value = [ scalars_result.all.return_value = [
dataset_doc_parent, dataset_doc_parent,
dataset_doc_text, dataset_doc_text,
dataset_doc_parent_summary, dataset_doc_parent_summary,
] ]
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query)) monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(return_value=scalars_result))
monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk) monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk)
monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment) monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment)
@ -882,7 +882,7 @@ class TestRetrievalServiceInternals:
def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch): def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch):
rollback = Mock() rollback = Mock()
monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback) monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback)
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error"))) monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(side_effect=RuntimeError("db error")))
documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")] documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")]

View File

@ -340,15 +340,13 @@ def test_search_by_file_handles_missing_and_existing_upload(vector_factory_modul
vector._embeddings = MagicMock() vector._embeddings = MagicMock()
vector._vector_processor = MagicMock() vector._vector_processor = MagicMock()
mock_session = SimpleNamespace(get=lambda _model, _id: None)
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
monkeypatch.setattr( monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session))
vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query))
)
upload_query.first.return_value = None
assert vector.search_by_file("file-1") == [] assert vector.search_by_file("file-1") == []
upload_query.first.return_value = SimpleNamespace(key="blob-key") mock_session.get = lambda _model, _id: SimpleNamespace(key="blob-key")
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes")) monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes"))
vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4] vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4]
vector._vector_processor.search_by_vector.return_value = ["hit"] vector._vector_processor.search_by_vector.return_value = ["hit"]

View File

@ -167,7 +167,7 @@ class TestDatasetDocumentStoreAddDocuments:
): ):
mock_session = MagicMock() mock_session = MagicMock()
mock_db.session = mock_session mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None mock_db.session.scalar.return_value = None
mock_manager = MagicMock() mock_manager = MagicMock()
mock_manager.get_model_instance.return_value = mock_model_instance mock_manager.get_model_instance.return_value = mock_model_instance
@ -211,7 +211,7 @@ class TestDatasetDocumentStoreAddDocuments:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db: with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock() mock_session = MagicMock()
mock_db.session = mock_session mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 mock_db.session.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@ -276,7 +276,7 @@ class TestDatasetDocumentStoreAddDocuments:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db: with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock() mock_session = MagicMock()
mock_db.session = mock_session mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None mock_db.session.scalar.return_value = None
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@ -353,7 +353,7 @@ class TestDatasetDocumentStoreAddDocuments:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db: with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock() mock_session = MagicMock()
mock_db.session = mock_session mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None mock_db.session.scalar.return_value = None
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@ -755,7 +755,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db: with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock() mock_session = MagicMock()
mock_db.session = mock_session mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 mock_db.session.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@ -767,7 +767,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
store.add_documents([mock_doc], save_child=True) store.add_documents([mock_doc], save_child=True)
mock_db.session.query.return_value.where.return_value.delete.assert_called() mock_db.session.execute.assert_called()
mock_db.session.commit.assert_called() mock_db.session.commit.assert_called()
@ -798,7 +798,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateAnswer:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db: with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock() mock_session = MagicMock()
mock_db.session = mock_session mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 mock_db.session.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):

View File

@ -69,7 +69,7 @@ class TestCacheEmbeddingMultimodalDocuments:
documents = [{"file_id": "file123", "content": "test content"}] documents = [{"file_id": "file123", "content": "test content"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
result = cache_embedding.embed_multimodal_documents(documents) result = cache_embedding.embed_multimodal_documents(documents)
@ -114,7 +114,7 @@ class TestCacheEmbeddingMultimodalDocuments:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
result = cache_embedding.embed_multimodal_documents(documents) result = cache_embedding.embed_multimodal_documents(documents)
@ -134,7 +134,7 @@ class TestCacheEmbeddingMultimodalDocuments:
mock_cached_embedding.get_embedding.return_value = normalized_cached mock_cached_embedding.get_embedding.return_value = normalized_cached
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding mock_session.scalar.return_value = mock_cached_embedding
result = cache_embedding.embed_multimodal_documents(documents) result = cache_embedding.embed_multimodal_documents(documents)
@ -180,18 +180,7 @@ class TestCacheEmbeddingMultimodalDocuments:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
call_count = [0] mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
def mock_filter_by(**kwargs):
call_count[0] += 1
mock_query = Mock()
if call_count[0] == 1:
mock_query.first.return_value = mock_cached_embedding
else:
mock_query.first.return_value = None
return mock_query
mock_session.query.return_value.filter_by = mock_filter_by
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
result = cache_embedding.embed_multimodal_documents(documents) result = cache_embedding.embed_multimodal_documents(documents)
@ -224,7 +213,7 @@ class TestCacheEmbeddingMultimodalDocuments:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
@ -265,7 +254,7 @@ class TestCacheEmbeddingMultimodalDocuments:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)] batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)]
mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results
@ -281,7 +270,7 @@ class TestCacheEmbeddingMultimodalDocuments:
documents = [{"file_id": "file123"}] documents = [{"file_id": "file123"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error") mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error")
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
@ -298,7 +287,7 @@ class TestCacheEmbeddingMultimodalDocuments:
documents = [{"file_id": "file123"}] documents = [{"file_id": "file123"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None) mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None)

View File

@ -139,7 +139,7 @@ class TestCacheEmbeddingDocuments:
# Mock database query to return no cached embedding (cache miss) # Mock database query to return no cached embedding (cache miss)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
# Mock model invocation # Mock model invocation
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
@ -203,7 +203,7 @@ class TestCacheEmbeddingDocuments:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -240,7 +240,7 @@ class TestCacheEmbeddingDocuments:
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
# Mock database to return cached embedding (cache hit) # Mock database to return cached embedding (cache hit)
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding mock_session.scalar.return_value = mock_cached_embedding
# Act # Act
result = cache_embedding.embed_documents(texts) result = cache_embedding.embed_documents(texts)
@ -313,19 +313,7 @@ class TestCacheEmbeddingDocuments:
mock_hash.side_effect = generate_hash mock_hash.side_effect = generate_hash
# Mock database to return cached embedding only for first text (hash_1) # Mock database to return cached embedding only for first text (hash_1)
call_count = [0] mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
def mock_filter_by(**kwargs):
call_count[0] += 1
mock_query = Mock()
# First call (hash_1) returns cached, others return None
if call_count[0] == 1:
mock_query.first.return_value = mock_cached_embedding
else:
mock_query.first.return_value = None
return mock_query
mock_session.query.return_value.filter_by = mock_filter_by
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -392,7 +380,7 @@ class TestCacheEmbeddingDocuments:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
# Mock model to return appropriate batch results # Mock model to return appropriate batch results
batch_results = [ batch_results = [
@ -455,7 +443,7 @@ class TestCacheEmbeddingDocuments:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
@ -489,7 +477,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"] texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
# Mock model to raise connection error # Mock model to raise connection error
mock_model_instance.invoke_text_embedding.side_effect = InvokeConnectionError("Failed to connect to API") mock_model_instance.invoke_text_embedding.side_effect = InvokeConnectionError("Failed to connect to API")
@ -515,7 +503,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"] texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
# Mock model to raise rate limit error # Mock model to raise rate limit error
mock_model_instance.invoke_text_embedding.side_effect = InvokeRateLimitError("Rate limit exceeded") mock_model_instance.invoke_text_embedding.side_effect = InvokeRateLimitError("Rate limit exceeded")
@ -539,7 +527,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"] texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
# Mock model to raise authorization error # Mock model to raise authorization error
mock_model_instance.invoke_text_embedding.side_effect = InvokeAuthorizationError("Invalid API key") mock_model_instance.invoke_text_embedding.side_effect = InvokeAuthorizationError("Invalid API key")
@ -564,7 +552,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"] texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
# Mock database commit to raise IntegrityError # Mock database commit to raise IntegrityError
@ -884,7 +872,7 @@ class TestEmbeddingModelSwitching:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
model_instance_ada.invoke_text_embedding.return_value = result_ada model_instance_ada.invoke_text_embedding.return_value = result_ada
model_instance_3_small.invoke_text_embedding.return_value = result_3_small model_instance_3_small.invoke_text_embedding.return_value = result_3_small
@ -1047,7 +1035,7 @@ class TestEmbeddingDimensionValidation:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -1100,7 +1088,7 @@ class TestEmbeddingDimensionValidation:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -1186,7 +1174,7 @@ class TestEmbeddingDimensionValidation:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
model_instance_ada.invoke_text_embedding.return_value = result_ada model_instance_ada.invoke_text_embedding.return_value = result_ada
model_instance_cohere.invoke_text_embedding.return_value = result_cohere model_instance_cohere.invoke_text_embedding.return_value = result_cohere
@ -1284,7 +1272,7 @@ class TestEmbeddingEdgeCases:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -1327,7 +1315,7 @@ class TestEmbeddingEdgeCases:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -1375,7 +1363,7 @@ class TestEmbeddingEdgeCases:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -1427,7 +1415,7 @@ class TestEmbeddingEdgeCases:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -1483,7 +1471,7 @@ class TestEmbeddingEdgeCases:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -1551,7 +1539,7 @@ class TestEmbeddingEdgeCases:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -1649,7 +1637,7 @@ class TestEmbeddingEdgeCases:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act # Act
@ -1728,7 +1716,7 @@ class TestEmbeddingCachePerformance:
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
# First call: cache miss # First call: cache miss
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
usage = EmbeddingUsage( usage = EmbeddingUsage(
tokens=5, tokens=5,
@ -1756,7 +1744,7 @@ class TestEmbeddingCachePerformance:
assert len(result1) == 1 assert len(result1) == 1
# Arrange - Second call: cache hit # Arrange - Second call: cache hit
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding mock_session.scalar.return_value = mock_cached_embedding
# Act - Second call (cache hit) # Act - Second call (cache hit)
result2 = cache_embedding.embed_documents([text]) result2 = cache_embedding.embed_documents([text])
@ -1816,7 +1804,7 @@ class TestEmbeddingCachePerformance:
) )
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.scalar.return_value = None
# Mock model to return appropriate batch results # Mock model to return appropriate batch results
batch_results = [ batch_results = [

View File

@ -405,35 +405,36 @@ class TestNotionMetadataAndCredentialMethods:
class FakeDocumentModel: class FakeDocumentModel:
data_source_info = "data_source_info" data_source_info = "data_source_info"
id = "id"
update_calls = [] execute_calls = []
class FakeQuery: class FakeUpdateStmt:
def filter_by(self, **kwargs): def where(self, *args):
return self return self
def update(self, payload): def values(self, **kwargs):
update_calls.append(payload) return self
class FakeSession: class FakeSession:
committed = False committed = False
def query(self, model): def execute(self, stmt):
assert model is FakeDocumentModel execute_calls.append(stmt)
return FakeQuery()
def commit(self): def commit(self):
self.committed = True self.committed = True
fake_db = SimpleNamespace(session=FakeSession()) fake_db = SimpleNamespace(session=FakeSession())
monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel) monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel)
monkeypatch.setattr(notion_extractor, "update", lambda model: FakeUpdateStmt())
monkeypatch.setattr(notion_extractor, "db", fake_db) monkeypatch.setattr(notion_extractor, "db", fake_db)
monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z") monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z")
doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"}) doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"})
extractor.update_last_edited_time(doc_model) extractor.update_last_edited_time(doc_model)
assert update_calls assert execute_calls
assert fake_db.session.committed is True assert fake_db.session.committed is True
def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture): def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture):

View File

@ -188,10 +188,10 @@ class TestParagraphIndexProcessor:
mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs) mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs)
def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
segment_query = Mock() scalars_result = Mock()
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] scalars_result.all.return_value = [SimpleNamespace(id="seg-1")]
session = Mock() session = Mock()
session.query.return_value = segment_query session.scalars.return_value = scalars_result
with ( with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
@ -531,10 +531,10 @@ class TestParagraphIndexProcessor:
size=1, size=1,
key="key", key="key",
) )
query = Mock() scalars_result = Mock()
query.where.return_value.all.return_value = [image_upload, non_image_upload] scalars_result.all.return_value = [image_upload, non_image_upload]
session = Mock() session = Mock()
session.query.return_value = query session.scalars.return_value = scalars_result
with ( with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
@ -565,10 +565,10 @@ class TestParagraphIndexProcessor:
size=1, size=1,
key="key", key="key",
) )
query = Mock() scalars_result = Mock()
query.where.return_value.all.return_value = [image_upload] scalars_result.all.return_value = [image_upload]
session = Mock() session = Mock()
session.query.return_value = query session.scalars.return_value = scalars_result
with ( with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),

View File

@ -208,11 +208,7 @@ class TestParentChildIndexProcessor:
vector.create_multimodal.assert_called_once_with(multimodal_docs) vector.create_multimodal.assert_called_once_with(multimodal_docs)
def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
delete_query = Mock()
where_query = Mock()
where_query.delete.return_value = 2
session = Mock() session = Mock()
session.query.return_value.where.return_value = where_query
with ( with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
@ -227,16 +223,16 @@ class TestParentChildIndexProcessor:
) )
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
where_query.delete.assert_called_once_with(synchronize_session=False) session.execute.assert_called()
session.commit.assert_called_once() session.commit.assert_called_once()
def test_clean_queries_child_ids_when_not_precomputed( def test_clean_queries_child_ids_when_not_precomputed(
self, processor: ParentChildIndexProcessor, dataset: Mock self, processor: ParentChildIndexProcessor, dataset: Mock
) -> None: ) -> None:
child_query = Mock() execute_result = Mock()
child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)] execute_result.all.return_value = [("child-1",), (None,), ("child-2",)]
session = Mock() session = Mock()
session.query.return_value = child_query session.execute.return_value = execute_result
with ( with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
@ -248,10 +244,7 @@ class TestParentChildIndexProcessor:
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
where_query = Mock()
where_query.delete.return_value = 3
session = Mock() session = Mock()
session.query.return_value.where.return_value = where_query
with ( with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
@ -261,7 +254,7 @@ class TestParentChildIndexProcessor:
processor.clean(dataset, None, delete_child_chunks=True) processor.clean(dataset, None, delete_child_chunks=True)
vector.delete.assert_called_once() vector.delete.assert_called_once()
where_query.delete.assert_called_once_with(synchronize_session=False) session.execute.assert_called()
session.commit.assert_called_once() session.commit.assert_called_once()
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:

View File

@ -133,10 +133,10 @@ class TestBaseIndexProcessor:
upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png") upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png")
upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png") upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png")
upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png") upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png")
db_query = Mock() scalars_result = Mock()
db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote] scalars_result.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
db_session = Mock() db_session = Mock()
db_session.query.return_value = db_query db_session.scalars.return_value = scalars_result
with ( with (
patch.object(processor, "_extract_markdown_images", return_value=images), patch.object(processor, "_extract_markdown_images", return_value=images),
@ -170,10 +170,10 @@ class TestBaseIndexProcessor:
def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None: def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None:
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"] images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"]
db_query = Mock() scalars_result = Mock()
db_query.where.return_value.all.return_value = [] scalars_result.all.return_value = []
db_session = Mock() db_session = Mock()
db_session.query.return_value = db_query db_session.scalars.return_value = scalars_result
with ( with (
patch.object(processor, "_extract_markdown_images", return_value=images), patch.object(processor, "_extract_markdown_images", return_value=images),
@ -259,20 +259,16 @@ class TestBaseIndexProcessor:
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None: def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
db_query = Mock()
db_query.where.return_value.first.return_value = None
db_session = Mock() db_session = Mock()
db_session.query.return_value = db_query db_session.get.return_value = None
with patch("core.rag.index_processor.index_processor_base.db.session", db_session): with patch("core.rag.index_processor.index_processor_base.db.session", db_session):
assert processor._download_tool_file("tool-id", current_user=Mock()) is None assert processor._download_tool_file("tool-id", current_user=Mock()) is None
def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None: def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png") tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png")
db_query = Mock()
db_query.where.return_value.first.return_value = tool_file
db_session = Mock() db_session = Mock()
db_session.query.return_value = db_query db_session.get.return_value = tool_file
mock_db = Mock() mock_db = Mock()
mock_db.session = db_session mock_db.session = db_session
mock_db.engine = Mock() mock_db.engine = Mock()

View File

@ -473,12 +473,10 @@ class TestRerankModelRunnerMultimodal:
metadata={}, metadata={},
provider="external", provider="external",
) )
query = Mock()
query.where.return_value.first.return_value = SimpleNamespace(key="image-key")
rerank_result = RerankResult(model="rerank-model", docs=[]) rerank_result = RerankResult(model="rerank-model", docs=[])
with ( with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), patch("core.rag.rerank.rerank_model.db.session.get", return_value=SimpleNamespace(key="image-key")),
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once, patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once,
patch.object( patch.object(
rerank_runner, rerank_runner,
@ -504,12 +502,10 @@ class TestRerankModelRunnerMultimodal:
metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE}, metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE},
provider="dify", provider="dify",
) )
query = Mock()
query.where.return_value.first.return_value = None
rerank_result = RerankResult(model="rerank-model", docs=[]) rerank_result = RerankResult(model="rerank-model", docs=[])
with ( with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), patch("core.rag.rerank.rerank_model.db.session.get", return_value=None),
patch.object( patch.object(
rerank_runner, rerank_runner,
"fetch_text_rerank", "fetch_text_rerank",
@ -533,8 +529,6 @@ class TestRerankModelRunnerMultimodal:
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
provider="dify", provider="dify",
) )
query_chain = Mock()
query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key")
rerank_result = RerankResult( rerank_result = RerankResult(
model="rerank-model", model="rerank-model",
docs=[RerankDocument(index=0, text="text-content", score=0.77)], docs=[RerankDocument(index=0, text="text-content", score=0.77)],
@ -542,7 +536,7 @@ class TestRerankModelRunnerMultimodal:
mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result
session = MagicMock() session = MagicMock()
session.query.return_value = query_chain session.get.return_value = SimpleNamespace(key="query-image-key")
with ( with (
patch("core.rag.rerank.rerank_model.db.session", session), patch("core.rag.rerank.rerank_model.db.session", session),
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"),
@ -563,10 +557,7 @@ class TestRerankModelRunnerMultimodal:
assert "user" not in invoke_kwargs assert "user" not in invoke_kwargs
def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner):
query_chain = Mock() with patch("core.rag.rerank.rerank_model.db.session.get", return_value=None):
query_chain.where.return_value.first.return_value = None
with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain):
with pytest.raises(ValueError, match="Upload file not found for query"): with pytest.raises(ValueError, match="Upload file not found for query"):
rerank_runner.fetch_multimodal_rerank( rerank_runner.fetch_multimodal_rerank(
query="missing-upload-id", query="missing-upload-id",

View File

@ -3971,11 +3971,10 @@ class TestDatasetRetrievalAdditionalHelpers:
) )
def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None: def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None:
db_query = Mock() scalars_result = Mock()
db_query.where.return_value = db_query scalars_result.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
mapping, condition = retrieval.get_metadata_filter_condition( mapping, condition = retrieval.get_metadata_filter_condition(
dataset_ids=["d1"], dataset_ids=["d1"],
query="python", query="python",
@ -3991,7 +3990,7 @@ class TestDatasetRetrievalAdditionalHelpers:
automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}] automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}]
with ( with (
patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query), patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result),
patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters), patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters),
): ):
mapping, condition = retrieval.get_metadata_filter_condition( mapping, condition = retrieval.get_metadata_filter_condition(
@ -4012,7 +4011,7 @@ class TestDatasetRetrievalAdditionalHelpers:
logical_operator="and", logical_operator="and",
conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")], conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")],
) )
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
mapping, condition = retrieval.get_metadata_filter_condition( mapping, condition = retrieval.get_metadata_filter_condition(
dataset_ids=["d1"], dataset_ids=["d1"],
query="python", query="python",
@ -4027,7 +4026,7 @@ class TestDatasetRetrievalAdditionalHelpers:
assert condition is not None assert condition is not None
assert condition.conditions[0].value == "Alice" assert condition.conditions[0].value == "Alice"
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
with pytest.raises(ValueError, match="Invalid metadata filtering mode"): with pytest.raises(ValueError, match="Invalid metadata filtering mode"):
retrieval.get_metadata_filter_condition( retrieval.get_metadata_filter_condition(
dataset_ids=["d1"], dataset_ids=["d1"],