From 0af8a7b958dd96425b4b8659558f324c30fed8e2 Mon Sep 17 00:00:00 2001 From: Conner Mo Date: Mon, 1 Dec 2025 09:51:47 +0800 Subject: [PATCH] feat: enhance OceanBase vector database with SQL injection fixes, unified processing, and improved error handling (#28951) --- .../vdb/oceanbase/oceanbase_vector.py | 260 +++++++++++++----- 1 file changed, 196 insertions(+), 64 deletions(-) diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index 7b53f47419..dc3b70140b 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -58,11 +58,39 @@ class OceanBaseVector(BaseVector): password=self._config.password, db_name=self._config.database, ) + self._fields: list[str] = [] # List of fields in the collection + if self._client.check_table_exists(collection_name): + self._load_collection_fields() self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported def get_type(self) -> str: return VectorType.OCEANBASE + def _load_collection_fields(self): + """ + Load collection fields from the database table. + This method populates the _fields list with column names from the table. + """ + try: + if self._collection_name in self._client.metadata_obj.tables: + table = self._client.metadata_obj.tables[self._collection_name] + # Store all column names except 'id' (primary key) + self._fields = [column.name for column in table.columns if column.name != "id"] + logger.debug("Loaded fields for collection '%s': %s", self._collection_name, self._fields) + else: + logger.warning("Collection '%s' not found in metadata", self._collection_name) + except Exception as e: + logger.warning("Failed to load collection fields for '%s': %s", self._collection_name, str(e)) + + def field_exists(self, field: str) -> bool: + """ + Check if a field exists in the collection. + + :param field: Field name to check + :return: True if field exists, False otherwise + """ + return field in self._fields + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self._vec_dim = len(embeddings[0]) self._create_collection() @@ -151,6 +179,7 @@ class OceanBaseVector(BaseVector): logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name) self._client.refresh_metadata([self._collection_name]) + self._load_collection_fields() redis_client.set(collection_exist_cache_key, 1, ex=3600) def _check_hybrid_search_support(self) -> bool: @@ -177,42 +206,134 @@ class OceanBaseVector(BaseVector): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): ids = self._get_uuids(documents) for id, doc, emb in zip(ids, documents, embeddings): - self._client.insert( - table_name=self._collection_name, - data={ - "id": id, - "vector": emb, - "text": doc.page_content, - "metadata": doc.metadata, - }, - ) + try: + self._client.insert( + table_name=self._collection_name, + data={ + "id": id, + "vector": emb, + "text": doc.page_content, + "metadata": doc.metadata, + }, + ) + except Exception as e: + logger.exception( + "Failed to insert document with id '%s' in collection '%s'", + id, + self._collection_name, + ) + raise Exception(f"Failed to insert document with id '{id}'") from e def text_exists(self, id: str) -> bool: - cur = self._client.get(table_name=self._collection_name, ids=id) - return bool(cur.rowcount != 0) + try: + cur = self._client.get(table_name=self._collection_name, ids=id) + return bool(cur.rowcount != 0) + except Exception as e: + logger.exception( + "Failed to check if text exists with id '%s' in collection '%s'", + id, + self._collection_name, + ) + raise Exception(f"Failed to check text existence for id '{id}'") from e def delete_by_ids(self, ids: list[str]): if not ids: return - self._client.delete(table_name=self._collection_name, ids=ids) + try: + self._client.delete(table_name=self._collection_name, ids=ids) + logger.debug("Deleted %d documents from collection '%s'", len(ids), self._collection_name) + except Exception as e: + logger.exception( + "Failed to delete %d documents from collection '%s'", + len(ids), + self._collection_name, + ) + raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: - from sqlalchemy import text + try: + import re - cur = self._client.get( - table_name=self._collection_name, - ids=None, - where_clause=[text(f"metadata->>'$.{key}' = '{value}'")], - output_column_name=["id"], - ) - return [row[0] for row in cur] + from sqlalchemy import text + + # Validate key to prevent injection in JSON path + if not re.match(r"^[a-zA-Z0-9_.]+$", key): + raise ValueError(f"Invalid characters in metadata key: {key}") + + # Use parameterized query to prevent SQL injection + sql = text(f"SELECT id FROM `{self._collection_name}` WHERE metadata->>'$.{key}' = :value") + + with self._client.engine.connect() as conn: + result = conn.execute(sql, {"value": value}) + ids = [row[0] for row in result] + + logger.debug( + "Found %d documents with metadata field '%s'='%s' in collection '%s'", + len(ids), + key, + value, + self._collection_name, + ) + return ids + except Exception as e: + logger.exception( + "Failed to get IDs by metadata field '%s'='%s' in collection '%s'", + key, + value, + self._collection_name, + ) + raise Exception(f"Failed to query documents by metadata field '{key}'") from e def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) - self.delete_by_ids(ids) + if ids: + self.delete_by_ids(ids) + else: + logger.debug("No documents found to delete with metadata field '%s'='%s'", key, value) + + def _process_search_results( + self, results: list[tuple], score_threshold: float = 0.0, score_key: str = "score" + ) -> list[Document]: + """ + Common method to process search results + + :param results: Search results as list of tuples (text, metadata, score) + :param score_threshold: Score threshold for filtering + :param score_key: Key name for score in metadata + :return: List of documents + """ + docs = [] + for row in results: + text, metadata_str, score = row[0], row[1], row[2] + + # Parse metadata JSON + try: + metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str + except json.JSONDecodeError: + logger.warning("Invalid JSON metadata: %s", metadata_str) + metadata = {} + + # Add score to metadata + metadata[score_key] = score + + # Filter by score threshold + if score >= score_threshold: + docs.append(Document(page_content=text, metadata=metadata)) + + return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: if not self._hybrid_search_enabled: + logger.warning( + "Full-text search is disabled: set OCEANBASE_ENABLE_HYBRID_SEARCH=true (requires OceanBase >= 4.3.5.1)." + ) + return [] + if not self.field_exists("text"): + logger.warning( + "Full-text search unavailable: collection '%s' missing 'text' field; " + "recreate the collection after enabling OCEANBASE_ENABLE_HYBRID_SEARCH to add fulltext index.", + self._collection_name, + ) return [] try: @@ -220,13 +341,24 @@ class OceanBaseVector(BaseVector): if not isinstance(top_k, int) or top_k <= 0: raise ValueError("top_k must be a positive integer") - document_ids_filter = kwargs.get("document_ids_filter") - where_clause = "" - if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})" + score_threshold = float(kwargs.get("score_threshold") or 0.0) - full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score + # Build parameterized query to prevent SQL injection + from sqlalchemy import text + + document_ids_filter = kwargs.get("document_ids_filter") + params = {"query": query} + where_clause = "" + + if document_ids_filter: + # Create parameterized placeholders for document IDs + placeholders = ", ".join(f":doc_id_{i}" for i in range(len(document_ids_filter))) + where_clause = f" AND metadata->>'$.document_id' IN ({placeholders})" + # Add document IDs to parameters + for i, doc_id in enumerate(document_ids_filter): + params[f"doc_id_{i}"] = doc_id + + full_sql = f"""SELECT text, metadata, MATCH (text) AGAINST (:query) AS score FROM {self._collection_name} WHERE MATCH (text) AGAINST (:query) > 0 {where_clause} @@ -235,35 +367,35 @@ class OceanBaseVector(BaseVector): with self._client.engine.connect() as conn: with conn.begin(): - from sqlalchemy import text - - result = conn.execute(text(full_sql), {"query": query}) + result = conn.execute(text(full_sql), params) rows = result.fetchall() - docs = [] - for row in rows: - metadata_str, _text, score = row - try: - metadata = json.loads(metadata_str) - except json.JSONDecodeError: - logger.warning("Invalid JSON metadata: %s", metadata_str) - metadata = {} - metadata["score"] = score - docs.append(Document(page_content=_text, metadata=metadata)) - - return docs + return self._process_search_results(rows, score_threshold=score_threshold) except Exception as e: - logger.warning("Failed to fulltext search: %s.", str(e)) - return [] + logger.exception( + "Failed to perform full-text search on collection '%s' with query '%s'", + self._collection_name, + query, + ) + raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + from sqlalchemy import text + document_ids_filter = kwargs.get("document_ids_filter") _where_clause = None if document_ids_filter: + # Validate document IDs to prevent SQL injection + # Document IDs should be alphanumeric with hyphens and underscores + import re + + for doc_id in document_ids_filter: + if not isinstance(doc_id, str) or not re.match(r"^[a-zA-Z0-9_-]+$", doc_id): + raise ValueError(f"Invalid document ID format: {doc_id}") + + # Safe to use in query after validation document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) where_clause = f"metadata->>'$.document_id' in ({document_ids})" - from sqlalchemy import text - _where_clause = [text(where_clause)] ef_search = kwargs.get("ef_search", self._hnsw_ef_search) if ef_search != self._hnsw_ef_search: @@ -286,27 +418,27 @@ class OceanBaseVector(BaseVector): where_clause=_where_clause, ) except Exception as e: - raise Exception("Failed to search by vector. ", e) - docs = [] - for _text, metadata, distance in cur: + logger.exception( + "Failed to perform vector search on collection '%s'", + self._collection_name, + ) + raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e + + # Convert distance to score and prepare results for processing + results = [] + for _text, metadata_str, distance in cur: score = 1 - distance / math.sqrt(2) - if score >= score_threshold: - try: - metadata = json.loads(metadata) - except json.JSONDecodeError: - logger.warning("Invalid JSON metadata: %s", metadata) - metadata = {} - metadata["score"] = score - docs.append( - Document( - page_content=_text, - metadata=metadata, - ) - ) - return docs + results.append((_text, metadata_str, score)) + + return self._process_search_results(results, score_threshold=score_threshold) def delete(self): - self._client.drop_table_if_exist(self._collection_name) + try: + self._client.drop_table_if_exist(self._collection_name) + logger.debug("Dropped collection '%s'", self._collection_name) + except Exception as e: + logger.exception("Failed to delete collection '%s'", self._collection_name) + raise Exception(f"Failed to delete collection '{self._collection_name}'") from e class OceanBaseVectorFactory(AbstractVectorFactory):