feat: enhance OceanBase vector database with SQL injection fixes, unified processing, and improved error handling (#28951)

This commit is contained in:
Conner Mo 2025-12-01 09:51:47 +08:00 committed by GitHub
parent 861098714b
commit 0af8a7b958
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 196 additions and 64 deletions

View File

@ -58,11 +58,39 @@ class OceanBaseVector(BaseVector):
password=self._config.password,
db_name=self._config.database,
)
self._fields: list[str] = [] # List of fields in the collection
if self._client.check_table_exists(collection_name):
self._load_collection_fields()
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def get_type(self) -> str:
return VectorType.OCEANBASE
def _load_collection_fields(self):
"""
Load collection fields from the database table.
This method populates the _fields list with column names from the table.
"""
try:
if self._collection_name in self._client.metadata_obj.tables:
table = self._client.metadata_obj.tables[self._collection_name]
# Store all column names except 'id' (primary key)
self._fields = [column.name for column in table.columns if column.name != "id"]
logger.debug("Loaded fields for collection '%s': %s", self._collection_name, self._fields)
else:
logger.warning("Collection '%s' not found in metadata", self._collection_name)
except Exception as e:
logger.warning("Failed to load collection fields for '%s': %s", self._collection_name, str(e))
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
:param field: Field name to check
:return: True if field exists, False otherwise
"""
return field in self._fields
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._vec_dim = len(embeddings[0])
self._create_collection()
@ -151,6 +179,7 @@ class OceanBaseVector(BaseVector):
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
self._client.refresh_metadata([self._collection_name])
self._load_collection_fields()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _check_hybrid_search_support(self) -> bool:
@ -177,42 +206,134 @@ class OceanBaseVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
for id, doc, emb in zip(ids, documents, embeddings):
self._client.insert(
table_name=self._collection_name,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)
try:
self._client.insert(
table_name=self._collection_name,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)
except Exception as e:
logger.exception(
"Failed to insert document with id '%s' in collection '%s'",
id,
self._collection_name,
)
raise Exception(f"Failed to insert document with id '{id}'") from e
def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, ids=id)
return bool(cur.rowcount != 0)
try:
cur = self._client.get(table_name=self._collection_name, ids=id)
return bool(cur.rowcount != 0)
except Exception as e:
logger.exception(
"Failed to check if text exists with id '%s' in collection '%s'",
id,
self._collection_name,
)
raise Exception(f"Failed to check text existence for id '{id}'") from e
def delete_by_ids(self, ids: list[str]):
if not ids:
return
self._client.delete(table_name=self._collection_name, ids=ids)
try:
self._client.delete(table_name=self._collection_name, ids=ids)
logger.debug("Deleted %d documents from collection '%s'", len(ids), self._collection_name)
except Exception as e:
logger.exception(
"Failed to delete %d documents from collection '%s'",
len(ids),
self._collection_name,
)
raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
from sqlalchemy import text
try:
import re
cur = self._client.get(
table_name=self._collection_name,
ids=None,
where_clause=[text(f"metadata->>'$.{key}' = '{value}'")],
output_column_name=["id"],
)
return [row[0] for row in cur]
from sqlalchemy import text
# Validate key to prevent injection in JSON path
if not re.match(r"^[a-zA-Z0-9_.]+$", key):
raise ValueError(f"Invalid characters in metadata key: {key}")
# Use parameterized query to prevent SQL injection
sql = text(f"SELECT id FROM `{self._collection_name}` WHERE metadata->>'$.{key}' = :value")
with self._client.engine.connect() as conn:
result = conn.execute(sql, {"value": value})
ids = [row[0] for row in result]
logger.debug(
"Found %d documents with metadata field '%s'='%s' in collection '%s'",
len(ids),
key,
value,
self._collection_name,
)
return ids
except Exception as e:
logger.exception(
"Failed to get IDs by metadata field '%s'='%s' in collection '%s'",
key,
value,
self._collection_name,
)
raise Exception(f"Failed to query documents by metadata field '{key}'") from e
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)
if ids:
self.delete_by_ids(ids)
else:
logger.debug("No documents found to delete with metadata field '%s'='%s'", key, value)
def _process_search_results(
self, results: list[tuple], score_threshold: float = 0.0, score_key: str = "score"
) -> list[Document]:
"""
Common method to process search results
:param results: Search results as list of tuples (text, metadata, score)
:param score_threshold: Score threshold for filtering
:param score_key: Key name for score in metadata
:return: List of documents
"""
docs = []
for row in results:
text, metadata_str, score = row[0], row[1], row[2]
# Parse metadata JSON
try:
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata_str)
metadata = {}
# Add score to metadata
metadata[score_key] = score
# Filter by score threshold
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
if not self._hybrid_search_enabled:
logger.warning(
"Full-text search is disabled: set OCEANBASE_ENABLE_HYBRID_SEARCH=true (requires OceanBase >= 4.3.5.1)."
)
return []
if not self.field_exists("text"):
logger.warning(
"Full-text search unavailable: collection '%s' missing 'text' field; "
"recreate the collection after enabling OCEANBASE_ENABLE_HYBRID_SEARCH to add fulltext index.",
self._collection_name,
)
return []
try:
@ -220,13 +341,24 @@ class OceanBaseVector(BaseVector):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0)
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
# Build parameterized query to prevent SQL injection
from sqlalchemy import text
document_ids_filter = kwargs.get("document_ids_filter")
params = {"query": query}
where_clause = ""
if document_ids_filter:
# Create parameterized placeholders for document IDs
placeholders = ", ".join(f":doc_id_{i}" for i in range(len(document_ids_filter)))
where_clause = f" AND metadata->>'$.document_id' IN ({placeholders})"
# Add document IDs to parameters
for i, doc_id in enumerate(document_ids_filter):
params[f"doc_id_{i}"] = doc_id
full_sql = f"""SELECT text, metadata, MATCH (text) AGAINST (:query) AS score
FROM {self._collection_name}
WHERE MATCH (text) AGAINST (:query) > 0
{where_clause}
@ -235,35 +367,35 @@ class OceanBaseVector(BaseVector):
with self._client.engine.connect() as conn:
with conn.begin():
from sqlalchemy import text
result = conn.execute(text(full_sql), {"query": query})
result = conn.execute(text(full_sql), params)
rows = result.fetchall()
docs = []
for row in rows:
metadata_str, _text, score = row
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata_str)
metadata = {}
metadata["score"] = score
docs.append(Document(page_content=_text, metadata=metadata))
return docs
return self._process_search_results(rows, score_threshold=score_threshold)
except Exception as e:
logger.warning("Failed to fulltext search: %s.", str(e))
return []
logger.exception(
"Failed to perform full-text search on collection '%s' with query '%s'",
self._collection_name,
query,
)
raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from sqlalchemy import text
document_ids_filter = kwargs.get("document_ids_filter")
_where_clause = None
if document_ids_filter:
# Validate document IDs to prevent SQL injection
# Document IDs should be alphanumeric with hyphens and underscores
import re
for doc_id in document_ids_filter:
if not isinstance(doc_id, str) or not re.match(r"^[a-zA-Z0-9_-]+$", doc_id):
raise ValueError(f"Invalid document ID format: {doc_id}")
# Safe to use in query after validation
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
from sqlalchemy import text
_where_clause = [text(where_clause)]
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
@ -286,27 +418,27 @@ class OceanBaseVector(BaseVector):
where_clause=_where_clause,
)
except Exception as e:
raise Exception("Failed to search by vector. ", e)
docs = []
for _text, metadata, distance in cur:
logger.exception(
"Failed to perform vector search on collection '%s'",
self._collection_name,
)
raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e
# Convert distance to score and prepare results for processing
results = []
for _text, metadata_str, distance in cur:
score = 1 - distance / math.sqrt(2)
if score >= score_threshold:
try:
metadata = json.loads(metadata)
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata)
metadata = {}
metadata["score"] = score
docs.append(
Document(
page_content=_text,
metadata=metadata,
)
)
return docs
results.append((_text, metadata_str, score))
return self._process_search_results(results, score_threshold=score_threshold)
def delete(self):
self._client.drop_table_if_exist(self._collection_name)
try:
self._client.drop_table_if_exist(self._collection_name)
logger.debug("Dropped collection '%s'", self._collection_name)
except Exception as e:
logger.exception("Failed to delete collection '%s'", self._collection_name)
raise Exception(f"Failed to delete collection '{self._collection_name}'") from e
class OceanBaseVectorFactory(AbstractVectorFactory):