mirror of https://github.com/langgenius/dify.git
feat(api): optimize OceanBase vector store performance and configurability (#32263)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c0ffb6db2a
commit
16df9851a2
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Literal
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
|
@ -49,3 +51,43 @@ class OceanBaseVectorConfig(BaseSettings):
|
|||
),
|
||||
default="ik",
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Number of documents to insert per batch",
|
||||
default=100,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_METRIC_TYPE: Literal["l2", "cosine", "inner_product"] = Field(
|
||||
description="Distance metric type for vector index: l2, cosine, or inner_product",
|
||||
default="l2",
|
||||
)
|
||||
|
||||
OCEANBASE_HNSW_M: PositiveInt = Field(
|
||||
description="HNSW M parameter (max number of connections per node)",
|
||||
default=16,
|
||||
)
|
||||
|
||||
OCEANBASE_HNSW_EF_CONSTRUCTION: PositiveInt = Field(
|
||||
description="HNSW efConstruction parameter (index build-time search width)",
|
||||
default=256,
|
||||
)
|
||||
|
||||
OCEANBASE_HNSW_EF_SEARCH: int = Field(
|
||||
description="HNSW efSearch parameter (query-time search width, -1 uses server default)",
|
||||
default=-1,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_POOL_SIZE: PositiveInt = Field(
|
||||
description="SQLAlchemy connection pool size",
|
||||
default=5,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_MAX_OVERFLOW: int = Field(
|
||||
description="SQLAlchemy connection pool max overflow connections",
|
||||
default=10,
|
||||
)
|
||||
|
||||
OCEANBASE_HNSW_REFRESH_THRESHOLD: int = Field(
|
||||
description="Minimum number of inserted documents to trigger an automatic HNSW index refresh (0 to disable)",
|
||||
default=1000,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, ObVecClient, l2_distance # type: ignore
|
||||
from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_distance # type: ignore
|
||||
from sqlalchemy import JSON, Column, String
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
|
|
@ -19,10 +20,14 @@ from models.dataset import Dataset
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
|
||||
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
|
||||
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
|
||||
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"
|
||||
_VALID_TABLE_NAME_RE = re.compile(r"^[a-zA-Z0-9_]+$")
|
||||
|
||||
_DISTANCE_FUNC_MAP = {
|
||||
"l2": l2_distance,
|
||||
"cosine": cosine_distance,
|
||||
"inner_product": inner_product,
|
||||
}
|
||||
|
||||
|
||||
class OceanBaseVectorConfig(BaseModel):
|
||||
|
|
@ -32,6 +37,14 @@ class OceanBaseVectorConfig(BaseModel):
|
|||
password: str
|
||||
database: str
|
||||
enable_hybrid_search: bool = False
|
||||
batch_size: int = 100
|
||||
metric_type: Literal["l2", "cosine", "inner_product"] = "l2"
|
||||
hnsw_m: int = 16
|
||||
hnsw_ef_construction: int = 256
|
||||
hnsw_ef_search: int = -1
|
||||
pool_size: int = 5
|
||||
max_overflow: int = 10
|
||||
hnsw_refresh_threshold: int = 1000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
@ -49,14 +62,23 @@ class OceanBaseVectorConfig(BaseModel):
|
|||
|
||||
class OceanBaseVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
|
||||
if not _VALID_TABLE_NAME_RE.match(collection_name):
|
||||
raise ValueError(
|
||||
f"Invalid collection name '{collection_name}': "
|
||||
"only alphanumeric characters and underscores are allowed."
|
||||
)
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._hnsw_ef_search = -1
|
||||
self._hnsw_ef_search = self._config.hnsw_ef_search
|
||||
self._client = ObVecClient(
|
||||
uri=f"{self._config.host}:{self._config.port}",
|
||||
user=self._config.user,
|
||||
password=self._config.password,
|
||||
db_name=self._config.database,
|
||||
pool_size=self._config.pool_size,
|
||||
max_overflow=self._config.max_overflow,
|
||||
pool_recycle=3600,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
self._fields: list[str] = [] # List of fields in the collection
|
||||
if self._client.check_table_exists(collection_name):
|
||||
|
|
@ -136,8 +158,8 @@ class OceanBaseVector(BaseVector):
|
|||
field_name="vector",
|
||||
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
|
||||
index_name="vector_index",
|
||||
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
|
||||
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
|
||||
metric_type=self._config.metric_type,
|
||||
params={"M": self._config.hnsw_m, "efConstruction": self._config.hnsw_ef_construction},
|
||||
)
|
||||
|
||||
self._client.create_table_with_index_params(
|
||||
|
|
@ -178,6 +200,17 @@ class OceanBaseVector(BaseVector):
|
|||
else:
|
||||
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
|
||||
|
||||
try:
|
||||
self._client.perform_raw_text_sql(
|
||||
f"CREATE INDEX IF NOT EXISTS idx_metadata_doc_id ON `{self._collection_name}` "
|
||||
f"((CAST(metadata->>'$.document_id' AS CHAR(64))))"
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
logger.warning(
|
||||
"Failed to create metadata functional index on '%s'; metadata queries may be slow without it.",
|
||||
self._collection_name,
|
||||
)
|
||||
|
||||
self._client.refresh_metadata([self._collection_name])
|
||||
self._load_collection_fields()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
|
@ -205,24 +238,49 @@ class OceanBaseVector(BaseVector):
|
|||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = self._get_uuids(documents)
|
||||
for id, doc, emb in zip(ids, documents, embeddings):
|
||||
batch_size = self._config.batch_size
|
||||
total = len(documents)
|
||||
|
||||
all_data = [
|
||||
{
|
||||
"id": doc_id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
}
|
||||
for doc_id, doc, emb in zip(ids, documents, embeddings)
|
||||
]
|
||||
|
||||
for start in range(0, total, batch_size):
|
||||
batch = all_data[start : start + batch_size]
|
||||
try:
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
data=batch,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to insert document with id '%s' in collection '%s'",
|
||||
id,
|
||||
"Failed to insert batch [%d:%d] into collection '%s'",
|
||||
start,
|
||||
start + len(batch),
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(
|
||||
f"Failed to insert batch [{start}:{start + len(batch)}] into collection '{self._collection_name}'"
|
||||
) from e
|
||||
|
||||
if self._config.hnsw_refresh_threshold > 0 and total >= self._config.hnsw_refresh_threshold:
|
||||
try:
|
||||
self._client.refresh_index(
|
||||
table_name=self._collection_name,
|
||||
index_name="vector_index",
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
logger.warning(
|
||||
"Failed to refresh HNSW index after inserting %d documents into '%s'",
|
||||
total,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to insert document with id '{id}'") from e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
try:
|
||||
|
|
@ -412,7 +470,7 @@ class OceanBaseVector(BaseVector):
|
|||
vec_column_name="vector",
|
||||
vec_data=query_vector,
|
||||
topk=topk,
|
||||
distance_func=l2_distance,
|
||||
distance_func=self._get_distance_func(),
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
where_clause=_where_clause,
|
||||
|
|
@ -424,14 +482,31 @@ class OceanBaseVector(BaseVector):
|
|||
)
|
||||
raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e
|
||||
|
||||
# Convert distance to score and prepare results for processing
|
||||
results = []
|
||||
for _text, metadata_str, distance in cur:
|
||||
score = 1 - distance / math.sqrt(2)
|
||||
score = self._distance_to_score(distance)
|
||||
results.append((_text, metadata_str, score))
|
||||
|
||||
return self._process_search_results(results, score_threshold=score_threshold)
|
||||
|
||||
def _get_distance_func(self):
|
||||
func = _DISTANCE_FUNC_MAP.get(self._config.metric_type)
|
||||
if func is None:
|
||||
raise ValueError(
|
||||
f"Unsupported metric_type '{self._config.metric_type}'. Supported: {', '.join(_DISTANCE_FUNC_MAP)}"
|
||||
)
|
||||
return func
|
||||
|
||||
def _distance_to_score(self, distance: float) -> float:
|
||||
metric = self._config.metric_type
|
||||
if metric == "l2":
|
||||
return 1.0 / (1.0 + distance)
|
||||
elif metric == "cosine":
|
||||
return 1.0 - distance
|
||||
elif metric == "inner_product":
|
||||
return -distance
|
||||
raise ValueError(f"Unsupported metric_type '{metric}'")
|
||||
|
||||
def delete(self):
|
||||
try:
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
|
|
@ -464,5 +539,13 @@ class OceanBaseVectorFactory(AbstractVectorFactory):
|
|||
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
|
||||
database=dify_config.OCEANBASE_VECTOR_DATABASE or "",
|
||||
enable_hybrid_search=dify_config.OCEANBASE_ENABLE_HYBRID_SEARCH or False,
|
||||
batch_size=dify_config.OCEANBASE_VECTOR_BATCH_SIZE,
|
||||
metric_type=dify_config.OCEANBASE_VECTOR_METRIC_TYPE,
|
||||
hnsw_m=dify_config.OCEANBASE_HNSW_M,
|
||||
hnsw_ef_construction=dify_config.OCEANBASE_HNSW_EF_CONSTRUCTION,
|
||||
hnsw_ef_search=dify_config.OCEANBASE_HNSW_EF_SEARCH,
|
||||
pool_size=dify_config.OCEANBASE_VECTOR_POOL_SIZE,
|
||||
max_overflow=dify_config.OCEANBASE_VECTOR_MAX_OVERFLOW,
|
||||
hnsw_refresh_threshold=dify_config.OCEANBASE_HNSW_REFRESH_THRESHOLD,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,241 @@
|
|||
"""
|
||||
Benchmark: OceanBase vector store — old (single-row) vs new (batch) insertion,
|
||||
metadata query with/without functional index, and vector search across metrics.
|
||||
|
||||
Usage:
|
||||
uv run --project api python -m tests.integration_tests.vdb.oceanbase.bench_oceanbase
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import statistics
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_distance
|
||||
from sqlalchemy import JSON, Column, String, text
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
HOST = "127.0.0.1"
|
||||
PORT = 2881
|
||||
USER = "root@test"
|
||||
PASSWORD = "difyai123456"
|
||||
DATABASE = "test"
|
||||
|
||||
VEC_DIM = 1536
|
||||
HNSW_BUILD = {"M": 16, "efConstruction": 256}
|
||||
DISTANCE_FUNCS = {"l2": l2_distance, "cosine": cosine_distance, "inner_product": inner_product}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _make_client(**extra):
|
||||
return ObVecClient(
|
||||
uri=f"{HOST}:{PORT}",
|
||||
user=USER,
|
||||
password=PASSWORD,
|
||||
db_name=DATABASE,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
def _rand_vec():
|
||||
return [random.uniform(-1, 1) for _ in range(VEC_DIM)] # noqa: S311
|
||||
|
||||
|
||||
def _drop(client, table):
|
||||
client.drop_table_if_exist(table)
|
||||
|
||||
|
||||
def _create_table(client, table, metric="l2"):
|
||||
cols = [
|
||||
Column("id", String(36), primary_key=True, autoincrement=False),
|
||||
Column("vector", VECTOR(VEC_DIM)),
|
||||
Column("text", LONGTEXT),
|
||||
Column("metadata", JSON),
|
||||
]
|
||||
vidx = client.prepare_index_params()
|
||||
vidx.add_index(
|
||||
field_name="vector",
|
||||
index_type="HNSW",
|
||||
index_name="vector_index",
|
||||
metric_type=metric,
|
||||
params=HNSW_BUILD,
|
||||
)
|
||||
client.create_table_with_index_params(table_name=table, columns=cols, vidxs=vidx)
|
||||
client.refresh_metadata([table])
|
||||
|
||||
|
||||
def _gen_rows(n):
|
||||
doc_id = str(uuid.uuid4())
|
||||
rows = []
|
||||
for _ in range(n):
|
||||
rows.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": _rand_vec(),
|
||||
"text": f"benchmark text {uuid.uuid4().hex[:12]}",
|
||||
"metadata": json.dumps({"document_id": doc_id, "dataset_id": str(uuid.uuid4())}),
|
||||
}
|
||||
)
|
||||
return rows, doc_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark: Insertion
|
||||
# ---------------------------------------------------------------------------
|
||||
def bench_insert_single(client, table, rows):
|
||||
"""Old approach: one INSERT per row."""
|
||||
t0 = time.perf_counter()
|
||||
for row in rows:
|
||||
client.insert(table_name=table, data=row)
|
||||
return time.perf_counter() - t0
|
||||
|
||||
|
||||
def bench_insert_batch(client, table, rows, batch_size=100):
|
||||
"""New approach: batch INSERT."""
|
||||
t0 = time.perf_counter()
|
||||
for start in range(0, len(rows), batch_size):
|
||||
batch = rows[start : start + batch_size]
|
||||
client.insert(table_name=table, data=batch)
|
||||
return time.perf_counter() - t0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark: Metadata query
|
||||
# ---------------------------------------------------------------------------
|
||||
def bench_metadata_query(client, table, doc_id, with_index=False):
|
||||
"""Query by metadata->>'$.document_id' with/without functional index."""
|
||||
if with_index:
|
||||
try:
|
||||
client.perform_raw_text_sql(f"CREATE INDEX idx_metadata_doc_id ON `{table}` ((metadata->>'$.document_id'))")
|
||||
except Exception:
|
||||
pass # already exists
|
||||
|
||||
sql = text(f"SELECT id FROM `{table}` WHERE metadata->>'$.document_id' = :val")
|
||||
times = []
|
||||
with client.engine.connect() as conn:
|
||||
for _ in range(10):
|
||||
t0 = time.perf_counter()
|
||||
result = conn.execute(sql, {"val": doc_id})
|
||||
_ = result.fetchall()
|
||||
times.append(time.perf_counter() - t0)
|
||||
return times
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark: Vector search
|
||||
# ---------------------------------------------------------------------------
|
||||
def bench_vector_search(client, table, metric, topk=10, n_queries=20):
|
||||
dist_func = DISTANCE_FUNCS[metric]
|
||||
times = []
|
||||
for _ in range(n_queries):
|
||||
q = _rand_vec()
|
||||
t0 = time.perf_counter()
|
||||
cur = client.ann_search(
|
||||
table_name=table,
|
||||
vec_column_name="vector",
|
||||
vec_data=q,
|
||||
topk=topk,
|
||||
distance_func=dist_func,
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
)
|
||||
_ = list(cur)
|
||||
times.append(time.perf_counter() - t0)
|
||||
return times
|
||||
|
||||
|
||||
def _fmt(times):
|
||||
"""Format list of durations as 'mean ± stdev'."""
|
||||
m = statistics.mean(times) * 1000
|
||||
s = statistics.stdev(times) * 1000 if len(times) > 1 else 0
|
||||
return f"{m:.1f} ± {s:.1f} ms"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
def main():
|
||||
client = _make_client()
|
||||
client_pooled = _make_client(pool_size=5, max_overflow=10, pool_recycle=3600, pool_pre_ping=True)
|
||||
|
||||
print("=" * 70)
|
||||
print("OceanBase Vector Store — Performance Benchmark")
|
||||
print(f" Endpoint : {HOST}:{PORT}")
|
||||
print(f" Vec dim : {VEC_DIM}")
|
||||
print("=" * 70)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Insertion benchmark
|
||||
# ------------------------------------------------------------------
|
||||
for n_docs in [100, 500, 1000]:
|
||||
rows, doc_id = _gen_rows(n_docs)
|
||||
tbl_single = f"bench_single_{n_docs}"
|
||||
tbl_batch = f"bench_batch_{n_docs}"
|
||||
|
||||
_drop(client, tbl_single)
|
||||
_drop(client, tbl_batch)
|
||||
_create_table(client, tbl_single)
|
||||
_create_table(client, tbl_batch)
|
||||
|
||||
t_single = bench_insert_single(client, tbl_single, rows)
|
||||
t_batch = bench_insert_batch(client_pooled, tbl_batch, rows, batch_size=100)
|
||||
|
||||
speedup = t_single / t_batch if t_batch > 0 else float("inf")
|
||||
print(f"\n[Insert {n_docs} docs]")
|
||||
print(f" Single-row : {t_single:.2f}s")
|
||||
print(f" Batch(100) : {t_batch:.2f}s")
|
||||
print(f" Speedup : {speedup:.1f}x")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Metadata query benchmark (use the 1000-doc batch table)
|
||||
# ------------------------------------------------------------------
|
||||
tbl_meta = "bench_batch_1000"
|
||||
rows_1000, doc_id_1000 = _gen_rows(1000)
|
||||
# The table already has 1000 rows from step 1; use that doc_id
|
||||
# Re-query doc_id from one of the rows we inserted
|
||||
with client.engine.connect() as conn:
|
||||
res = conn.execute(text(f"SELECT metadata->>'$.document_id' FROM `{tbl_meta}` LIMIT 1"))
|
||||
doc_id_1000 = res.fetchone()[0]
|
||||
|
||||
print("\n[Metadata filter query — 1000 rows, by document_id]")
|
||||
times_no_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=False)
|
||||
print(f" Without index : {_fmt(times_no_idx)}")
|
||||
times_with_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=True)
|
||||
print(f" With index : {_fmt(times_with_idx)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Vector search benchmark — across metrics
|
||||
# ------------------------------------------------------------------
|
||||
print("\n[Vector search — top-10, 20 queries each, on 1000 rows]")
|
||||
|
||||
for metric in ["l2", "cosine", "inner_product"]:
|
||||
tbl_vs = f"bench_vs_{metric}"
|
||||
_drop(client_pooled, tbl_vs)
|
||||
_create_table(client_pooled, tbl_vs, metric=metric)
|
||||
# Insert 1000 rows
|
||||
rows_vs, _ = _gen_rows(1000)
|
||||
bench_insert_batch(client_pooled, tbl_vs, rows_vs, batch_size=100)
|
||||
times = bench_vector_search(client_pooled, tbl_vs, metric, topk=10, n_queries=20)
|
||||
print(f" {metric:15s}: {_fmt(times)}")
|
||||
_drop(client_pooled, tbl_vs)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ------------------------------------------------------------------
|
||||
for n in [100, 500, 1000]:
|
||||
_drop(client, f"bench_single_{n}")
|
||||
_drop(client, f"bench_batch_{n}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Benchmark complete.")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -21,6 +21,7 @@ def oceanbase_vector():
|
|||
database="test",
|
||||
password="difyai123456",
|
||||
enable_hybrid_search=True,
|
||||
batch_size=10,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue