diff --git a/api/configs/middleware/vdb/oceanbase_config.py b/api/configs/middleware/vdb/oceanbase_config.py index 7c9376f86b..27ec99e56a 100644 --- a/api/configs/middleware/vdb/oceanbase_config.py +++ b/api/configs/middleware/vdb/oceanbase_config.py @@ -1,3 +1,5 @@ +from typing import Literal + from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -49,3 +51,43 @@ class OceanBaseVectorConfig(BaseSettings): ), default="ik", ) + + OCEANBASE_VECTOR_BATCH_SIZE: PositiveInt = Field( + description="Number of documents to insert per batch", + default=100, + ) + + OCEANBASE_VECTOR_METRIC_TYPE: Literal["l2", "cosine", "inner_product"] = Field( + description="Distance metric type for vector index: l2, cosine, or inner_product", + default="l2", + ) + + OCEANBASE_HNSW_M: PositiveInt = Field( + description="HNSW M parameter (max number of connections per node)", + default=16, + ) + + OCEANBASE_HNSW_EF_CONSTRUCTION: PositiveInt = Field( + description="HNSW efConstruction parameter (index build-time search width)", + default=256, + ) + + OCEANBASE_HNSW_EF_SEARCH: int = Field( + description="HNSW efSearch parameter (query-time search width, -1 uses server default)", + default=-1, + ) + + OCEANBASE_VECTOR_POOL_SIZE: PositiveInt = Field( + description="SQLAlchemy connection pool size", + default=5, + ) + + OCEANBASE_VECTOR_MAX_OVERFLOW: int = Field( + description="SQLAlchemy connection pool max overflow connections", + default=10, + ) + + OCEANBASE_HNSW_REFRESH_THRESHOLD: int = Field( + description="Minimum number of inserted documents to trigger an automatic HNSW index refresh (0 to disable)", + default=1000, + ) diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index dc3b70140b..86c1e65f47 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -1,12 +1,13 @@ import json import logging -import math -from typing import Any +import re +from typing import Any, Literal from pydantic import BaseModel, model_validator -from pyobvector import VECTOR, ObVecClient, l2_distance # type: ignore +from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_distance # type: ignore from sqlalchemy import JSON, Column, String from sqlalchemy.dialects.mysql import LONGTEXT +from sqlalchemy.exc import SQLAlchemyError from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -19,10 +20,14 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) -DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256} -DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64} OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW" -DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2" +_VALID_TABLE_NAME_RE = re.compile(r"^[a-zA-Z0-9_]+$") + +_DISTANCE_FUNC_MAP = { + "l2": l2_distance, + "cosine": cosine_distance, + "inner_product": inner_product, +} class OceanBaseVectorConfig(BaseModel): @@ -32,6 +37,14 @@ class OceanBaseVectorConfig(BaseModel): password: str database: str enable_hybrid_search: bool = False + batch_size: int = 100 + metric_type: Literal["l2", "cosine", "inner_product"] = "l2" + hnsw_m: int = 16 + hnsw_ef_construction: int = 256 + hnsw_ef_search: int = -1 + pool_size: int = 5 + max_overflow: int = 10 + hnsw_refresh_threshold: int = 1000 @model_validator(mode="before") @classmethod @@ -49,14 +62,23 @@ class OceanBaseVectorConfig(BaseModel): class OceanBaseVector(BaseVector): def __init__(self, collection_name: str, config: OceanBaseVectorConfig): + if not _VALID_TABLE_NAME_RE.match(collection_name): + raise ValueError( + f"Invalid collection name '{collection_name}': " + "only alphanumeric characters and underscores are allowed." + ) super().__init__(collection_name) self._config = config - self._hnsw_ef_search = -1 + self._hnsw_ef_search = self._config.hnsw_ef_search self._client = ObVecClient( uri=f"{self._config.host}:{self._config.port}", user=self._config.user, password=self._config.password, db_name=self._config.database, + pool_size=self._config.pool_size, + max_overflow=self._config.max_overflow, + pool_recycle=3600, + pool_pre_ping=True, ) self._fields: list[str] = [] # List of fields in the collection if self._client.check_table_exists(collection_name): @@ -136,8 +158,8 @@ class OceanBaseVector(BaseVector): field_name="vector", index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE, index_name="vector_index", - metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE, - params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM, + metric_type=self._config.metric_type, + params={"M": self._config.hnsw_m, "efConstruction": self._config.hnsw_ef_construction}, ) self._client.create_table_with_index_params( @@ -178,6 +200,17 @@ class OceanBaseVector(BaseVector): else: logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name) + try: + self._client.perform_raw_text_sql( + f"CREATE INDEX IF NOT EXISTS idx_metadata_doc_id ON `{self._collection_name}` " + f"((CAST(metadata->>'$.document_id' AS CHAR(64))))" + ) + except SQLAlchemyError: + logger.warning( + "Failed to create metadata functional index on '%s'; metadata queries may be slow without it.", + self._collection_name, + ) + self._client.refresh_metadata([self._collection_name]) self._load_collection_fields() redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -205,24 +238,49 @@ class OceanBaseVector(BaseVector): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): ids = self._get_uuids(documents) - for id, doc, emb in zip(ids, documents, embeddings): + batch_size = self._config.batch_size + total = len(documents) + + all_data = [ + { + "id": doc_id, + "vector": emb, + "text": doc.page_content, + "metadata": doc.metadata, + } + for doc_id, doc, emb in zip(ids, documents, embeddings) + ] + + for start in range(0, total, batch_size): + batch = all_data[start : start + batch_size] try: self._client.insert( table_name=self._collection_name, - data={ - "id": id, - "vector": emb, - "text": doc.page_content, - "metadata": doc.metadata, - }, + data=batch, ) except Exception as e: logger.exception( - "Failed to insert document with id '%s' in collection '%s'", - id, + "Failed to insert batch [%d:%d] into collection '%s'", + start, + start + len(batch), + self._collection_name, + ) + raise Exception( + f"Failed to insert batch [{start}:{start + len(batch)}] into collection '{self._collection_name}'" + ) from e + + if self._config.hnsw_refresh_threshold > 0 and total >= self._config.hnsw_refresh_threshold: + try: + self._client.refresh_index( + table_name=self._collection_name, + index_name="vector_index", + ) + except SQLAlchemyError: + logger.warning( + "Failed to refresh HNSW index after inserting %d documents into '%s'", + total, self._collection_name, ) - raise Exception(f"Failed to insert document with id '{id}'") from e def text_exists(self, id: str) -> bool: try: @@ -412,7 +470,7 @@ class OceanBaseVector(BaseVector): vec_column_name="vector", vec_data=query_vector, topk=topk, - distance_func=l2_distance, + distance_func=self._get_distance_func(), output_column_names=["text", "metadata"], with_dist=True, where_clause=_where_clause, @@ -424,14 +482,31 @@ class OceanBaseVector(BaseVector): ) raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e - # Convert distance to score and prepare results for processing results = [] for _text, metadata_str, distance in cur: - score = 1 - distance / math.sqrt(2) + score = self._distance_to_score(distance) results.append((_text, metadata_str, score)) return self._process_search_results(results, score_threshold=score_threshold) + def _get_distance_func(self): + func = _DISTANCE_FUNC_MAP.get(self._config.metric_type) + if func is None: + raise ValueError( + f"Unsupported metric_type '{self._config.metric_type}'. Supported: {', '.join(_DISTANCE_FUNC_MAP)}" + ) + return func + + def _distance_to_score(self, distance: float) -> float: + metric = self._config.metric_type + if metric == "l2": + return 1.0 / (1.0 + distance) + elif metric == "cosine": + return 1.0 - distance + elif metric == "inner_product": + return -distance + raise ValueError(f"Unsupported metric_type '{metric}'") + def delete(self): try: self._client.drop_table_if_exist(self._collection_name) @@ -464,5 +539,13 @@ class OceanBaseVectorFactory(AbstractVectorFactory): password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""), database=dify_config.OCEANBASE_VECTOR_DATABASE or "", enable_hybrid_search=dify_config.OCEANBASE_ENABLE_HYBRID_SEARCH or False, + batch_size=dify_config.OCEANBASE_VECTOR_BATCH_SIZE, + metric_type=dify_config.OCEANBASE_VECTOR_METRIC_TYPE, + hnsw_m=dify_config.OCEANBASE_HNSW_M, + hnsw_ef_construction=dify_config.OCEANBASE_HNSW_EF_CONSTRUCTION, + hnsw_ef_search=dify_config.OCEANBASE_HNSW_EF_SEARCH, + pool_size=dify_config.OCEANBASE_VECTOR_POOL_SIZE, + max_overflow=dify_config.OCEANBASE_VECTOR_MAX_OVERFLOW, + hnsw_refresh_threshold=dify_config.OCEANBASE_HNSW_REFRESH_THRESHOLD, ), ) diff --git a/api/tests/integration_tests/vdb/oceanbase/bench_oceanbase.py b/api/tests/integration_tests/vdb/oceanbase/bench_oceanbase.py new file mode 100644 index 0000000000..8b57be08c5 --- /dev/null +++ b/api/tests/integration_tests/vdb/oceanbase/bench_oceanbase.py @@ -0,0 +1,241 @@ +""" +Benchmark: OceanBase vector store — old (single-row) vs new (batch) insertion, +metadata query with/without functional index, and vector search across metrics. + +Usage: + uv run --project api python -m tests.integration_tests.vdb.oceanbase.bench_oceanbase +""" + +import json +import random +import statistics +import time +import uuid + +from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_distance +from sqlalchemy import JSON, Column, String, text +from sqlalchemy.dialects.mysql import LONGTEXT + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- +HOST = "127.0.0.1" +PORT = 2881 +USER = "root@test" +PASSWORD = "difyai123456" +DATABASE = "test" + +VEC_DIM = 1536 +HNSW_BUILD = {"M": 16, "efConstruction": 256} +DISTANCE_FUNCS = {"l2": l2_distance, "cosine": cosine_distance, "inner_product": inner_product} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _make_client(**extra): + return ObVecClient( + uri=f"{HOST}:{PORT}", + user=USER, + password=PASSWORD, + db_name=DATABASE, + **extra, + ) + + +def _rand_vec(): + return [random.uniform(-1, 1) for _ in range(VEC_DIM)] # noqa: S311 + + +def _drop(client, table): + client.drop_table_if_exist(table) + + +def _create_table(client, table, metric="l2"): + cols = [ + Column("id", String(36), primary_key=True, autoincrement=False), + Column("vector", VECTOR(VEC_DIM)), + Column("text", LONGTEXT), + Column("metadata", JSON), + ] + vidx = client.prepare_index_params() + vidx.add_index( + field_name="vector", + index_type="HNSW", + index_name="vector_index", + metric_type=metric, + params=HNSW_BUILD, + ) + client.create_table_with_index_params(table_name=table, columns=cols, vidxs=vidx) + client.refresh_metadata([table]) + + +def _gen_rows(n): + doc_id = str(uuid.uuid4()) + rows = [] + for _ in range(n): + rows.append( + { + "id": str(uuid.uuid4()), + "vector": _rand_vec(), + "text": f"benchmark text {uuid.uuid4().hex[:12]}", + "metadata": json.dumps({"document_id": doc_id, "dataset_id": str(uuid.uuid4())}), + } + ) + return rows, doc_id + + +# --------------------------------------------------------------------------- +# Benchmark: Insertion +# --------------------------------------------------------------------------- +def bench_insert_single(client, table, rows): + """Old approach: one INSERT per row.""" + t0 = time.perf_counter() + for row in rows: + client.insert(table_name=table, data=row) + return time.perf_counter() - t0 + + +def bench_insert_batch(client, table, rows, batch_size=100): + """New approach: batch INSERT.""" + t0 = time.perf_counter() + for start in range(0, len(rows), batch_size): + batch = rows[start : start + batch_size] + client.insert(table_name=table, data=batch) + return time.perf_counter() - t0 + + +# --------------------------------------------------------------------------- +# Benchmark: Metadata query +# --------------------------------------------------------------------------- +def bench_metadata_query(client, table, doc_id, with_index=False): + """Query by metadata->>'$.document_id' with/without functional index.""" + if with_index: + try: + client.perform_raw_text_sql(f"CREATE INDEX idx_metadata_doc_id ON `{table}` ((metadata->>'$.document_id'))") + except Exception: + pass # already exists + + sql = text(f"SELECT id FROM `{table}` WHERE metadata->>'$.document_id' = :val") + times = [] + with client.engine.connect() as conn: + for _ in range(10): + t0 = time.perf_counter() + result = conn.execute(sql, {"val": doc_id}) + _ = result.fetchall() + times.append(time.perf_counter() - t0) + return times + + +# --------------------------------------------------------------------------- +# Benchmark: Vector search +# --------------------------------------------------------------------------- +def bench_vector_search(client, table, metric, topk=10, n_queries=20): + dist_func = DISTANCE_FUNCS[metric] + times = [] + for _ in range(n_queries): + q = _rand_vec() + t0 = time.perf_counter() + cur = client.ann_search( + table_name=table, + vec_column_name="vector", + vec_data=q, + topk=topk, + distance_func=dist_func, + output_column_names=["text", "metadata"], + with_dist=True, + ) + _ = list(cur) + times.append(time.perf_counter() - t0) + return times + + +def _fmt(times): + """Format list of durations as 'mean ± stdev'.""" + m = statistics.mean(times) * 1000 + s = statistics.stdev(times) * 1000 if len(times) > 1 else 0 + return f"{m:.1f} ± {s:.1f} ms" + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def main(): + client = _make_client() + client_pooled = _make_client(pool_size=5, max_overflow=10, pool_recycle=3600, pool_pre_ping=True) + + print("=" * 70) + print("OceanBase Vector Store — Performance Benchmark") + print(f" Endpoint : {HOST}:{PORT}") + print(f" Vec dim : {VEC_DIM}") + print("=" * 70) + + # ------------------------------------------------------------------ + # 1. Insertion benchmark + # ------------------------------------------------------------------ + for n_docs in [100, 500, 1000]: + rows, doc_id = _gen_rows(n_docs) + tbl_single = f"bench_single_{n_docs}" + tbl_batch = f"bench_batch_{n_docs}" + + _drop(client, tbl_single) + _drop(client, tbl_batch) + _create_table(client, tbl_single) + _create_table(client, tbl_batch) + + t_single = bench_insert_single(client, tbl_single, rows) + t_batch = bench_insert_batch(client_pooled, tbl_batch, rows, batch_size=100) + + speedup = t_single / t_batch if t_batch > 0 else float("inf") + print(f"\n[Insert {n_docs} docs]") + print(f" Single-row : {t_single:.2f}s") + print(f" Batch(100) : {t_batch:.2f}s") + print(f" Speedup : {speedup:.1f}x") + + # ------------------------------------------------------------------ + # 2. Metadata query benchmark (use the 1000-doc batch table) + # ------------------------------------------------------------------ + tbl_meta = "bench_batch_1000" + rows_1000, doc_id_1000 = _gen_rows(1000) + # The table already has 1000 rows from step 1; use that doc_id + # Re-query doc_id from one of the rows we inserted + with client.engine.connect() as conn: + res = conn.execute(text(f"SELECT metadata->>'$.document_id' FROM `{tbl_meta}` LIMIT 1")) + doc_id_1000 = res.fetchone()[0] + + print("\n[Metadata filter query — 1000 rows, by document_id]") + times_no_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=False) + print(f" Without index : {_fmt(times_no_idx)}") + times_with_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=True) + print(f" With index : {_fmt(times_with_idx)}") + + # ------------------------------------------------------------------ + # 3. Vector search benchmark — across metrics + # ------------------------------------------------------------------ + print("\n[Vector search — top-10, 20 queries each, on 1000 rows]") + + for metric in ["l2", "cosine", "inner_product"]: + tbl_vs = f"bench_vs_{metric}" + _drop(client_pooled, tbl_vs) + _create_table(client_pooled, tbl_vs, metric=metric) + # Insert 1000 rows + rows_vs, _ = _gen_rows(1000) + bench_insert_batch(client_pooled, tbl_vs, rows_vs, batch_size=100) + times = bench_vector_search(client_pooled, tbl_vs, metric, topk=10, n_queries=20) + print(f" {metric:15s}: {_fmt(times)}") + _drop(client_pooled, tbl_vs) + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + for n in [100, 500, 1000]: + _drop(client, f"bench_single_{n}") + _drop(client, f"bench_batch_{n}") + + print("\n" + "=" * 70) + print("Benchmark complete.") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py b/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py index 8fbbbe61b8..2db6732354 100644 --- a/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py +++ b/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py @@ -21,6 +21,7 @@ def oceanbase_vector(): database="test", password="difyai123456", enable_hybrid_search=True, + batch_size=10, ), )