refactor(api): use sessionmaker in pgvecto_rs VDB service (#34818)

This commit is contained in:
carlos4s 2026-04-09 00:49:03 -05:00 committed by GitHub
parent 5f53748d07
commit d360929af1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 20 deletions

View File

@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
from sqlalchemy import Float, create_engine, insert, select, text
from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker
from configs import dify_config
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
@ -55,9 +55,8 @@ class PGVectoRS(BaseVector):
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self._client = create_engine(self._url)
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
session.commit()
self._fields: list[str] = []
class _Table(CollectionORM):
@ -88,7 +87,7 @@ class PGVectoRS(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id UUID PRIMARY KEY,
@ -111,12 +110,11 @@ class PGVectoRS(BaseVector):
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
pks = []
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
for document, embedding in zip(documents, embeddings):
pk = uuid4()
session.execute(
@ -128,7 +126,6 @@ class PGVectoRS(BaseVector):
),
)
pks.append(pk)
session.commit()
return pks
@ -145,10 +142,9 @@ class PGVectoRS(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {"ids": ids})
session.commit()
def delete_by_ids(self, ids: list[str]):
with Session(self._client) as session:
@ -159,15 +155,13 @@ class PGVectoRS(BaseVector):
if result:
ids = [item[0] for item in result]
if ids:
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {"ids": ids})
session.commit()
def delete(self):
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self._client) as session:

View File

@ -53,6 +53,31 @@ def _session_factory(calls, execute_results=None):
return _session
class _FakeBeginContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return None
def _sessionmaker_factory(calls, execute_results=None):
def _sessionmaker(*args, **kwargs):
session = _FakeSessionContext(calls=calls, execute_results=execute_results)
return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
return _sessionmaker
def _patch_both(monkeypatch, module, calls, execute_results=None):
"""Patch both Session and sessionmaker on the module with the same call tracker."""
monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results))
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results))
@pytest.fixture
def pgvecto_module(monkeypatch):
for name, module in _build_fake_pgvecto_modules().items():
@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
_patch_both(monkeypatch, module, session_calls)
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
vector.create_collection = MagicMock()
@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
_patch_both(monkeypatch, module, session_calls)
lock = MagicMock()
lock.__enter__.return_value = None
@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
_patch_both(monkeypatch, module, init_calls)
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
_patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results))
class _InsertBuilder:
def __init__(self, table):
@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
"Session",
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
)
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
monkeypatch.setattr(
@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
],
),
)
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
vector.delete_by_ids(["doc-1"])
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
runtime_calls.clear()
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
_patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()])
vector.delete()
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
init_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
_patch_both(monkeypatch, module, init_calls)
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
runtime_calls = []
@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
]
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
_patch_both(monkeypatch, module, runtime_calls, execute_results=[rows])
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1