From 345ac8333cd9ceac4b5cb2fcaa7883f7bb1ece94 Mon Sep 17 00:00:00 2001 From: Shili Cao Date: Mon, 22 Sep 2025 10:17:35 +0800 Subject: [PATCH] Add Full-Text & Hybrid Search Support to Baidu Vector DB and Update SDK, Closes #25982 (#25983) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/.env.example | 2 + .../middleware/vdb/baidu_vector_config.py | 10 + api/controllers/console/datasets/datasets.py | 4 +- .../rag/datasource/vdb/baidu/baidu_vector.py | 183 ++++++++++++------ api/pyproject.toml | 2 +- .../vdb/__mock/baiduvectordb.py | 8 +- api/uv.lock | 10 +- docker/.env.example | 2 + docker/docker-compose.yaml | 2 + 9 files changed, 156 insertions(+), 67 deletions(-) diff --git a/api/.env.example b/api/.env.example index b89111a8e3..78a363e506 100644 --- a/api/.env.example +++ b/api/.env.example @@ -304,6 +304,8 @@ BAIDU_VECTOR_DB_API_KEY=dify BAIDU_VECTOR_DB_DATABASE=dify BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 +BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER +BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE # Upstash configuration UPSTASH_VECTOR_URL=your-server-url diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 4b6ddb3bde..8f956745b1 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -41,3 +41,13 @@ class BaiduVectorDBConfig(BaseSettings): description="Number of replicas for the Baidu Vector Database (default is 3)", default=3, ) + + BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: str = Field( + description="Analyzer type for inverted index in Baidu Vector Database (default is DEFAULT_ANALYZER)", + default="DEFAULT_ANALYZER", + ) + + BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: str = Field( + description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)", + default="COARSE_MODE", + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index a1ae941d4b..2affbd6a42 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -782,7 +782,6 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.PGVECTO_RS - | VectorType.BAIDU | VectorType.VIKINGDB | VectorType.UPSTASH ): @@ -809,6 +808,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.TENCENT | VectorType.MATRIXONE | VectorType.CLICKZETTA + | VectorType.BAIDU ): return { "retrieval_method": [ @@ -838,7 +838,6 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.PGVECTO_RS - | VectorType.BAIDU | VectorType.VIKINGDB | VectorType.UPSTASH ): @@ -863,6 +862,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.HUAWEI_CLOUD | VectorType.MATRIXONE | VectorType.CLICKZETTA + | VectorType.BAIDU ): return { "retrieval_method": [ diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index aa980f3835..144d834495 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -1,4 +1,5 @@ import json +import logging import time import uuid from typing import Any @@ -9,11 +10,24 @@ from pymochow import MochowClient # type: ignore from pymochow.auth.bce_credentials import BceCredentials # type: ignore from pymochow.configuration import Configuration # type: ignore from pymochow.exception import ServerError # type: ignore +from pymochow.model.database import Database from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore -from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore -from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore +from pymochow.model.schema import ( + Field, + FilteringIndex, + HNSWParams, + InvertedIndex, + InvertedIndexAnalyzer, + InvertedIndexFieldAttribute, + InvertedIndexParams, + InvertedIndexParseMode, + Schema, + VectorIndex, +) # type: ignore +from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, Partition, Row # type: ignore from configs import dify_config +from core.rag.datasource.vdb.field import Field as VDBField from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -22,6 +36,8 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset +logger = logging.getLogger(__name__) + class BaiduConfig(BaseModel): endpoint: str @@ -30,9 +46,11 @@ class BaiduConfig(BaseModel): api_key: str database: str index_type: str = "HNSW" - metric_type: str = "L2" + metric_type: str = "IP" shard: int = 1 replicas: int = 3 + inverted_index_analyzer: str = "DEFAULT_ANALYZER" + inverted_index_parser_mode: str = "COARSE_MODE" @model_validator(mode="before") @classmethod @@ -49,13 +67,9 @@ class BaiduConfig(BaseModel): class BaiduVector(BaseVector): - field_id: str = "id" - field_vector: str = "vector" - field_text: str = "text" - field_metadata: str = "metadata" - field_app_id: str = "app_id" - field_annotation_id: str = "annotation_id" - index_vector: str = "vector_idx" + vector_index: str = "vector_idx" + filtering_index: str = "filtering_idx" + inverted_index: str = "content_inverted_idx" def __init__(self, collection_name: str, config: BaiduConfig): super().__init__(collection_name) @@ -74,8 +88,6 @@ class BaiduVector(BaseVector): self.add_texts(texts, embeddings) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents if doc.metadata is not None] total_count = len(documents) batch_size = 1000 @@ -84,29 +96,31 @@ class BaiduVector(BaseVector): for start in range(0, total_count, batch_size): end = min(start + batch_size, total_count) rows = [] - assert len(metadatas) == total_count, "metadatas length should be equal to total_count" for i in range(start, end, 1): + metadata = documents[i].metadata row = Row( - id=metadatas[i].get("doc_id", str(uuid.uuid4())), + id=metadata.get("doc_id", str(uuid.uuid4())), + page_content=documents[i].page_content, + metadata=metadata, vector=embeddings[i], - text=texts[i], - metadata=json.dumps(metadatas[i]), - app_id=metadatas[i].get("app_id", ""), - annotation_id=metadatas[i].get("annotation_id", ""), ) rows.append(row) table.upsert(rows=rows) # rebuild vector index after upsert finished - table.rebuild_index(self.index_vector) + table.rebuild_index(self.vector_index) + timeout = 3600 # 1 hour timeout + start_time = time.time() while True: time.sleep(1) - index = table.describe_index(self.index_vector) + index = table.describe_index(self.vector_index) if index.state == IndexState.NORMAL: break + if time.time() - start_time > timeout: + raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") def text_exists(self, id: str) -> bool: - res = self._db.table(self._collection_name).query(primary_key={self.field_id: id}) + res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id}) if res and res.code == 0: return True return False @@ -115,53 +129,73 @@ class BaiduVector(BaseVector): if not ids: return quoted_ids = [f"'{id}'" for id in ids] - self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") + self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})") def delete_by_metadata_field(self, key: str, value: str): - self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") + # Escape double quotes in value to prevent injection + escaped_value = value.replace('"', '\\"') + self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"') def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] document_ids_filter = kwargs.get("document_ids_filter") + filter = "" if document_ids_filter: document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - anns = AnnSearch( - vector_field=self.field_vector, - vector_floats=query_vector, - params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), - filter=f"document_id IN ({document_ids})", - ) - else: - anns = AnnSearch( - vector_field=self.field_vector, - vector_floats=query_vector, - params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), - ) + filter = f'metadata["document_id"] IN({document_ids})' + anns = AnnSearch( + vector_field=VDBField.VECTOR, + vector_floats=query_vector, + params=HNSWSearchParams(ef=kwargs.get("ef", 20), limit=kwargs.get("top_k", 4)), + filter=filter, + ) res = self._db.table(self._collection_name).search( anns=anns, - projections=[self.field_id, self.field_text, self.field_metadata], - retrieve_vector=True, + projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY], + retrieve_vector=False, ) score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - # baidu vector database doesn't support bm25 search on current version - return [] + # document ids filter + document_ids_filter = kwargs.get("document_ids_filter") + filter = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + filter = f'metadata["document_id"] IN({document_ids})' + + request = BM25SearchRequest( + index_name=self.inverted_index, search_text=query, limit=kwargs.get("top_k", 4), filter=filter + ) + res = self._db.table(self._collection_name).bm25_search( + request=request, projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY] + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._get_search_res(res, score_threshold) def _get_search_res(self, res, score_threshold) -> list[Document]: docs = [] for row in res.rows: row_data = row.get("row", {}) - meta = row_data.get(self.field_metadata) - if meta is not None: - meta = json.loads(meta) score = row.get("score", 0.0) + meta = row_data.get(VDBField.METADATA_KEY, {}) + + # Handle both JSON string and dict formats for backward compatibility + if isinstance(meta, str): + try: + import json + + meta = json.loads(meta) + except (json.JSONDecodeError, TypeError): + meta = {} + elif not isinstance(meta, dict): + meta = {} + if score >= score_threshold: meta["score"] = score - doc = Document(page_content=row_data.get(self.field_text), metadata=meta) + doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), metadata=meta) docs.append(doc) - return docs def delete(self): @@ -178,7 +212,7 @@ class BaiduVector(BaseVector): client = MochowClient(config) return client - def _init_database(self): + def _init_database(self) -> Database: exists = False for db in self._client.list_databases(): if db.database_name == self._client_config.database: @@ -192,10 +226,10 @@ class BaiduVector(BaseVector): self._client.create_database(database_name=self._client_config.database) except ServerError as e: if e.code == ServerErrCode.DB_ALREADY_EXIST: - pass + return self._client.database(self._client_config.database) else: raise - return + return self._client.database(self._client_config.database) def _table_existed(self) -> bool: tables = self._db.list_table() @@ -232,7 +266,7 @@ class BaiduVector(BaseVector): fields = [] fields.append( Field( - self.field_id, + VDBField.PRIMARY_KEY, FieldType.STRING, primary_key=True, partition_key=True, @@ -240,24 +274,57 @@ class BaiduVector(BaseVector): not_null=True, ) ) - fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True)) - fields.append(Field(self.field_app_id, FieldType.STRING)) - fields.append(Field(self.field_annotation_id, FieldType.STRING)) - fields.append(Field(self.field_text, FieldType.TEXT, not_null=True)) - fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension)) + fields.append(Field(VDBField.CONTENT_KEY, FieldType.TEXT, not_null=False)) + fields.append(Field(VDBField.METADATA_KEY, FieldType.JSON, not_null=False)) + fields.append(Field(VDBField.VECTOR, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension)) # Construct vector index params indexes = [] indexes.append( VectorIndex( - index_name="vector_idx", + index_name=self.vector_index, index_type=index_type, - field="vector", + field=VDBField.VECTOR, metric_type=metric_type, params=HNSWParams(m=16, efconstruction=200), ) ) + # Filtering index + indexes.append( + FilteringIndex( + index_name=self.filtering_index, + fields=[VDBField.METADATA_KEY], + ) + ) + + # Get analyzer and parse_mode from config + analyzer = getattr( + InvertedIndexAnalyzer, + self._client_config.inverted_index_analyzer, + InvertedIndexAnalyzer.DEFAULT_ANALYZER, + ) + + parse_mode = getattr( + InvertedIndexParseMode, + self._client_config.inverted_index_parser_mode, + InvertedIndexParseMode.COARSE_MODE, + ) + + # Inverted index + indexes.append( + InvertedIndex( + index_name=self.inverted_index, + fields=[VDBField.CONTENT_KEY], + params=InvertedIndexParams( + analyzer=analyzer, + parse_mode=parse_mode, + case_sensitive=True, + ), + field_attributes=[InvertedIndexFieldAttribute.ANALYZED], + ) + ) + # Create table self._db.create_table( table_name=self._collection_name, @@ -268,11 +335,15 @@ class BaiduVector(BaseVector): ) # Wait for table created + timeout = 300 # 5 minutes timeout + start_time = time.time() while True: time.sleep(1) table = self._db.describe_table(self._collection_name) if table.state == TableState.NORMAL: break + if time.time() - start_time > timeout: + raise TimeoutError(f"Table creation timeout after {timeout} seconds") redis_client.set(table_exist_cache_key, 1, ex=3600) @@ -296,5 +367,7 @@ class BaiduVectorFactory(AbstractVectorFactory): database=dify_config.BAIDU_VECTOR_DB_DATABASE or "", shard=dify_config.BAIDU_VECTOR_DB_SHARD, replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, + inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER, + inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE, ), ) diff --git a/api/pyproject.toml b/api/pyproject.toml index 42024936ec..2ebd830345 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -211,7 +211,7 @@ vdb = [ "pgvecto-rs[sqlalchemy]~=0.2.1", "pgvector==0.2.5", "pymilvus~=2.5.0", - "pymochow==1.3.1", + "pymochow==2.2.9", "pyobvector~=0.2.15", "qdrant-client==1.9.0", "tablestore==6.2.0", diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index f9f9f4f369..6d2aff5197 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -100,8 +100,8 @@ class MockBaiduVectorDBClass: "row": { "id": primary_key.get("id"), "vector": [0.23432432, 0.8923744, 0.89238432], - "text": "text", - "metadata": '{"doc_id": "doc_id_001"}', + "page_content": "text", + "metadata": {"doc_id": "doc_id_001"}, }, "code": 0, "msg": "Success", @@ -127,8 +127,8 @@ class MockBaiduVectorDBClass: "row": { "id": "doc_id_001", "vector": [0.23432432, 0.8923744, 0.89238432], - "text": "text", - "metadata": '{"doc_id": "doc_id_001"}', + "page_content": "text", + "metadata": {"doc_id": "doc_id_001"}, }, "distance": 0.1, "score": 0.5, diff --git a/api/uv.lock b/api/uv.lock index 5cb637457b..2070ebc5fd 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and sys_platform == 'linux'", @@ -1670,7 +1670,7 @@ vdb = [ { name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" }, { name = "pgvector", specifier = "==0.2.5" }, { name = "pymilvus", specifier = "~=2.5.0" }, - { name = "pymochow", specifier = "==1.3.1" }, + { name = "pymochow", specifier = "==2.2.9" }, { name = "pyobvector", specifier = "~=0.2.15" }, { name = "qdrant-client", specifier = "==1.9.0" }, { name = "tablestore", specifier = "==6.2.0" }, @@ -4935,16 +4935,16 @@ wheels = [ [[package]] name = "pymochow" -version = "1.3.1" +version = "2.2.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, { name = "orjson" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cc/da/3027eeeaf7a7db9b0ca761079de4e676a002e1cc2c4260dab0ce812972b8/pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba", size = 30800, upload-time = "2024-09-11T12:06:37.88Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/29/d9b112684ce490057b90bddede3fb6a69cf2787a3fd7736bdce203e77388/pymochow-2.2.9.tar.gz", hash = "sha256:5a28058edc8861deb67524410e786814571ed9fe0700c8c9fc0bc2ad5835b06c", size = 50079, upload-time = "2025-06-05T08:33:19.59Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6b/74/4b6227717f6baa37e7288f53e0fd55764939abc4119342eed4924a98f477/pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327", size = 42697, upload-time = "2024-09-11T12:06:36.114Z" }, + { url = "https://files.pythonhosted.org/packages/bf/9b/be18f9709dfd8187ff233be5acb253a9f4f1b07f1db0e7b09d84197c28e2/pymochow-2.2.9-py3-none-any.whl", hash = "sha256:639192b97f143d4a22fc163872be12aee19523c46f12e22416e8f289f1354d15", size = 77899, upload-time = "2025-06-05T08:33:17.424Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index af72ce8213..d4e8ab3beb 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -635,6 +635,8 @@ BAIDU_VECTOR_DB_API_KEY=dify BAIDU_VECTOR_DB_DATABASE=dify BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 +BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER +BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE # VikingDB configurations, only available when VECTOR_STORE is `vikingdb` VIKINGDB_ACCESS_KEY=your-ak diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index acec6adf10..2d6ba572e6 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -286,6 +286,8 @@ x-shared-env: &shared-api-worker-env BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify} BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1} BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} + BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: ${BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER:-DEFAULT_ANALYZER} + BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: ${BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE:-COARSE_MODE} VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-your-ak} VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-your-sk} VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai}