mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 09:57:03 +08:00
fix: Compatibility issues with the summary index feature when using the weaviate vector database (#35052)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
This commit is contained in:
parent
0f643bca76
commit
b0c4d8c541
@ -20,7 +20,7 @@ from pydantic import BaseModel, model_validator
|
||||
from weaviate.classes.data import DataObject
|
||||
from weaviate.classes.init import Auth
|
||||
from weaviate.classes.query import Filter, MetadataQuery
|
||||
from weaviate.exceptions import UnexpectedStatusCodeError
|
||||
from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateQueryError
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -230,6 +230,8 @@ class WeaviateVector(BaseVector):
|
||||
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
|
||||
wc.Property(name="doc_type", data_type=wc.DataType.TEXT),
|
||||
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
|
||||
wc.Property(name="is_summary", data_type=wc.DataType.BOOL),
|
||||
wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT),
|
||||
],
|
||||
vector_config=wc.Configure.Vectors.self_provided(),
|
||||
)
|
||||
@ -262,6 +264,10 @@ class WeaviateVector(BaseVector):
|
||||
to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT))
|
||||
if "chunk_index" not in existing:
|
||||
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))
|
||||
if "is_summary" not in existing:
|
||||
to_add.append(wc.Property(name="is_summary", data_type=wc.DataType.BOOL))
|
||||
if "original_chunk_id" not in existing:
|
||||
to_add.append(wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT))
|
||||
|
||||
for prop in to_add:
|
||||
try:
|
||||
@ -400,15 +406,27 @@ class WeaviateVector(BaseVector):
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
res = col.query.near_vector(
|
||||
near_vector=query_vector,
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
return_metadata=MetadataQuery(distance=True),
|
||||
include_vector=False,
|
||||
filters=where,
|
||||
target_vector="default",
|
||||
)
|
||||
try:
|
||||
res = col.query.near_vector(
|
||||
near_vector=query_vector,
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
return_metadata=MetadataQuery(distance=True),
|
||||
include_vector=False,
|
||||
filters=where,
|
||||
target_vector="default",
|
||||
)
|
||||
except WeaviateQueryError:
|
||||
self._ensure_properties()
|
||||
res = col.query.near_vector(
|
||||
near_vector=query_vector,
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
return_metadata=MetadataQuery(distance=True),
|
||||
include_vector=False,
|
||||
filters=where,
|
||||
target_vector="default",
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
for obj in res.objects:
|
||||
@ -446,14 +464,25 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
|
||||
res = col.query.bm25(
|
||||
query=query,
|
||||
query_properties=[Field.TEXT_KEY.value],
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
include_vector=True,
|
||||
filters=where,
|
||||
)
|
||||
try:
|
||||
res = col.query.bm25(
|
||||
query=query,
|
||||
query_properties=[Field.TEXT_KEY.value],
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
include_vector=True,
|
||||
filters=where,
|
||||
)
|
||||
except WeaviateQueryError:
|
||||
self._ensure_properties()
|
||||
res = col.query.bm25(
|
||||
query=query,
|
||||
query_properties=[Field.TEXT_KEY.value],
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
include_vector=True,
|
||||
filters=where,
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
for obj in res.objects:
|
||||
|
||||
@ -326,7 +326,7 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
|
||||
add_calls = mock_col.config.add_property.call_args_list
|
||||
added_names = [call.args[0].name for call in add_calls]
|
||||
assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"]
|
||||
assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index", "is_summary", "original_chunk_id"]
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module):
|
||||
@ -346,6 +346,8 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
SimpleNamespace(name="doc_id"),
|
||||
SimpleNamespace(name="doc_type"),
|
||||
SimpleNamespace(name="chunk_index"),
|
||||
SimpleNamespace(name="is_summary"),
|
||||
SimpleNamespace(name="original_chunk_id"),
|
||||
]
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = existing_props
|
||||
@ -383,7 +385,7 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
with patch.object(weaviate_vector_module.logger, "warning") as mock_warning:
|
||||
wv._ensure_properties()
|
||||
|
||||
assert mock_warning.call_count == 4
|
||||
assert mock_warning.call_count == 6
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module):
|
||||
@ -484,6 +486,56 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
|
||||
assert wv.search_by_vector(query_vector=[0.1] * 3) == []
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_vector_retries_on_weaviate_query_error(self, mock_weaviate_module):
|
||||
"""Test that search_by_vector catches WeaviateQueryError, calls _ensure_properties, and retries."""
|
||||
from weaviate.exceptions import WeaviateQueryError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
# First call raises WeaviateQueryError, second call succeeds
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.properties = {"text": "retry result", "document_id": "doc-1"}
|
||||
mock_obj.metadata.distance = 0.2
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.objects = [mock_obj]
|
||||
|
||||
mock_col.query.near_vector.side_effect = [
|
||||
WeaviateQueryError("missing property", "gRPC"),
|
||||
mock_result,
|
||||
]
|
||||
|
||||
# Mock _ensure_properties dependencies
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = [
|
||||
SimpleNamespace(name="text"),
|
||||
SimpleNamespace(name="document_id"),
|
||||
SimpleNamespace(name="doc_id"),
|
||||
SimpleNamespace(name="doc_type"),
|
||||
SimpleNamespace(name="chunk_index"),
|
||||
SimpleNamespace(name="is_summary"),
|
||||
SimpleNamespace(name="original_chunk_id"),
|
||||
]
|
||||
mock_col.config.get.return_value = mock_cfg
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
docs = wv.search_by_vector(query_vector=[0.1] * 3, top_k=1)
|
||||
|
||||
assert mock_col.query.near_vector.call_count == 2
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.8)
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module):
|
||||
"""Test that search_by_full_text also returns doc_type in document metadata."""
|
||||
@ -569,6 +621,56 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
|
||||
assert wv.search_by_full_text(query="missing") == []
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_full_text_retries_on_weaviate_query_error(self, mock_weaviate_module):
|
||||
"""Test that search_by_full_text catches WeaviateQueryError, calls _ensure_properties, and retries."""
|
||||
from weaviate.exceptions import WeaviateQueryError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
# First call raises WeaviateQueryError, second call succeeds
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.properties = {"text": "retry bm25 result", "doc_id": "segment-1"}
|
||||
mock_obj.vector = {"default": [0.5, 0.6]}
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.objects = [mock_obj]
|
||||
|
||||
mock_col.query.bm25.side_effect = [
|
||||
WeaviateQueryError("missing property", "gRPC"),
|
||||
mock_result,
|
||||
]
|
||||
|
||||
# Mock _ensure_properties dependencies
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = [
|
||||
SimpleNamespace(name="text"),
|
||||
SimpleNamespace(name="document_id"),
|
||||
SimpleNamespace(name="doc_id"),
|
||||
SimpleNamespace(name="doc_type"),
|
||||
SimpleNamespace(name="chunk_index"),
|
||||
SimpleNamespace(name="is_summary"),
|
||||
SimpleNamespace(name="original_chunk_id"),
|
||||
]
|
||||
mock_col.config.get.return_value = mock_cfg
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
docs = wv.search_by_full_text(query="retry", top_k=1)
|
||||
|
||||
assert mock_col.query.bm25.call_count == 2
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "retry bm25 result"
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module):
|
||||
"""Test that add_texts includes doc_type from document metadata in stored properties."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user