refactor(api): use sessionmaker in relyt & tidb_vector VDB services (#34848)

This commit is contained in:
carlos4s 2026-04-09 21:16:25 -06:00 committed by GitHub
parent d826ac7099
commit 86fd94767c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 24 deletions

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self.client) as session:
with sessionmaker(bind=self.client).begin() as session:
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
session.execute(drop_statement)
create_statement = sql_text(f"""
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
self.delete_by_uuids(ids)
def delete(self):
with Session(self.client) as session:
with sessionmaker(bind=self.client).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self.client) as session:

View File

@ -6,7 +6,7 @@ import sqlalchemy
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from configs import dify_config
from core.rag.datasource.vdb.field import Field, parse_metadata_json
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session:
session.begin()
with sessionmaker(bind=self._engine).begin() as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id CHAR(36) PRIMARY KEY,
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
);
""")
session.execute(create_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
return []
def delete(self):
with Session(self._engine) as session:
with sessionmaker(bind=self._engine).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()
def _get_distance_func(self) -> str:
match self._distance_func:

View File

@ -39,6 +39,25 @@ class _FakeSession:
return None
class _FakeBeginContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return None
def _patch_both(monkeypatch, module, session):
"""Patch both Session and sessionmaker on the module."""
monkeypatch.setattr(module, "Session", lambda _client: session)
monkeypatch.setattr(
module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
)
@pytest.fixture
def relyt_module(monkeypatch):
for name, module in _build_fake_relyt_modules().items():
@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
_patch_both(monkeypatch, relyt_module, session)
vector.create_collection(3)
session.execute.assert_not_called()
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
_patch_both(monkeypatch, relyt_module, session)
vector.create_collection(3)
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
# 8. delete commits session
def test_delete_commits_session(relyt_module, monkeypatch):
def test_delete_drops_table(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
_patch_both(monkeypatch, relyt_module, session)
vector.delete()
session.commit.assert_called_once()
session.execute.assert_called_once()
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):

View File

@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
session = MagicMock()
class _SessionCtx:
class _BeginCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
vector._create_collection(3)
session.begin.assert_called_once()
sql = str(session.execute.call_args.args[0])
assert "VECTOR<FLOAT>(3)" in sql
assert "VEC_L2_DISTANCE" in sql
session.commit.assert_called_once()
tidb_module.redis_client.set.assert_called_once()
@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
def test_delete_drops_table(tidb_module, monkeypatch):
session = MagicMock()
session.execute.return_value = None
session.commit = MagicMock()
class _SessionCtx:
class _BeginCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector.delete()
drop_sql = str(session.execute.call_args.args[0])
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
session.commit.assert_called_once()
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):