diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 90d9173409..387e918c76 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator from sqlalchemy import Float, create_engine, insert, select, text from sqlalchemy import text as sql_text from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import Mapped, Session, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker from configs import dify_config from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM @@ -55,9 +55,8 @@ class PGVectoRS(BaseVector): f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" ) self._client = create_engine(self._url) - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) - session.commit() self._fields: list[str] = [] class _Table(CollectionORM): @@ -88,7 +87,7 @@ class PGVectoRS(BaseVector): if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: create_statement = sql_text(f""" CREATE TABLE IF NOT EXISTS {self._collection_name} ( id UUID PRIMARY KEY, @@ -111,12 +110,11 @@ class PGVectoRS(BaseVector): $$); """) session.execute(index_statement) - session.commit() redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): pks = [] - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: for document, embedding in zip(documents, embeddings): pk = uuid4() session.execute( @@ -128,7 +126,6 @@ class PGVectoRS(BaseVector): ), ) pks.append(pk) - session.commit() return pks @@ -145,10 +142,9 @@ class PGVectoRS(BaseVector): def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") session.execute(select_statement, {"ids": ids}) - session.commit() def delete_by_ids(self, ids: list[str]): with Session(self._client) as session: @@ -159,15 +155,13 @@ class PGVectoRS(BaseVector): if result: ids = [item[0] for item in result] if ids: - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") session.execute(select_statement, {"ids": ids}) - session.commit() def delete(self): - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}")) - session.commit() def text_exists(self, id: str) -> bool: with Session(self._client) as session: diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py index 1aec81b8ac..5b9ec8002a 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -53,6 +53,31 @@ def _session_factory(calls, execute_results=None): return _session +class _FakeBeginContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return None + + +def _sessionmaker_factory(calls, execute_results=None): + def _sessionmaker(*args, **kwargs): + session = _FakeSessionContext(calls=calls, execute_results=execute_results) + return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session))) + + return _sessionmaker + + +def _patch_both(monkeypatch, module, calls, execute_results=None): + """Patch both Session and sessionmaker on the module with the same call tracker.""" + monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results)) + monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results)) + + @pytest.fixture def pgvecto_module(monkeypatch): for name, module in _build_fake_pgvecto_modules().items(): @@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch): module, _ = pgvecto_module session_calls = [] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + _patch_both(monkeypatch, module, session_calls) vector = module.PGVectoRS("collection_1", _config(module), dim=3) vector.create_collection = MagicMock() @@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch): module, _ = pgvecto_module session_calls = [] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + _patch_both(monkeypatch, module, session_calls) lock = MagicMock() lock.__enter__.return_value = None @@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + _patch_both(monkeypatch, module, init_calls) vector = module.PGVectoRS("collection_1", _config(module), dim=3) - monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results))) + _patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results)) class _InsertBuilder: def __init__(self, table): @@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): "Session", _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]), ) + monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls)) assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] monkeypatch.setattr( @@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): ], ), ) + monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls)) vector.delete_by_ids(["doc-1"]) assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls) assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) runtime_calls.clear() - monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()])) + _patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()]) vector.delete() assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls) @@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): module, _ = pgvecto_module init_calls = [] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + _patch_both(monkeypatch, module, init_calls) vector = module.PGVectoRS("collection_1", _config(module), dim=3) runtime_calls = [] @@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): (SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1), (SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8), ] - monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows])) + _patch_both(monkeypatch, module, runtime_calls, execute_results=[rows]) docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) assert len(docs) == 1