diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 17aac25b87..2bafe999c8 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -33,6 +33,18 @@ class SortOrder(StrEnum): class MyScaleVector(BaseVector): + _METADATA_KEY_WHITELIST = { + "annotation_id", + "app_id", + "batch", + "dataset_id", + "doc_hash", + "doc_id", + "document_id", + "lang", + "source", + } + def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"): super().__init__(collection_name) self._config = config @@ -45,10 +57,17 @@ class MyScaleVector(BaseVector): password=config.password, ) self._client.command("SET allow_experimental_object_type=1") + self._qualified_table = f"{self._config.database}.{self._collection_name}" def get_type(self) -> str: return VectorType.MYSCALE + @classmethod + def _validate_metadata_key(cls, key: str) -> str: + if key not in cls._METADATA_KEY_WHITELIST: + raise ValueError(f"Unsupported metadata key: {key!r}") + return key + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) @@ -59,7 +78,7 @@ class MyScaleVector(BaseVector): self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}") fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else "" sql = f""" - CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}( + CREATE TABLE IF NOT EXISTS {self._qualified_table}( id String, text String, vector Array(Float32), @@ -74,23 +93,21 @@ class MyScaleVector(BaseVector): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): ids = [] columns = ["id", "text", "vector", "metadata"] - values = [] + rows = [] for i, doc in enumerate(documents): if doc.metadata is not None: doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - row = ( - doc_id, - self.escape_str(doc.page_content), - embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {}, + rows.append( + ( + doc_id, + self.escape_str(doc.page_content), + embeddings[i], + json.dumps(doc.metadata or {}), + ) ) - values.append(str(row)) ids.append(doc_id) - sql = f""" - INSERT INTO {self._config.database}.{self._collection_name} - ({",".join(columns)}) VALUES {",".join(values)} - """ - self._client.command(sql) + if rows: + self._client.insert(self._qualified_table, rows, column_names=columns) return ids @staticmethod @@ -98,49 +115,80 @@ class MyScaleVector(BaseVector): return "".join(" " if c in {"\\", "'"} else c for c in str(value)) def text_exists(self, id: str) -> bool: - results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") + results = self._client.query( + f"SELECT id FROM {self._qualified_table} WHERE id = %(id)s LIMIT 1", + parameters={"id": id}, + ) return results.row_count > 0 def delete_by_ids(self, ids: list[str]): if not ids: return + placeholders, params = self._build_in_params("id", ids) self._client.command( - f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" + f"DELETE FROM {self._qualified_table} WHERE id IN ({placeholders})", + parameters=params, ) def get_ids_by_metadata_field(self, key: str, value: str): + safe_key = self._validate_metadata_key(key) rows = self._client.query( - f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'" + f"SELECT DISTINCT id FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s", + parameters={"value": value}, ).result_rows return [row[0] for row in rows] def delete_by_metadata_field(self, key: str, value: str): + safe_key = self._validate_metadata_key(key) self._client.command( - f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'" + f"DELETE FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s", + parameters={"value": value}, ) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs) + return self._search( + "TextSearch('enable_nlq=false')(text, %(query)s)", + SortOrder.DESC, + parameters={"query": query}, + **kwargs, + ) - def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: + @staticmethod + def _build_in_params(prefix: str, values: list[str]) -> tuple[str, dict[str, str]]: + params: dict[str, str] = {} + placeholders = [] + for i, value in enumerate(values): + name = f"{prefix}_{i}" + placeholders.append(f"%({name})s") + params[name] = value + return ", ".join(placeholders), params + + def _search( + self, + dist: str, + order: SortOrder, + parameters: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Document]: top_k = kwargs.get("top_k", 4) if not isinstance(top_k, int) or top_k <= 0: raise ValueError("top_k must be a positive integer") score_threshold = float(kwargs.get("score_threshold") or 0.0) - where_str = ( - f"WHERE dist < {1 - score_threshold}" - if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 - else "" - ) + where_clauses = [] + if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0: + where_clauses.append(f"dist < {1 - score_threshold}") document_ids_filter = kwargs.get("document_ids_filter") + query_params = dict(parameters or {}) if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})" + placeholders, params = self._build_in_params("document_id", document_ids_filter) + where_clauses.append(f"metadata['document_id'] IN ({placeholders})") + query_params.update(params) + where_str = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" sql = f""" - SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} + SELECT text, vector, metadata, {dist} as dist FROM {self._qualified_table} {where_str} ORDER BY dist {order.value} LIMIT {top_k} """ try: @@ -150,14 +198,14 @@ class MyScaleVector(BaseVector): vector=r["vector"], metadata=r["metadata"], ) - for r in self._client.query(sql).named_results() + for r in self._client.query(sql, parameters=query_params).named_results() ] except Exception: logger.exception("Vector search operation failed") return [] def delete(self): - self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}") + self._client.command(f"DROP TABLE IF EXISTS {self._qualified_table}") class MyScaleVectorFactory(AbstractVectorFactory):