refactor: migrate service_api and inner_api to sessionmaker pattern (#34379)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tim Ren 2026-04-01 22:53:41 +08:00 committed by GitHub
parent e41965061c
commit 391007d02e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 47 additions and 33 deletions

View File

@ -6,7 +6,7 @@ from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_user
@ -33,7 +33,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try: try:
with Session(db.engine) as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
user_model = None user_model = None
if is_anonymous: if is_anonymous:
@ -56,7 +56,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
session_id=user_id, session_id=user_id,
) )
session.add(user_model) session.add(user_model)
session.commit() session.flush()
session.refresh(user_model) session.refresh(user_model)
except Exception: except Exception:

View File

@ -3,7 +3,7 @@ from typing import Any, Literal
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
import services import services
@ -116,7 +116,7 @@ class ConversationApi(Resource):
last_id = str(query_args.last_id) if query_args.last_id else None last_id = str(query_args.last_id) if query_args.last_id else None
try: try:
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
pagination = ConversationService.pagination_by_last_id( pagination = ConversationService.pagination_by_last_id(
session=session, session=session,
app_model=app_model, app_model=app_model,

View File

@ -8,7 +8,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.graph_engine.manager import GraphEngineManager from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@ -314,7 +314,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session, session=session,
app_model=app_model, app_model=app_model,

View File

@ -41,15 +41,15 @@ class TestGetUser:
"""Test get_user function""" """Test get_user function"""
@patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session") @patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db") @patch("controllers.inner_api.plugin.wraps.db")
def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
"""Test returning existing user when found by ID""" """Test returning existing user when found by ID"""
# Arrange # Arrange
mock_user = MagicMock() mock_user = MagicMock()
mock_user.id = "user123" mock_user.id = "user123"
mock_session = MagicMock() mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.get.return_value = mock_user mock_session.get.return_value = mock_user
# Act # Act
@ -61,17 +61,17 @@ class TestGetUser:
mock_session.get.assert_called_once() mock_session.get.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session") @patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db") @patch("controllers.inner_api.plugin.wraps.db")
def test_should_return_existing_anonymous_user_by_session_id( def test_should_return_existing_anonymous_user_by_session_id(
self, mock_db, mock_session_class, mock_enduser_class, app: Flask self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask
): ):
"""Test returning existing anonymous user by session_id""" """Test returning existing anonymous user by session_id"""
# Arrange # Arrange
mock_user = MagicMock() mock_user = MagicMock()
mock_user.session_id = "anonymous_session" mock_user.session_id = "anonymous_session"
mock_session = MagicMock() mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
# non-anonymous path uses session.get(); anonymous uses session.scalar() # non-anonymous path uses session.get(); anonymous uses session.scalar()
mock_session.get.return_value = mock_user mock_session.get.return_value = mock_user
@ -83,13 +83,13 @@ class TestGetUser:
assert result == mock_user assert result == mock_user
@patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session") @patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db") @patch("controllers.inner_api.plugin.wraps.db")
def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
"""Test creating new user when not found in database""" """Test creating new user when not found in database"""
# Arrange # Arrange
mock_session = MagicMock() mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.get.return_value = None mock_session.get.return_value = None
mock_new_user = MagicMock() mock_new_user = MagicMock()
mock_enduser_class.return_value = mock_new_user mock_enduser_class.return_value = mock_new_user
@ -101,21 +101,20 @@ class TestGetUser:
# Assert # Assert
assert result == mock_new_user assert result == mock_new_user
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once() mock_session.refresh.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.select")
@patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session") @patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db") @patch("controllers.inner_api.plugin.wraps.db")
def test_should_use_default_session_id_when_user_id_none( def test_should_use_default_session_id_when_user_id_none(
self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
): ):
"""Test using default session ID when user_id is None""" """Test using default session ID when user_id is None"""
# Arrange # Arrange
mock_user = MagicMock() mock_user = MagicMock()
mock_session = MagicMock() mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
# When user_id is None, is_anonymous=True, so session.scalar() is used # When user_id is None, is_anonymous=True, so session.scalar() is used
mock_session.scalar.return_value = mock_user mock_session.scalar.return_value = mock_user
@ -127,15 +126,13 @@ class TestGetUser:
assert result == mock_user assert result == mock_user
@patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session") @patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db") @patch("controllers.inner_api.plugin.wraps.db")
def test_should_raise_error_on_database_exception( def test_should_raise_error_on_database_exception(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
):
"""Test raising ValueError when database operation fails""" """Test raising ValueError when database operation fails"""
# Arrange # Arrange
mock_session = MagicMock() mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.get.side_effect = Exception("Database error") mock_session.get.side_effect = Exception("Database error")
# Act & Assert # Act & Assert

View File

@ -433,13 +433,20 @@ class TestConversationApiController:
handler(api, app_model=app_model, end_user=end_user) handler(api, app_model=app_model, end_user=end_user)
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
class _SessionStub: class _BeginStub:
def __enter__(self): def __enter__(self):
return SimpleNamespace() return SimpleNamespace()
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
class _SessionMakerStub:
def __init__(self, *args, **kwargs):
pass
def begin(self):
return _BeginStub()
monkeypatch.setattr( monkeypatch.setattr(
ConversationService, ConversationService,
"pagination_by_last_id", "pagination_by_last_id",
@ -447,7 +454,7 @@ class TestConversationApiController:
) )
conversation_module = sys.modules["controllers.service_api.app.conversation"] conversation_module = sys.modules["controllers.service_api.app.conversation"]
monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(conversation_module, "Session", lambda *_args, **_kwargs: _SessionStub()) monkeypatch.setattr(conversation_module, "sessionmaker", _SessionMakerStub)
api = ConversationApi() api = ConversationApi()
handler = _unwrap(api.get) handler = _unwrap(api.get)

View File

@ -470,16 +470,23 @@ class TestWorkflowTaskStopApi:
class TestWorkflowAppLogApi: class TestWorkflowAppLogApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
class _SessionStub: class _BeginStub:
def __enter__(self): def __enter__(self):
return SimpleNamespace() return SimpleNamespace()
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
class _SessionMakerStub:
def __init__(self, *args, **kwargs):
pass
def begin(self):
return _BeginStub()
workflow_module = sys.modules["controllers.service_api.app.workflow"] workflow_module = sys.modules["controllers.service_api.app.workflow"]
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(workflow_module, "Session", lambda *_args, **_kwargs: _SessionStub()) monkeypatch.setattr(workflow_module, "sessionmaker", _SessionMakerStub)
monkeypatch.setattr( monkeypatch.setattr(
WorkflowAppService, WorkflowAppService,
"get_paginate_workflow_app_logs", "get_paginate_workflow_app_logs",
@ -635,11 +642,14 @@ class TestWorkflowAppLogApiGet:
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
mock_wf_svc_cls.return_value = mock_svc_instance mock_wf_svc_cls.return_value = mock_svc_instance
# Mock Session context manager # Mock sessionmaker(...).begin() context manager
mock_session = Mock() mock_session = Mock()
mock_db.engine = Mock() mock_db.engine = Mock()
mock_session.__enter__ = Mock(return_value=mock_session) mock_begin = Mock()
mock_session.__exit__ = Mock(return_value=False) mock_begin.__enter__ = Mock(return_value=mock_session)
mock_begin.__exit__ = Mock(return_value=False)
mock_session_factory = Mock()
mock_session_factory.begin.return_value = mock_begin
from controllers.service_api.app.workflow import WorkflowAppLogApi from controllers.service_api.app.workflow import WorkflowAppLogApi
@ -647,7 +657,7 @@ class TestWorkflowAppLogApiGet:
"/workflows/logs?page=1&limit=20", "/workflows/logs?page=1&limit=20",
method="GET", method="GET",
): ):
with patch("controllers.service_api.app.workflow.Session", return_value=mock_session): with patch("controllers.service_api.app.workflow.sessionmaker", return_value=mock_session_factory):
api = WorkflowAppLogApi() api = WorkflowAppLogApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app) result = _unwrap(api.get)(api, app_model=mock_workflow_app)