mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor(api): use sessionmaker in relyt & tidb_vector VDB services (#34848)
This commit is contained in:
parent
d826ac7099
commit
86fd94767c
@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self.client) as session:
|
||||
with sessionmaker(bind=self.client).begin() as session:
|
||||
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
|
||||
session.execute(drop_statement)
|
||||
create_statement = sql_text(f"""
|
||||
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
|
||||
$$);
|
||||
""")
|
||||
session.execute(index_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def delete(self):
|
||||
with Session(self.client) as session:
|
||||
with sessionmaker(bind=self.client).begin() as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self.client) as session:
|
||||
|
||||
@ -6,7 +6,7 @@ import sqlalchemy
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field, parse_metadata_json
|
||||
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
with Session(self._engine) as session:
|
||||
session.begin()
|
||||
with sessionmaker(bind=self._engine).begin() as session:
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id CHAR(36) PRIMARY KEY,
|
||||
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
|
||||
);
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
|
||||
return []
|
||||
|
||||
def delete(self):
|
||||
with Session(self._engine) as session:
|
||||
with sessionmaker(bind=self._engine).begin() as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||
session.commit()
|
||||
|
||||
def _get_distance_func(self) -> str:
|
||||
match self._distance_func:
|
||||
|
||||
@ -39,6 +39,25 @@ class _FakeSession:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeBeginContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _patch_both(monkeypatch, module, session):
|
||||
"""Patch both Session and sessionmaker on the module."""
|
||||
monkeypatch.setattr(module, "Session", lambda _client: session)
|
||||
monkeypatch.setattr(
|
||||
module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def relyt_module(monkeypatch):
|
||||
for name, module in _build_fake_relyt_modules().items():
|
||||
@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
|
||||
|
||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
_patch_both(monkeypatch, relyt_module, session)
|
||||
vector.create_collection(3)
|
||||
session.execute.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
_patch_both(monkeypatch, relyt_module, session)
|
||||
vector.create_collection(3)
|
||||
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
|
||||
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
|
||||
@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
|
||||
|
||||
|
||||
# 8. delete commits session
|
||||
def test_delete_commits_session(relyt_module, monkeypatch):
|
||||
def test_delete_drops_table(relyt_module, monkeypatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
_patch_both(monkeypatch, relyt_module, session)
|
||||
vector.delete()
|
||||
session.commit.assert_called_once()
|
||||
session.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
|
||||
|
||||
@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
||||
|
||||
session = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
class _BeginCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
|
||||
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
|
||||
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
||||
|
||||
vector._create_collection(3)
|
||||
|
||||
session.begin.assert_called_once()
|
||||
sql = str(session.execute.call_args.args[0])
|
||||
assert "VECTOR<FLOAT>(3)" in sql
|
||||
assert "VEC_L2_DISTANCE" in sql
|
||||
session.commit.assert_called_once()
|
||||
tidb_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
|
||||
def test_delete_drops_table(tidb_module, monkeypatch):
|
||||
session = MagicMock()
|
||||
session.execute.return_value = None
|
||||
session.commit = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
class _BeginCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
|
||||
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
vector.delete()
|
||||
drop_sql = str(session.execute.call_args.args[0])
|
||||
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user