diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 2b4c1b59ab..c9ee67863d 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -557,11 +557,9 @@ class TestPauseStatePersistenceLayerTestContainers: self.session.refresh(self.test_workflow_run) assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING - pause_states = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) - .all() - ) + pause_states = self.session.scalars( + select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) + ).all() assert len(pause_states) == 0 def test_layer_requires_initialization(self, db_session_with_containers): diff --git a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py index 8aec6b6acc..957b7145d3 100644 --- a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py +++ b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py @@ -6,7 +6,7 @@ import pytest import sqlalchemy as sa from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import exc as sa_exc -from sqlalchemy import insert +from sqlalchemy import insert, select from sqlalchemy.engine import Connection, Engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.sql.sqltypes import VARCHAR @@ -137,12 +137,12 @@ class TestEnumText: session.commit() with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == admin_user_id).first() + user = session.scalar(select(_User).where(_User.id == admin_user_id).limit(1)) assert user.user_type == _UserType.admin assert user.user_type_nullable is None with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == normal_user_id).first() + user = session.scalar(select(_User).where(_User.id == normal_user_id).limit(1)) assert user.user_type == _UserType.normal assert user.user_type_nullable == _UserType.normal @@ -206,7 +206,7 @@ class TestEnumText: with pytest.raises(ValueError) as exc: with Session(engine_with_containers) as session: - _user = session.query(_User).where(_User.id == 1).first() + _user = session.scalar(select(_User).where(_User.id == 1).limit(1)) assert str(exc.value) == "'invalid' is not a valid _UserType" @@ -222,7 +222,7 @@ class TestEnumText: session.commit() with Session(engine_with_containers) as session: - records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all() + records = session.scalars(select(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id)).all() assert [record.model_type for record in records] == [ ModelType.LLM, diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 159ab51304..4bc022c415 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -26,7 +26,7 @@ from datetime import timedelta import pytest from graphon.entities import WorkflowExecution from graphon.enums import WorkflowExecutionStatus -from sqlalchemy import delete, select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage @@ -679,9 +679,12 @@ class TestWorkflowPauseIntegration: # Verify only 3 were deleted remaining_count = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities])) - .count() + self.session.scalar( + select(func.count(WorkflowPauseModel.id)).where( + WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]) + ) + ) + or 0 ) assert remaining_count == 2