mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor(api): use sessionmaker in pgvecto_rs VDB service (#34818)
This commit is contained in:
parent
5f53748d07
commit
d360929af1
@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Float, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
|
||||
@ -55,9 +55,8 @@ class PGVectoRS(BaseVector):
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self._client = create_engine(self._url)
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
session.commit()
|
||||
self._fields: list[str] = []
|
||||
|
||||
class _Table(CollectionORM):
|
||||
@ -88,7 +87,7 @@ class PGVectoRS(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
@ -111,12 +110,11 @@ class PGVectoRS(BaseVector):
|
||||
$$);
|
||||
""")
|
||||
session.execute(index_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
pks = []
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
for document, embedding in zip(documents, embeddings):
|
||||
pk = uuid4()
|
||||
session.execute(
|
||||
@ -128,7 +126,6 @@ class PGVectoRS(BaseVector):
|
||||
),
|
||||
)
|
||||
pks.append(pk)
|
||||
session.commit()
|
||||
|
||||
return pks
|
||||
|
||||
@ -145,10 +142,9 @@ class PGVectoRS(BaseVector):
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
with Session(self._client) as session:
|
||||
@ -159,15 +155,13 @@ class PGVectoRS(BaseVector):
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete(self):
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self._client) as session:
|
||||
|
||||
@ -53,6 +53,31 @@ def _session_factory(calls, execute_results=None):
|
||||
return _session
|
||||
|
||||
|
||||
class _FakeBeginContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _sessionmaker_factory(calls, execute_results=None):
|
||||
def _sessionmaker(*args, **kwargs):
|
||||
session = _FakeSessionContext(calls=calls, execute_results=execute_results)
|
||||
return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
|
||||
|
||||
return _sessionmaker
|
||||
|
||||
|
||||
def _patch_both(monkeypatch, module, calls, execute_results=None):
|
||||
"""Patch both Session and sessionmaker on the module with the same call tracker."""
|
||||
monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results))
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pgvecto_module(monkeypatch):
|
||||
for name, module in _build_fake_pgvecto_modules().items():
|
||||
@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
_patch_both(monkeypatch, module, session_calls)
|
||||
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
vector.create_collection = MagicMock()
|
||||
@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
_patch_both(monkeypatch, module, session_calls)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
|
||||
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
_patch_both(monkeypatch, module, init_calls)
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results))
|
||||
|
||||
class _InsertBuilder:
|
||||
def __init__(self, table):
|
||||
@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
"Session",
|
||||
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
|
||||
)
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
],
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
runtime_calls.clear()
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()])
|
||||
vector.delete()
|
||||
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
_patch_both(monkeypatch, module, init_calls)
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
runtime_calls = []
|
||||
@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
|
||||
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
|
||||
]
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=[rows])
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user