From 67eb632f1a8002a6ba893f7e622836f826ba9585 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 9 Dec 2025 23:52:00 +0800 Subject: [PATCH] add qdrant migrate to tidb --- api/commands.py | 29 +++----------- .../alibabacloud_mysql_vector.py | 14 +++++++ .../vdb/analyticdb/analyticdb_vector.py | 3 ++ .../analyticdb/analyticdb_vector_openapi.py | 28 +++++++++++++ .../vdb/analyticdb/analyticdb_vector_sql.py | 17 ++++++++ .../rag/datasource/vdb/baidu/baidu_vector.py | 28 +++++++++++++ .../datasource/vdb/chroma/chroma_vector.py | 24 ++++++++++++ .../vdb/clickzetta/clickzetta_vector.py | 32 +++++++++++++++ .../vdb/couchbase/couchbase_vector.py | 16 ++++++++ .../vdb/elasticsearch/elasticsearch_vector.py | 15 +++++++ .../vdb/huawei/huawei_cloud_vector.py | 15 +++++++ .../datasource/vdb/lindorm/lindorm_vector.py | 25 ++++++++++++ .../vdb/matrixone/matrixone_vector.py | 21 ++++++++++ .../datasource/vdb/milvus/milvus_vector.py | 24 ++++++++++++ .../datasource/vdb/myscale/myscale_vector.py | 18 +++++++++ .../vdb/oceanbase/oceanbase_vector.py | 36 +++++++++++++++++ .../rag/datasource/vdb/opengauss/opengauss.py | 12 ++++++ .../vdb/opensearch/opensearch_vector.py | 13 +++++++ .../rag/datasource/vdb/oracle/oraclevector.py | 14 +++++++ .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 13 +++++++ .../rag/datasource/vdb/pgvector/pgvector.py | 12 ++++++ .../vdb/pyvastbase/vastbase_vector.py | 12 ++++++ .../datasource/vdb/qdrant/qdrant_vector.py | 39 +++++++++++++++++-- .../rag/datasource/vdb/relyt/relyt_vector.py | 21 ++++++++++ .../vdb/tablestore/tablestore_vector.py | 39 +++++++++++++++++++ .../datasource/vdb/tencent/tencent_vector.py | 22 +++++++++++ .../tidb_on_qdrant/tidb_on_qdrant_vector.py | 36 +++++++++++++++++ .../datasource/vdb/tidb_vector/tidb_vector.py | 15 +++++++ .../datasource/vdb/upstash/upstash_vector.py | 18 +++++++++ api/core/rag/datasource/vdb/vector_base.py | 4 ++ api/core/rag/datasource/vdb/vector_factory.py | 17 ++++++++ .../vdb/vikingdb/vikingdb_vector.py | 24 ++++++++++++ .../vdb/weaviate/weaviate_vector.py | 24 ++++++++++++ 33 files changed, 654 insertions(+), 26 deletions(-) diff --git a/api/commands.py b/api/commands.py index 3ac09a9415..4d15171034 100644 --- a/api/commands.py +++ b/api/commands.py @@ -395,29 +395,12 @@ def migrate_knowledge_vector_database(): documents = [] segments_count = 0 for dataset_document in dataset_documents: - - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - ) - ).all() - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - - documents.append(document) - segments_count = segments_count + 1 + single_documents = vector.search_by_metadata_field("document_id", dataset_document.id) + if single_documents: + documents.extend(single_documents) + segments_count += len(single_documents) if documents: try: click.echo( @@ -427,12 +410,12 @@ def migrate_knowledge_vector_database(): fg="green", ) ) - vector.create(documents) + vector.create_with_vectors(documents) click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green")) except Exception as e: click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) raise e - db.session.add(dataset) + db.session.add(instance=dataset) db.session.commit() click.echo(f"Successfully migrated dataset {dataset.id}.") create_count += 1 diff --git a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py b/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py index fdb5ffebfc..73d96604ea 100644 --- a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py +++ b/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py @@ -313,6 +313,20 @@ class AlibabaCloudMySQLVector(BaseVector): docs.append(Document(page_content=record["text"], metadata=metadata)) return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + with self._get_cursor() as cur: + cur.execute( + f"SELECT meta, text, embedding FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", + (f"$.{key}", value), + ) + docs = [] + for record in cur: + metadata = record["meta"] + if isinstance(metadata, str): + metadata = json.loads(metadata) + docs.append(Document(page_content=record["text"], vector=record["embedding"], metadata=metadata)) + return docs + def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index ddb549ba9d..4b84001f84 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -58,6 +58,9 @@ class AnalyticdbVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self.analyticdb_vector.search_by_full_text(query, **kwargs) + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + return self.analyticdb_vector.search_by_metadata_field(key, value, **kwargs) + def delete(self): self.analyticdb_vector.delete() diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 77a0fa6cf2..6bee6c0704 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -305,6 +305,34 @@ class AnalyticdbVectorOpenAPI: documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + + request = gpdb_20160503_models.QueryCollectionDataRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + collection=self._collection_name, + include_values=True, + metrics=self.config.metrics, + vector=None, # ty: ignore [invalid-argument-type] + content=None, # ty: ignore [invalid-argument-type] + top_k=999999, + filter=f"metadata_->>'{key}' = '{value}'", + ) + response = self._client.query_collection_data(request) + documents = [] + for match in response.body.matches.match: + metadata = json.loads(match.metadata.get("metadata_")) + doc = Document( + page_content=match.metadata.get("page_content"), + vector=match.values.value if match.values else None, + metadata=metadata, + ) + documents.append(doc) + return documents + def delete(self): try: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 12126f32d6..0c86d9b1b5 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -270,6 +270,23 @@ class AnalyticdbVectorBySql: documents.append(doc) return documents + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + with self._get_cursor() as cur: + cur.execute( + f"SELECT id, embedding, page_content, metadata_ FROM {self.table_name} WHERE metadata_->>%s = %s", + (key, value), + ) + documents = [] + for record in cur: + _, vector, page_content, metadata = record + doc = Document( + page_content=page_content, + vector=vector, + metadata=metadata, + ) + documents.append(doc) + return documents + def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 144d834495..3e6c2d1a30 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -198,6 +198,34 @@ class BaiduVector(BaseVector): docs.append(doc) return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + # Escape double quotes in value to prevent injection + escaped_value = value.replace('"', '\\"') + filter = f'metadata["{key}"] = "{escaped_value}"' + + res = self._db.table(self._collection_name).select( + projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY, VDBField.VECTOR_KEY], + filter=filter, + ) + + docs = [] + for row in res.rows: + row_data = row.get("row", {}) + meta = row_data.get(VDBField.METADATA_KEY, {}) + + if isinstance(meta, str): + try: + meta = json.loads(meta) + except (json.JSONDecodeError, TypeError): + meta = {} + elif not isinstance(meta, dict): + meta = {} + + vector = row_data.get(VDBField.VECTOR_KEY) + doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), vector=vector, metadata=meta) + docs.append(doc) + return docs + def delete(self): try: self._db.drop_table(table_name=self._collection_name) diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index de1572410c..e0d894319e 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -135,6 +135,30 @@ class ChromaVector(BaseVector): # chroma does not support BM25 full text searching return [] + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + collection = self._client.get_or_create_collection(self._collection_name) + + # FIXME: fix the type error later + results = collection.get( + where={key: {"$eq": value}}, # type: ignore + include=["documents", "metadatas", "embeddings"], + ) + + if not results["ids"] or not results["documents"] or not results["metadatas"]: + return [] + + docs = [] + for i, doc_id in enumerate(results["ids"]): + metadata = dict(results["metadatas"][i]) if results["metadatas"][i] else {} + vector = results["embeddings"][i] if results.get("embeddings") else None + doc = Document( + page_content=results["documents"][i], + vector=vector, + metadata=metadata, + ) + docs.append(doc) + return docs + class ChromaVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index a306f9ba0c..cb21cca610 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -1025,6 +1025,38 @@ class ClickzettaVector(BaseVector): with connection.cursor() as cursor: cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + """Search for documents by metadata field.""" + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + + # Use json_extract_string function for ClickZetta compatibility + search_sql = f""" + SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR} + FROM {self._config.schema_name}.{self._table_name} + WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ? + """ + + documents = [] + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(search_sql, binding_params=[value]) + results = cursor.fetchall() + + for row in results: + metadata = self._parse_metadata(row[2], row[0]) + vector = row[3] if len(row) > 3 else None + doc = Document(page_content=row[1], vector=vector, metadata=metadata) + documents.append(doc) + + return documents + def _format_vector_simple(self, vector: list[float]) -> str: """Simple vector formatting for SQL queries.""" return ",".join(map(str, vector)) diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index 6df909ca94..d9d3c0eb63 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -325,6 +325,22 @@ class CouchbaseVector(BaseVector): return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + query = f""" + SELECT text, metadata, embedding FROM + `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + WHERE metadata.{key} = $value + """ + result = self._cluster.query(query, named_parameters={"value": value}).execute() + docs = [] + for row in result: + text = row.get("text", "") + metadata = row.get("metadata", {}) + vector = row.get("embedding") + doc = Document(page_content=text, vector=vector, metadata=metadata) + docs.append(doc) + return docs + def delete(self): manager = self._bucket.collections() scopes = manager.get_all_scopes() diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 1470713b88..7fe38811c3 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -249,6 +249,21 @@ class ElasticSearchVector(BaseVector): return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + query_str = {"query": {"match": {f"metadata.{key}": value}}} + results = self._client.search(index=self._collection_name, body=query_str, size=999999) + + docs = [] + for hit in results["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"].get(Field.VECTOR), + metadata=hit["_source"][Field.METADATA_KEY], + ) + ) + return docs + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index c7b6593a8f..3b1d6c5115 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -149,6 +149,21 @@ class HuaweiCloudVector(BaseVector): return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + query_str = {"query": {"match": {f"metadata.{key}": value}}} + results = self._client.search(index=self._collection_name, body=query_str, size=999999) + + docs = [] + for hit in results["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"].get(Field.VECTOR), + metadata=hit["_source"].get(Field.METADATA_KEY, {}), + ) + ) + return docs + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index bfcb620618..993d07896e 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -326,6 +326,31 @@ class LindormVectorStore(BaseVector): return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + query: dict[str, Any] = { + "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}} + } + if self._using_ugc: + query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}}) + + try: + params: dict[str, Any] = {"timeout": self._client_config.request_timeout} + if self._using_ugc: + params["routing"] = self._routing + response = self._client.search(index=self._collection_name, body=query, params=params, size=999999) + except Exception: + logger.exception("Error executing metadata field search, query: %s", query) + raise + + docs = [] + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY) or {} + vector = hit["_source"].get(Field.VECTOR) + page_content = hit["_source"].get(Field.CONTENT_KEY) + doc = Document(page_content=page_content, vector=vector, metadata=metadata) + docs.append(doc) + return docs + def create_collection( self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None ): diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 14955c8d7c..e59c0c7d66 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -217,6 +217,27 @@ class MatrixoneVector(BaseVector): assert self.client is not None self.client.delete() + @ensure_client + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + assert self.client is not None + + results = self.client.query_by_metadata(filter={key: value}) + + docs = [] + for result in results: + metadata = result.metadata + if isinstance(metadata, str): + metadata = json.loads(metadata) + vector = result.embedding if hasattr(result, "embedding") else None + docs.append( + Document( + page_content=result.document, + vector=vector, + metadata=metadata, + ) + ) + return docs + class MatrixoneVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 96eb465401..a05fc4a6cf 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -291,6 +291,30 @@ class MilvusVector(BaseVector): score_threshold=float(kwargs.get("score_threshold") or 0.0), ) + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + """ + Search for documents by metadata field key and value. + """ + if not self._client.has_collection(self._collection_name): + return [] + + result = self._client.query( + collection_name=self._collection_name, + filter=f'metadata["{key}"] == "{value}"', + output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY, Field.VECTOR], + ) + + docs = [] + for item in result: + metadata = item.get(Field.METADATA_KEY, {}) + doc = Document( + page_content=item.get(Field.CONTENT_KEY, ""), + vector=item.get(Field.VECTOR), + metadata=metadata, + ) + docs.append(doc) + return docs + def create_collection( self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None ): diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 17aac25b87..c356162b1b 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -156,6 +156,24 @@ class MyScaleVector(BaseVector): logger.exception("Vector search operation failed") return [] + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + sql = f""" + SELECT text, vector, metadata FROM {self._config.database}.{self._collection_name} + WHERE metadata['{key}']='{value}' + """ + try: + return [ + Document( + page_content=r["text"], + vector=r["vector"], + metadata=r["metadata"], + ) + for r in self._client.query(sql).named_results() + ] + except Exception: + logger.exception("Metadata field search operation failed") + return [] + def delete(self): self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}") diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index dc3b70140b..934a246fc6 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -440,6 +440,42 @@ class OceanBaseVector(BaseVector): logger.exception("Failed to delete collection '%s'", self._collection_name) raise Exception(f"Failed to delete collection '{self._collection_name}'") from e + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + try: + import re + + from sqlalchemy import text + + # Validate key to prevent injection in JSON path + if not re.match(r"^[a-zA-Z0-9_.]+$", key): + raise ValueError(f"Invalid characters in metadata key: {key}") + + # Use parameterized query to prevent SQL injection + sql = text( + f"SELECT text, metadata, vector FROM `{self._collection_name}` " + f"WHERE metadata->>'$.{key}' = :value" + ) + + with self._client.engine.connect() as conn: + result = conn.execute(sql, {"value": value}) + rows = result.fetchall() + + docs = [] + for row in rows: + text_content, metadata, vector = row + if isinstance(metadata, str): + metadata = json.loads(metadata) + docs.append(Document(page_content=text_content, vector=vector, metadata=metadata)) + return docs + except Exception as e: + logger.exception( + "Failed to search by metadata field '%s'='%s' in collection '%s'", + key, + value, + self._collection_name, + ) + raise Exception(f"Failed to search by metadata field '{key}'") from e + class OceanBaseVectorFactory(AbstractVectorFactory): def init_vector( diff --git a/api/core/rag/datasource/vdb/opengauss/opengauss.py b/api/core/rag/datasource/vdb/opengauss/opengauss.py index f9dbfbeeaf..af313a0924 100644 --- a/api/core/rag/datasource/vdb/opengauss/opengauss.py +++ b/api/core/rag/datasource/vdb/opengauss/opengauss.py @@ -222,6 +222,18 @@ class OpenGauss(BaseVector): return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + with self._get_cursor() as cur: + cur.execute( + f"SELECT meta, text, embedding FROM {self.table_name} WHERE meta->>%s = %s", + (key, value), + ) + docs = [] + for record in cur: + metadata, text, embedding = record + docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) + return docs + def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 2f77776807..d327986657 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -236,6 +236,19 @@ class OpenSearchVector(BaseVector): return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}} + response = self._client.search(index=self._collection_name.lower(), body=query, size=999999) + + docs = [] + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY) or {} + vector = hit["_source"].get(Field.VECTOR) + page_content = hit["_source"].get(Field.CONTENT_KEY) + doc = Document(page_content=page_content, vector=vector, metadata=metadata) + docs.append(doc) + return docs + def create_collection( self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None ): diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index d82ab89a34..74a79164b9 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -338,6 +338,20 @@ class OracleVector(BaseVector): else: return [Document(page_content="", metadata={})] + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute( + f"SELECT meta, text, embedding FROM {self.table_name} WHERE JSON_VALUE(meta, '$.{key}') = :1", + (value,), + ) + docs = [] + for record in cur: + metadata, text, embedding = record + docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) + conn.close() + return docs + def delete(self): with self._get_connection() as conn: with conn.cursor() as cur: diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index b986c79e3a..79046ad441 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -210,6 +210,19 @@ class PGVectoRS(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + with Session(self._client) as session: + select_statement = sql_text( + f"SELECT text, meta, embedding FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'" + ) + result = session.execute(select_statement).fetchall() + + docs = [] + for record in result: + doc = Document(page_content=record[0], vector=record[2], metadata=record[1]) + docs.append(doc) + return docs + class PGVectoRSFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS: diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 445a0a7f8b..f7ef3895a0 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -242,6 +242,18 @@ class PGVector(BaseVector): return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + with self._get_cursor() as cur: + cur.execute( + f"SELECT meta, text, embedding FROM {self.table_name} WHERE meta->>%s = %s", + (key, value), + ) + docs = [] + for record in cur: + metadata, text, embedding = record + docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) + return docs + def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py index 86b6ace3f6..16b6a755bb 100644 --- a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py +++ b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py @@ -199,6 +199,18 @@ class VastbaseVector(BaseVector): return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + with self._get_cursor() as cur: + cur.execute( + f"SELECT meta, text, embedding FROM {self.table_name} WHERE meta->>%s = %s", + (key, value), + ) + docs = [] + for record in cur: + metadata, text, embedding = record + docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) + return docs + def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f8c62b908a..0ab594ae27 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -434,9 +434,42 @@ class QdrantVector(BaseVector): return documents - def _reload_if_needed(self): - if isinstance(self._client, QdrantLocal): - self._client._load() + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + from qdrant_client.http import models + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ] + ) + + response = self._client.scroll( + collection_name=self._collection_name, + scroll_filter=scroll_filter, + limit=999999, + with_payload=True, + with_vectors=True, + ) + results = response[0] + documents = [] + for result in results: + if result: + doc = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY) + documents.append(doc) + return documents + except UnexpectedResponse as e: + if e.status_code == 404: + return [] + raise e @classmethod def _document_from_scored_point( diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 70857b3e3c..f3b0ff5b1a 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -294,6 +294,27 @@ class RelytVector(BaseVector): # milvus/zilliz/relyt doesn't support bm25 search return [] + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + sql_query = f""" + SELECT document, metadata, embedding + FROM "{self._collection_name}" + WHERE metadata->>'{key}' = :value + """ + params = {"value": value} + + with self.client.connect() as conn: + results = conn.execute(sql_text(sql_query), params).fetchall() + + docs = [] + for result in results: + doc = Document( + page_content=result.document, + vector=result.embedding, + metadata=result.metadata, + ) + docs.append(doc) + return docs + class RelytVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector: diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index f2156afa59..a994476b61 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -390,6 +390,45 @@ class TableStoreVector(BaseVector): documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + # Search using tags field which stores key=value pairs + tag_value = f"{key}={value}" + + query = tablestore.SearchQuery( + tablestore.TermQuery(self._tags_field, tag_value), + limit=999999, + get_total_count=False, + ) + + search_response = self._tablestore_client.search( + table_name=self._table_name, + index_name=self._index_name, + search_query=query, + columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), + ) + + documents = [] + for search_hit in search_response.search_hits: + ots_column_map = {} + for col in search_hit.row[1]: + ots_column_map[col[0]] = col[1] + + metadata_str = ots_column_map.get(Field.METADATA_KEY) + metadata = json.loads(metadata_str) if metadata_str else {} + + vector_str = ots_column_map.get(Field.VECTOR) + # TableStore stores vector as JSON string, need to parse it + vector = json.loads(vector_str) if vector_str else None + + documents.append( + Document( + page_content=ots_column_map.get(Field.CONTENT_KEY) or "", + vector=vector, + metadata=metadata, + ) + ) + return documents + class TableStoreVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TableStoreVector: diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 291d047c04..d3cabf287c 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -299,6 +299,28 @@ class TencentVector(BaseVector): docs.append(doc) return docs + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + filter = Filter(Filter.In(f"metadata.{key}", [value])) + res = self._client.query( + database_name=self._client_config.database, + collection_name=self.collection_name, + filter=filter, + retrieve_vector=True, + ) + + docs: list[Document] = [] + if res is None or len(res) == 0: + return docs + + for result in res: + meta = result.get(self.field_metadata) + if isinstance(meta, str): + meta = json.loads(meta) + vector = result.get(self.field_vector) + doc = Document(page_content=result.get(self.field_text), vector=vector, metadata=meta) + docs.append(doc) + return docs + def delete(self): if self._has_collection(): self._client.drop_collection( diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 56ffb36a2b..b22e322268 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -393,6 +393,42 @@ class TidbOnQdrantVector(BaseVector): return documents + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + from qdrant_client.http import models + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ] + ) + + response = self._client.scroll( + collection_name=self._collection_name, + scroll_filter=scroll_filter, + limit=999999, + with_payload=True, + with_vectors=True, + ) + results = response[0] + documents = [] + for result in results: + if result: + metadata = result.payload.get(Field.METADATA_KEY) if result.payload else {} + page_content = result.payload.get(Field.CONTENT_KEY, "") if result.payload else "" + vector = result.vector if hasattr(result, "vector") else None + documents.append(Document(page_content=page_content, vector=vector, metadata=metadata)) + + return documents + except UnexpectedResponse as e: + if e.status_code == 404: + return [] + raise e + def _reload_if_needed(self): if isinstance(self._client, QdrantLocal): self._client._load() diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 27ae038a06..14a42d0f9d 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -237,6 +237,21 @@ class TiDBVector(BaseVector): # tidb doesn't support bm25 search return [] + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + with Session(self._engine) as session: + select_statement = sql_text(f""" + SELECT meta, text, vector FROM {self._collection_name} + WHERE meta->>'$.{key}' = :value + """) + res = session.execute(select_statement, params={"value": value}) + results = [(row[0], row[1], row[2]) for row in res] + + docs = [] + for meta, text, vector in results: + metadata = json.loads(meta) + docs.append(Document(page_content=text, vector=vector, metadata=metadata)) + return docs + def delete(self): with Session(self._engine) as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/core/rag/datasource/vdb/upstash/upstash_vector.py index 289d971853..8b577f48f5 100644 --- a/api/core/rag/datasource/vdb/upstash/upstash_vector.py +++ b/api/core/rag/datasource/vdb/upstash/upstash_vector.py @@ -117,6 +117,24 @@ class UpstashVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + query_result = self.index.query( + vector=[1.001 * i for i in range(self._get_index_dimension())], + include_metadata=True, + include_data=True, + include_vectors=True, + top_k=999999, + filter=f"{key} = '{value}'", + ) + docs = [] + for record in query_result: + metadata = record.metadata + text = record.data + vector = record.vector if hasattr(record, "vector") else None + if metadata is not None and text is not None: + docs.append(Document(page_content=text, vector=vector, metadata=metadata)) + return docs + def delete(self): self.index.reset() diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 469978224a..e3c959e4d5 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -45,6 +45,10 @@ class BaseVector(ABC): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError + @abstractmethod + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + raise NotImplementedError + @abstractmethod def delete(self): raise NotImplementedError diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index c3fec09f0f..53b8af604e 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -207,6 +207,23 @@ class Vector: self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) logger.info("Embedding %s texts took %s s", len(texts), time.time() - start) + def create_with_vectors(self, texts: list[Document], **kwargs): + """ + Create documents with vectors. + + Args: + texts: List of documents. + **kwargs: Keyword arguments. + """ + embeddings = [] + embedding_texts = [] + for text in texts: + if text.vector: + embeddings.append(text.vector) + embedding_texts.append(text) + if embeddings and embedding_texts: + self._vector_processor.create(texts=embedding_texts, embeddings=embeddings, **kwargs) + def create_multimodal(self, file_documents: list | None = None, **kwargs): if file_documents: start = time.time() diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index e5feecf2bc..f7cdd9929d 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -202,6 +202,30 @@ class VikingDBVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + # Query by metadata field using filter on group_id and matching metadata + results = self._client.get_index(self._collection_name, self._index_name).search( + filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]}, + limit=5000, # max value is 5000 + ) + + if not results: + return [] + + docs = [] + for result in results: + metadata = result.fields.get(vdb_Field.METADATA_KEY) + if metadata is not None: + if isinstance(metadata, str): + metadata = json.loads(metadata) + if metadata.get(key) == value: + vector = result.fields.get(vdb_Field.VECTOR_KEY) + doc = Document( + page_content=result.fields.get(vdb_Field.CONTENT_KEY), vector=vector, metadata=metadata + ) + docs.append(doc) + return docs + def delete(self): if self._has_index(): self._client.drop_index(self._collection_name, self._index_name) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 84d1e26b34..f770d9d246 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -442,6 +442,30 @@ class WeaviateVector(BaseVector): return value.isoformat() return value + def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]: + """Searches for documents matching a specific metadata field value.""" + if not self._client.collections.exists(self._collection_name): + return [] + + col = self._client.collections.use(self._collection_name) + props = list({*self._attributes, Field.TEXT_KEY.value}) + + res = col.query.fetch_objects( + filters=Filter.by_property(key).equal(value), + limit=999999, + return_properties=props, + include_vector=True, + ) + + docs: list[Document] = [] + for obj in res.objects: + properties = dict(obj.properties or {}) + text = properties.pop(Field.TEXT_KEY.value, "") + vector = obj.vector.get("default") if obj.vector else None + docs.append(Document(page_content=text, vector=vector, metadata=properties)) + + return docs + class WeaviateVectorFactory(AbstractVectorFactory): """Factory class for creating WeaviateVector instances."""