mirror of https://github.com/langgenius/dify.git
chore: tablestore full text search support score normalization (#23255)
Co-authored-by: xiaozhiqing.xzq <xiaozhiqing.xzq@alibaba-inc.com>
This commit is contained in:
parent
c33741a5e9
commit
da5c003f97
|
|
@ -232,6 +232,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com
|
||||||
TABLESTORE_INSTANCE_NAME=instance-name
|
TABLESTORE_INSTANCE_NAME=instance-name
|
||||||
TABLESTORE_ACCESS_KEY_ID=xxx
|
TABLESTORE_ACCESS_KEY_ID=xxx
|
||||||
TABLESTORE_ACCESS_KEY_SECRET=xxx
|
TABLESTORE_ACCESS_KEY_SECRET=xxx
|
||||||
|
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
|
||||||
|
|
||||||
# Tidb Vector configuration
|
# Tidb Vector configuration
|
||||||
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
|
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
|
||||||
|
|
|
||||||
|
|
@ -28,3 +28,8 @@ class TableStoreConfig(BaseSettings):
|
||||||
description="AccessKey secret for the instance name",
|
description="AccessKey secret for the instance name",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field(
|
||||||
|
description="Whether to normalize full-text search scores to [0, 1]",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import tablestore # type: ignore
|
import tablestore # type: ignore
|
||||||
|
|
@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel):
|
||||||
access_key_secret: Optional[str] = None
|
access_key_secret: Optional[str] = None
|
||||||
instance_name: Optional[str] = None
|
instance_name: Optional[str] = None
|
||||||
endpoint: Optional[str] = None
|
endpoint: Optional[str] = None
|
||||||
|
normalize_full_text_bm25_score: Optional[bool] = False
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -47,6 +49,7 @@ class TableStoreVector(BaseVector):
|
||||||
config.access_key_secret,
|
config.access_key_secret,
|
||||||
config.instance_name,
|
config.instance_name,
|
||||||
)
|
)
|
||||||
|
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
|
||||||
self._table_name = f"{collection_name}"
|
self._table_name = f"{collection_name}"
|
||||||
self._index_name = f"{collection_name}_idx"
|
self._index_name = f"{collection_name}_idx"
|
||||||
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
|
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
|
||||||
|
|
@ -131,8 +134,8 @@ class TableStoreVector(BaseVector):
|
||||||
filtered_list = None
|
filtered_list = None
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||||
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
return self._search_by_full_text(query, filtered_list, top_k)
|
return self._search_by_full_text(query, filtered_list, top_k, score_threshold)
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
self._delete_table_if_exist()
|
self._delete_table_if_exist()
|
||||||
|
|
@ -318,7 +321,19 @@ class TableStoreVector(BaseVector):
|
||||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
|
@staticmethod
|
||||||
|
def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
score: BM25 search score.
|
||||||
|
k: decay factor, the larger the k, the steeper the low score end
|
||||||
|
"""
|
||||||
|
normalized_score = 1 - math.exp(-k * score)
|
||||||
|
return max(0.0, min(1.0, normalized_score))
|
||||||
|
|
||||||
|
def _search_by_full_text(
|
||||||
|
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||||
|
) -> list[Document]:
|
||||||
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
|
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
|
||||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
|
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
|
||||||
|
|
||||||
|
|
@ -339,15 +354,27 @@ class TableStoreVector(BaseVector):
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
for search_hit in search_response.search_hits:
|
for search_hit in search_response.search_hits:
|
||||||
|
score = None
|
||||||
|
if self._normalize_full_text_bm25_score:
|
||||||
|
score = self._normalize_score_exp_decay(search_hit.score)
|
||||||
|
|
||||||
|
# skip when score is below threshold and use normalize score
|
||||||
|
if score and score <= score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
ots_column_map = {}
|
ots_column_map = {}
|
||||||
for col in search_hit.row[1]:
|
for col in search_hit.row[1]:
|
||||||
ots_column_map[col[0]] = col[1]
|
ots_column_map[col[0]] = col[1]
|
||||||
|
|
||||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
|
||||||
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
||||||
vector = json.loads(vector_str) if vector_str else None
|
|
||||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||||
|
|
||||||
|
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||||
|
vector = json.loads(vector_str) if vector_str else None
|
||||||
|
|
||||||
|
if score:
|
||||||
|
metadata["score"] = score
|
||||||
|
|
||||||
documents.append(
|
documents.append(
|
||||||
Document(
|
Document(
|
||||||
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
||||||
|
|
@ -355,6 +382,8 @@ class TableStoreVector(BaseVector):
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if self._normalize_full_text_bm25_score:
|
||||||
|
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory):
|
||||||
instance_name=dify_config.TABLESTORE_INSTANCE_NAME,
|
instance_name=dify_config.TABLESTORE_INSTANCE_NAME,
|
||||||
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID,
|
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID,
|
||||||
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET,
|
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET,
|
||||||
|
normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import tablestore
|
import tablestore
|
||||||
|
from _pytest.python_api import approx
|
||||||
|
|
||||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
||||||
TableStoreConfig,
|
TableStoreConfig,
|
||||||
|
|
@ -16,7 +17,7 @@ from tests.integration_tests.vdb.test_vector_store import (
|
||||||
|
|
||||||
|
|
||||||
class TableStoreVectorTest(AbstractVectorTest):
|
class TableStoreVectorTest(AbstractVectorTest):
|
||||||
def __init__(self):
|
def __init__(self, normalize_full_text_score: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vector = TableStoreVector(
|
self.vector = TableStoreVector(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
|
|
@ -25,6 +26,7 @@ class TableStoreVectorTest(AbstractVectorTest):
|
||||||
instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"),
|
instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"),
|
||||||
access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"),
|
access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"),
|
||||||
access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"),
|
access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"),
|
||||||
|
normalize_full_text_bm25_score=normalize_full_text_score,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -64,7 +66,21 @@ class TableStoreVectorTest(AbstractVectorTest):
|
||||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
|
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
|
||||||
assert len(docs) == 1
|
assert len(docs) == 1
|
||||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||||
assert not hasattr(docs[0], "score")
|
if self.vector._config.normalize_full_text_bm25_score:
|
||||||
|
assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3)
|
||||||
|
else:
|
||||||
|
assert docs[0].metadata.get("score") is None
|
||||||
|
|
||||||
|
# return none if normalize_full_text_score=true and score_threshold > 0
|
||||||
|
docs = self.vector.search_by_full_text(
|
||||||
|
get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5
|
||||||
|
)
|
||||||
|
if self.vector._config.normalize_full_text_bm25_score:
|
||||||
|
assert len(docs) == 0
|
||||||
|
else:
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||||
|
assert docs[0].metadata.get("score") is None
|
||||||
|
|
||||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
|
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
|
||||||
assert len(docs) == 0
|
assert len(docs) == 0
|
||||||
|
|
@ -80,3 +96,5 @@ class TableStoreVectorTest(AbstractVectorTest):
|
||||||
|
|
||||||
def test_tablestore_vector(setup_mock_redis):
|
def test_tablestore_vector(setup_mock_redis):
|
||||||
TableStoreVectorTest().run_all_tests()
|
TableStoreVectorTest().run_all_tests()
|
||||||
|
TableStoreVectorTest(normalize_full_text_score=True).run_all_tests()
|
||||||
|
TableStoreVectorTest(normalize_full_text_score=False).run_all_tests()
|
||||||
|
|
|
||||||
|
|
@ -653,6 +653,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com
|
||||||
TABLESTORE_INSTANCE_NAME=instance-name
|
TABLESTORE_INSTANCE_NAME=instance-name
|
||||||
TABLESTORE_ACCESS_KEY_ID=xxx
|
TABLESTORE_ACCESS_KEY_ID=xxx
|
||||||
TABLESTORE_ACCESS_KEY_SECRET=xxx
|
TABLESTORE_ACCESS_KEY_SECRET=xxx
|
||||||
|
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# Knowledge Configuration
|
# Knowledge Configuration
|
||||||
|
|
|
||||||
|
|
@ -312,6 +312,7 @@ x-shared-env: &shared-api-worker-env
|
||||||
TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name}
|
TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name}
|
||||||
TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx}
|
TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx}
|
||||||
TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx}
|
TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx}
|
||||||
|
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false}
|
||||||
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
|
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
|
||||||
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
|
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
|
||||||
ETL_TYPE: ${ETL_TYPE:-dify}
|
ETL_TYPE: ${ETL_TYPE:-dify}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue