mirror of https://github.com/langgenius/dify.git
add qdrant migrate to tidb
This commit is contained in:
parent
5ae1f62daf
commit
67eb632f1a
|
|
@ -395,29 +395,12 @@ def migrate_knowledge_vector_database():
|
|||
documents = []
|
||||
segments_count = 0
|
||||
for dataset_document in dataset_documents:
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
segments_count = segments_count + 1
|
||||
single_documents = vector.search_by_metadata_field("document_id", dataset_document.id)
|
||||
|
||||
if single_documents:
|
||||
documents.extend(single_documents)
|
||||
segments_count += len(single_documents)
|
||||
if documents:
|
||||
try:
|
||||
click.echo(
|
||||
|
|
@ -427,12 +410,12 @@ def migrate_knowledge_vector_database():
|
|||
fg="green",
|
||||
)
|
||||
)
|
||||
vector.create(documents)
|
||||
vector.create_with_vectors(documents)
|
||||
click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
|
||||
raise e
|
||||
db.session.add(dataset)
|
||||
db.session.add(instance=dataset)
|
||||
db.session.commit()
|
||||
click.echo(f"Successfully migrated dataset {dataset.id}.")
|
||||
create_count += 1
|
||||
|
|
|
|||
|
|
@ -313,6 +313,20 @@ class AlibabaCloudMySQLVector(BaseVector):
|
|||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s",
|
||||
(f"$.{key}", value),
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
docs.append(Document(page_content=record["text"], vector=record["embedding"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
|
|||
|
|
@ -58,6 +58,9 @@ class AnalyticdbVector(BaseVector):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_metadata_field(key, value, **kwargs)
|
||||
|
||||
def delete(self):
|
||||
self.analyticdb_vector.delete()
|
||||
|
||||
|
|
|
|||
|
|
@ -305,6 +305,34 @@ class AnalyticdbVectorOpenAPI:
|
|||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=True,
|
||||
metrics=self.config.metrics,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
top_k=999999,
|
||||
filter=f"metadata_->>'{key}' = '{value}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value if match.values else None,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
|
|
|||
|
|
@ -270,6 +270,23 @@ class AnalyticdbVectorBySql:
|
|||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT id, embedding, page_content, metadata_ FROM {self.table_name} WHERE metadata_->>%s = %s",
|
||||
(key, value),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
_, vector, page_content, metadata = record
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
|
|||
|
|
@ -198,6 +198,34 @@ class BaiduVector(BaseVector):
|
|||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
# Escape double quotes in value to prevent injection
|
||||
escaped_value = value.replace('"', '\\"')
|
||||
filter = f'metadata["{key}"] = "{escaped_value}"'
|
||||
|
||||
res = self._db.table(self._collection_name).select(
|
||||
projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY, VDBField.VECTOR_KEY],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
docs = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
meta = row_data.get(VDBField.METADATA_KEY, {})
|
||||
|
||||
if isinstance(meta, str):
|
||||
try:
|
||||
meta = json.loads(meta)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
meta = {}
|
||||
elif not isinstance(meta, dict):
|
||||
meta = {}
|
||||
|
||||
vector = row_data.get(VDBField.VECTOR_KEY)
|
||||
doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), vector=vector, metadata=meta)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
try:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
|
|
|
|||
|
|
@ -135,6 +135,30 @@ class ChromaVector(BaseVector):
|
|||
# chroma does not support BM25 full text searching
|
||||
return []
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
|
||||
# FIXME: fix the type error later
|
||||
results = collection.get(
|
||||
where={key: {"$eq": value}}, # type: ignore
|
||||
include=["documents", "metadatas", "embeddings"],
|
||||
)
|
||||
|
||||
if not results["ids"] or not results["documents"] or not results["metadatas"]:
|
||||
return []
|
||||
|
||||
docs = []
|
||||
for i, doc_id in enumerate(results["ids"]):
|
||||
metadata = dict(results["metadatas"][i]) if results["metadatas"][i] else {}
|
||||
vector = results["embeddings"][i] if results.get("embeddings") else None
|
||||
doc = Document(
|
||||
page_content=results["documents"][i],
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
class ChromaVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
|
|
|
|||
|
|
@ -1025,6 +1025,38 @@ class ClickzettaVector(BaseVector):
|
|||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by metadata field."""
|
||||
# Check if table exists first
|
||||
if not self._table_exists():
|
||||
logger.warning(
|
||||
"Table %s.%s does not exist, returning empty results",
|
||||
self._config.schema_name,
|
||||
self._table_name,
|
||||
)
|
||||
return []
|
||||
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ?
|
||||
"""
|
||||
|
||||
documents = []
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(search_sql, binding_params=[value])
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
metadata = self._parse_metadata(row[2], row[0])
|
||||
vector = row[3] if len(row) > 3 else None
|
||||
doc = Document(page_content=row[1], vector=vector, metadata=metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def _format_vector_simple(self, vector: list[float]) -> str:
|
||||
"""Simple vector formatting for SQL queries."""
|
||||
return ",".join(map(str, vector))
|
||||
|
|
|
|||
|
|
@ -325,6 +325,22 @@ class CouchbaseVector(BaseVector):
|
|||
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
query = f"""
|
||||
SELECT text, metadata, embedding FROM
|
||||
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
|
||||
WHERE metadata.{key} = $value
|
||||
"""
|
||||
result = self._cluster.query(query, named_parameters={"value": value}).execute()
|
||||
docs = []
|
||||
for row in result:
|
||||
text = row.get("text", "")
|
||||
metadata = row.get("metadata", {})
|
||||
vector = row.get("embedding")
|
||||
doc = Document(page_content=text, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
manager = self._bucket.collections()
|
||||
scopes = manager.get_all_scopes()
|
||||
|
|
|
|||
|
|
@ -249,6 +249,21 @@ class ElasticSearchVector(BaseVector):
|
|||
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"query": {"match": {f"metadata.{key}": value}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str, size=999999)
|
||||
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"].get(Field.VECTOR),
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
|
|
|
|||
|
|
@ -149,6 +149,21 @@ class HuaweiCloudVector(BaseVector):
|
|||
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"query": {"match": {f"metadata.{key}": value}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str, size=999999)
|
||||
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"].get(Field.VECTOR),
|
||||
metadata=hit["_source"].get(Field.METADATA_KEY, {}),
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
|
|
|
|||
|
|
@ -326,6 +326,31 @@ class LindormVectorStore(BaseVector):
|
|||
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
query: dict[str, Any] = {
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
|
||||
}
|
||||
if self._using_ugc:
|
||||
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
|
||||
|
||||
try:
|
||||
params: dict[str, Any] = {"timeout": self._client_config.request_timeout}
|
||||
if self._using_ugc:
|
||||
params["routing"] = self._routing
|
||||
response = self._client.search(index=self._collection_name, body=query, params=params, size=999999)
|
||||
except Exception:
|
||||
logger.exception("Error executing metadata field search, query: %s", query)
|
||||
raise
|
||||
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY) or {}
|
||||
vector = hit["_source"].get(Field.VECTOR)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
|
||||
):
|
||||
|
|
|
|||
|
|
@ -217,6 +217,27 @@ class MatrixoneVector(BaseVector):
|
|||
assert self.client is not None
|
||||
self.client.delete()
|
||||
|
||||
@ensure_client
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
assert self.client is not None
|
||||
|
||||
results = self.client.query_by_metadata(filter={key: value})
|
||||
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
vector = result.embedding if hasattr(result, "embedding") else None
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=result.document,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
class MatrixoneVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
|
||||
|
|
|
|||
|
|
@ -291,6 +291,30 @@ class MilvusVector(BaseVector):
|
|||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search for documents by metadata field key and value.
|
||||
"""
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
return []
|
||||
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name,
|
||||
filter=f'metadata["{key}"] == "{value}"',
|
||||
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY, Field.VECTOR],
|
||||
)
|
||||
|
||||
docs = []
|
||||
for item in result:
|
||||
metadata = item.get(Field.METADATA_KEY, {})
|
||||
doc = Document(
|
||||
page_content=item.get(Field.CONTENT_KEY, ""),
|
||||
vector=item.get(Field.VECTOR),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
|
||||
):
|
||||
|
|
|
|||
|
|
@ -156,6 +156,24 @@ class MyScaleVector(BaseVector):
|
|||
logger.exception("Vector search operation failed")
|
||||
return []
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata FROM {self._config.database}.{self._collection_name}
|
||||
WHERE metadata['{key}']='{value}'
|
||||
"""
|
||||
try:
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
vector=r["vector"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception:
|
||||
logger.exception("Metadata field search operation failed")
|
||||
return []
|
||||
|
||||
def delete(self):
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
|
||||
|
||||
|
|
|
|||
|
|
@ -440,6 +440,42 @@ class OceanBaseVector(BaseVector):
|
|||
logger.exception("Failed to delete collection '%s'", self._collection_name)
|
||||
raise Exception(f"Failed to delete collection '{self._collection_name}'") from e
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
try:
|
||||
import re
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
# Validate key to prevent injection in JSON path
|
||||
if not re.match(r"^[a-zA-Z0-9_.]+$", key):
|
||||
raise ValueError(f"Invalid characters in metadata key: {key}")
|
||||
|
||||
# Use parameterized query to prevent SQL injection
|
||||
sql = text(
|
||||
f"SELECT text, metadata, vector FROM `{self._collection_name}` "
|
||||
f"WHERE metadata->>'$.{key}' = :value"
|
||||
)
|
||||
|
||||
with self._client.engine.connect() as conn:
|
||||
result = conn.execute(sql, {"value": value})
|
||||
rows = result.fetchall()
|
||||
|
||||
docs = []
|
||||
for row in rows:
|
||||
text_content, metadata, vector = row
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
docs.append(Document(page_content=text_content, vector=vector, metadata=metadata))
|
||||
return docs
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to search by metadata field '%s'='%s' in collection '%s'",
|
||||
key,
|
||||
value,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to search by metadata field '{key}'") from e
|
||||
|
||||
|
||||
class OceanBaseVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(
|
||||
|
|
|
|||
|
|
@ -222,6 +222,18 @@ class OpenGauss(BaseVector):
|
|||
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding FROM {self.table_name} WHERE meta->>%s = %s",
|
||||
(key, value),
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
|
|||
|
|
@ -236,6 +236,19 @@ class OpenSearchVector(BaseVector):
|
|||
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}}
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query, size=999999)
|
||||
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY) or {}
|
||||
vector = hit["_source"].get(Field.VECTOR)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
|
||||
):
|
||||
|
|
|
|||
|
|
@ -338,6 +338,20 @@ class OracleVector(BaseVector):
|
|||
else:
|
||||
return [Document(page_content="", metadata={})]
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding FROM {self.table_name} WHERE JSON_VALUE(meta, '$.{key}') = :1",
|
||||
(value,),
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
|
|
|
|||
|
|
@ -210,6 +210,19 @@ class PGVectoRS(BaseVector):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT text, meta, embedding FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'"
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
|
||||
docs = []
|
||||
for record in result:
|
||||
doc = Document(page_content=record[0], vector=record[2], metadata=record[1])
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
class PGVectoRSFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
|
||||
|
|
|
|||
|
|
@ -242,6 +242,18 @@ class PGVector(BaseVector):
|
|||
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding FROM {self.table_name} WHERE meta->>%s = %s",
|
||||
(key, value),
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
|
|||
|
|
@ -199,6 +199,18 @@ class VastbaseVector(BaseVector):
|
|||
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding FROM {self.table_name} WHERE meta->>%s = %s",
|
||||
(key, value),
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
|
|||
|
|
@ -434,9 +434,42 @@ class QdrantVector(BaseVector):
|
|||
|
||||
return documents
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client._load()
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
try:
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=999999,
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
doc = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
except UnexpectedResponse as e:
|
||||
if e.status_code == 404:
|
||||
return []
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
|
|
|
|||
|
|
@ -294,6 +294,27 @@ class RelytVector(BaseVector):
|
|||
# milvus/zilliz/relyt doesn't support bm25 search
|
||||
return []
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
sql_query = f"""
|
||||
SELECT document, metadata, embedding
|
||||
FROM "{self._collection_name}"
|
||||
WHERE metadata->>'{key}' = :value
|
||||
"""
|
||||
params = {"value": value}
|
||||
|
||||
with self.client.connect() as conn:
|
||||
results = conn.execute(sql_text(sql_query), params).fetchall()
|
||||
|
||||
docs = []
|
||||
for result in results:
|
||||
doc = Document(
|
||||
page_content=result.document,
|
||||
vector=result.embedding,
|
||||
metadata=result.metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
class RelytVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
|
||||
|
|
|
|||
|
|
@ -390,6 +390,45 @@ class TableStoreVector(BaseVector):
|
|||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
# Search using tags field which stores key=value pairs
|
||||
tag_value = f"{key}={value}"
|
||||
|
||||
query = tablestore.SearchQuery(
|
||||
tablestore.TermQuery(self._tags_field, tag_value),
|
||||
limit=999999,
|
||||
get_total_count=False,
|
||||
)
|
||||
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
index_name=self._index_name,
|
||||
search_query=query,
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
# TableStore stores vector as JSON string, need to parse it
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return documents
|
||||
|
||||
|
||||
class TableStoreVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TableStoreVector:
|
||||
|
|
|
|||
|
|
@ -299,6 +299,28 @@ class TencentVector(BaseVector):
|
|||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
filter = Filter(Filter.In(f"metadata.{key}", [value]))
|
||||
res = self._client.query(
|
||||
database_name=self._client_config.database,
|
||||
collection_name=self.collection_name,
|
||||
filter=filter,
|
||||
retrieve_vector=True,
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
if res is None or len(res) == 0:
|
||||
return docs
|
||||
|
||||
for result in res:
|
||||
meta = result.get(self.field_metadata)
|
||||
if isinstance(meta, str):
|
||||
meta = json.loads(meta)
|
||||
vector = result.get(self.field_vector)
|
||||
doc = Document(page_content=result.get(self.field_text), vector=vector, metadata=meta)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
if self._has_collection():
|
||||
self._client.drop_collection(
|
||||
|
|
|
|||
|
|
@ -393,6 +393,42 @@ class TidbOnQdrantVector(BaseVector):
|
|||
|
||||
return documents
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
try:
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=999999,
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
metadata = result.payload.get(Field.METADATA_KEY) if result.payload else {}
|
||||
page_content = result.payload.get(Field.CONTENT_KEY, "") if result.payload else ""
|
||||
vector = result.vector if hasattr(result, "vector") else None
|
||||
documents.append(Document(page_content=page_content, vector=vector, metadata=metadata))
|
||||
|
||||
return documents
|
||||
except UnexpectedResponse as e:
|
||||
if e.status_code == 404:
|
||||
return []
|
||||
raise e
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client._load()
|
||||
|
|
|
|||
|
|
@ -237,6 +237,21 @@ class TiDBVector(BaseVector):
|
|||
# tidb doesn't support bm25 search
|
||||
return []
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(f"""
|
||||
SELECT meta, text, vector FROM {self._collection_name}
|
||||
WHERE meta->>'$.{key}' = :value
|
||||
""")
|
||||
res = session.execute(select_statement, params={"value": value})
|
||||
results = [(row[0], row[1], row[2]) for row in res]
|
||||
|
||||
docs = []
|
||||
for meta, text, vector in results:
|
||||
metadata = json.loads(meta)
|
||||
docs.append(Document(page_content=text, vector=vector, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
with Session(self._engine) as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||
|
|
|
|||
|
|
@ -117,6 +117,24 @@ class UpstashVector(BaseVector):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
query_result = self.index.query(
|
||||
vector=[1.001 * i for i in range(self._get_index_dimension())],
|
||||
include_metadata=True,
|
||||
include_data=True,
|
||||
include_vectors=True,
|
||||
top_k=999999,
|
||||
filter=f"{key} = '{value}'",
|
||||
)
|
||||
docs = []
|
||||
for record in query_result:
|
||||
metadata = record.metadata
|
||||
text = record.data
|
||||
vector = record.vector if hasattr(record, "vector") else None
|
||||
if metadata is not None and text is not None:
|
||||
docs.append(Document(page_content=text, vector=vector, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
self.index.reset()
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ class BaseVector(ABC):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -207,6 +207,23 @@ class Vector:
|
|||
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
||||
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
|
||||
|
||||
def create_with_vectors(self, texts: list[Document], **kwargs):
|
||||
"""
|
||||
Create documents with vectors.
|
||||
|
||||
Args:
|
||||
texts: List of documents.
|
||||
**kwargs: Keyword arguments.
|
||||
"""
|
||||
embeddings = []
|
||||
embedding_texts = []
|
||||
for text in texts:
|
||||
if text.vector:
|
||||
embeddings.append(text.vector)
|
||||
embedding_texts.append(text)
|
||||
if embeddings and embedding_texts:
|
||||
self._vector_processor.create(texts=embedding_texts, embeddings=embeddings, **kwargs)
|
||||
|
||||
def create_multimodal(self, file_documents: list | None = None, **kwargs):
|
||||
if file_documents:
|
||||
start = time.time()
|
||||
|
|
|
|||
|
|
@ -202,6 +202,30 @@ class VikingDBVector(BaseVector):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
# Query by metadata field using filter on group_id and matching metadata
|
||||
results = self._client.get_index(self._collection_name, self._index_name).search(
|
||||
filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]},
|
||||
limit=5000, # max value is 5000
|
||||
)
|
||||
|
||||
if not results:
|
||||
return []
|
||||
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY)
|
||||
if metadata is not None:
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
if metadata.get(key) == value:
|
||||
vector = result.fields.get(vdb_Field.VECTOR_KEY)
|
||||
doc = Document(
|
||||
page_content=result.fields.get(vdb_Field.CONTENT_KEY), vector=vector, metadata=metadata
|
||||
)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
if self._has_index():
|
||||
self._client.drop_index(self._collection_name, self._index_name)
|
||||
|
|
|
|||
|
|
@ -442,6 +442,30 @@ class WeaviateVector(BaseVector):
|
|||
return value.isoformat()
|
||||
return value
|
||||
|
||||
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
|
||||
"""Searches for documents matching a specific metadata field value."""
|
||||
if not self._client.collections.exists(self._collection_name):
|
||||
return []
|
||||
|
||||
col = self._client.collections.use(self._collection_name)
|
||||
props = list({*self._attributes, Field.TEXT_KEY.value})
|
||||
|
||||
res = col.query.fetch_objects(
|
||||
filters=Filter.by_property(key).equal(value),
|
||||
limit=999999,
|
||||
return_properties=props,
|
||||
include_vector=True,
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
for obj in res.objects:
|
||||
properties = dict(obj.properties or {})
|
||||
text = properties.pop(Field.TEXT_KEY.value, "")
|
||||
vector = obj.vector.get("default") if obj.vector else None
|
||||
docs.append(Document(page_content=text, vector=vector, metadata=properties))
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
"""Factory class for creating WeaviateVector instances."""
|
||||
|
|
|
|||
Loading…
Reference in New Issue