diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 095752ea8e..6f3e15d166 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" + score_threshold = kwargs.get("score_threshold") or 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI: vector=query_vector, content=None, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = [] @@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" score_threshold = float(kwargs.get("score_threshold") or 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI: vector=None, content=query, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = []