From 7b1b5c2445a4543748fd479db8ac664a1df5e227 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 26 Feb 2026 00:22:20 +0900 Subject: [PATCH] test: example for [Refactor/Chore] use Testcontainers to do sql test #32454 (#32459) --- .../models/test_types_enum_text.py | 69 +++++++++++-------- 1 file changed, 42 insertions(+), 27 deletions(-) rename api/tests/{unit_tests => test_containers_integration_tests}/models/test_types_enum_text.py (76%) diff --git a/api/tests/unit_tests/models/test_types_enum_text.py b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py similarity index 76% rename from api/tests/unit_tests/models/test_types_enum_text.py rename to api/tests/test_containers_integration_tests/models/test_types_enum_text.py index c59afcf0db..206c84c750 100644 --- a/api/tests/unit_tests/models/test_types_enum_text.py +++ b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py @@ -6,11 +6,15 @@ import pytest import sqlalchemy as sa from sqlalchemy import exc as sa_exc from sqlalchemy import insert +from sqlalchemy.engine import Connection, Engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.sql.sqltypes import VARCHAR from models.types import EnumText +_USER_TABLE = "enum_text_users" +_COLUMN_TABLE = "enum_text_column_test" + _user_type_admin = "admin" _user_type_normal = "normal" @@ -30,7 +34,7 @@ class _EnumWithLongValue(StrEnum): class _User(_Base): - __tablename__ = "users" + __tablename__ = _USER_TABLE id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) name: Mapped[str] = mapped_column(sa.String(length=255), nullable=False) @@ -41,7 +45,7 @@ class _User(_Base): class _ColumnTest(_Base): - __tablename__ = "column_test" + __tablename__ = _COLUMN_TABLE id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) @@ -64,13 +68,30 @@ def _first(it: Iterable[_T]) -> _T: return ls[0] -class TestEnumText: - def test_column_impl(self): - engine = sa.create_engine("sqlite://", echo=False) - _Base.metadata.create_all(engine) +def _resolve_engine(bind: Engine | Connection) -> Engine: + if isinstance(bind, Engine): + return bind + return bind.engine - inspector = sa.inspect(engine) - columns = inspector.get_columns(_ColumnTest.__tablename__) + +@pytest.fixture +def engine_with_containers(db_session_with_containers: Session) -> Engine: + return _resolve_engine(db_session_with_containers.get_bind()) + + +@pytest.fixture(autouse=True) +def _enum_text_schema(engine_with_containers: Engine) -> Iterable[None]: + _Base.metadata.create_all(engine_with_containers) + try: + yield + finally: + _Base.metadata.drop_all(engine_with_containers) + + +class TestEnumText: + def test_column_impl(self, engine_with_containers: Engine): + inspector = sa.inspect(engine_with_containers) + columns = inspector.get_columns(_COLUMN_TABLE) user_type_column = _first(c for c in columns if c["name"] == "user_type") sql_type = user_type_column["type"] @@ -89,11 +110,8 @@ class TestEnumText: assert isinstance(sql_type, VARCHAR) assert sql_type.length == len(_EnumWithLongValue.a_really_long_enum_values) - def test_insert_and_select(self): - engine = sa.create_engine("sqlite://", echo=False) - _Base.metadata.create_all(engine) - - with Session(engine) as session: + def test_insert_and_select(self, engine_with_containers: Engine): + with Session(engine_with_containers) as session: admin_user = _User( name="admin", user_type=_UserType.admin, @@ -113,17 +131,17 @@ class TestEnumText: normal_user_id = normal_user.id session.commit() - with Session(engine) as session: + with Session(engine_with_containers) as session: user = session.query(_User).where(_User.id == admin_user_id).first() assert user.user_type == _UserType.admin assert user.user_type_nullable is None - with Session(engine) as session: + with Session(engine_with_containers) as session: user = session.query(_User).where(_User.id == normal_user_id).first() assert user.user_type == _UserType.normal assert user.user_type_nullable == _UserType.normal - def test_insert_invalid_values(self): + def test_insert_invalid_values(self, engine_with_containers: Engine): def _session_insert_with_value(sess: Session, user_type: Any): user = _User(name="test_user", user_type=user_type) sess.add(user) @@ -143,8 +161,6 @@ class TestEnumText: action: Callable[[Session], None] exc_type: type[Exception] - engine = sa.create_engine("sqlite://", echo=False) - _Base.metadata.create_all(engine) cases = [ TestCase( name="session insert with invalid value", @@ -169,23 +185,22 @@ class TestEnumText: ] for idx, c in enumerate(cases, 1): with pytest.raises(sa_exc.StatementError) as exc: - with Session(engine) as session: + with Session(engine_with_containers) as session: c.action(session) assert isinstance(exc.value.orig, c.exc_type), f"test case {idx} failed, name={c.name}" - def test_select_invalid_values(self): - engine = sa.create_engine("sqlite://", echo=False) - _Base.metadata.create_all(engine) - - insertion_sql = """ - INSERT INTO users (id, name, user_type) VALUES + def test_select_invalid_values(self, engine_with_containers: Engine): + insertion_sql = f""" + INSERT INTO {_USER_TABLE} (id, name, user_type) VALUES (1, 'invalid_value', 'invalid'); """ - with Session(engine) as session: + with Session(engine_with_containers) as session: session.execute(sa.text(insertion_sql)) session.commit() with pytest.raises(ValueError) as exc: - with Session(engine) as session: + with Session(engine_with_containers) as session: _user = session.query(_User).where(_User.id == 1).first() + + assert str(exc.value) == "'invalid' is not a valid _UserType"