mirror of
https://github.com/langgenius/dify.git
synced 2026-04-12 14:10:42 +08:00
refactor: migrate service_api and inner_api to sessionmaker pattern (#34379)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e41965061c
commit
391007d02e
@ -6,7 +6,7 @@ from flask import current_app, request
|
|||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
@ -33,7 +33,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||||
try:
|
try:
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
user_model = None
|
user_model = None
|
||||||
|
|
||||||
if is_anonymous:
|
if is_anonymous:
|
||||||
@ -56,7 +56,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||||||
session_id=user_id,
|
session_id=user_id,
|
||||||
)
|
)
|
||||||
session.add(user_model)
|
session.add(user_model)
|
||||||
session.commit()
|
session.flush()
|
||||||
session.refresh(user_model)
|
session.refresh(user_model)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Any, Literal
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -116,7 +116,7 @@ class ConversationApi(Resource):
|
|||||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
pagination = ConversationService.pagination_by_last_id(
|
pagination = ConversationService.pagination_by_last_id(
|
||||||
session=session,
|
session=session,
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from graphon.enums import WorkflowExecutionStatus
|
|||||||
from graphon.graph_engine.manager import GraphEngineManager
|
from graphon.graph_engine.manager import GraphEngineManager
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
@ -314,7 +314,7 @@ class WorkflowAppLogApi(Resource):
|
|||||||
|
|
||||||
# get paginate workflow app logs
|
# get paginate workflow app logs
|
||||||
workflow_app_service = WorkflowAppService()
|
workflow_app_service = WorkflowAppService()
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||||
session=session,
|
session=session,
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
|
|||||||
@ -41,15 +41,15 @@ class TestGetUser:
|
|||||||
"""Test get_user function"""
|
"""Test get_user function"""
|
||||||
|
|
||||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||||
@patch("controllers.inner_api.plugin.wraps.db")
|
@patch("controllers.inner_api.plugin.wraps.db")
|
||||||
def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
|
||||||
"""Test returning existing user when found by ID"""
|
"""Test returning existing user when found by ID"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_user = MagicMock()
|
mock_user = MagicMock()
|
||||||
mock_user.id = "user123"
|
mock_user.id = "user123"
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
mock_session.get.return_value = mock_user
|
mock_session.get.return_value = mock_user
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -61,17 +61,17 @@ class TestGetUser:
|
|||||||
mock_session.get.assert_called_once()
|
mock_session.get.assert_called_once()
|
||||||
|
|
||||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||||
@patch("controllers.inner_api.plugin.wraps.db")
|
@patch("controllers.inner_api.plugin.wraps.db")
|
||||||
def test_should_return_existing_anonymous_user_by_session_id(
|
def test_should_return_existing_anonymous_user_by_session_id(
|
||||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask
|
||||||
):
|
):
|
||||||
"""Test returning existing anonymous user by session_id"""
|
"""Test returning existing anonymous user by session_id"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_user = MagicMock()
|
mock_user = MagicMock()
|
||||||
mock_user.session_id = "anonymous_session"
|
mock_user.session_id = "anonymous_session"
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
# non-anonymous path uses session.get(); anonymous uses session.scalar()
|
# non-anonymous path uses session.get(); anonymous uses session.scalar()
|
||||||
mock_session.get.return_value = mock_user
|
mock_session.get.return_value = mock_user
|
||||||
|
|
||||||
@ -83,13 +83,13 @@ class TestGetUser:
|
|||||||
assert result == mock_user
|
assert result == mock_user
|
||||||
|
|
||||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||||
@patch("controllers.inner_api.plugin.wraps.db")
|
@patch("controllers.inner_api.plugin.wraps.db")
|
||||||
def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
|
||||||
"""Test creating new user when not found in database"""
|
"""Test creating new user when not found in database"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
mock_session.get.return_value = None
|
mock_session.get.return_value = None
|
||||||
mock_new_user = MagicMock()
|
mock_new_user = MagicMock()
|
||||||
mock_enduser_class.return_value = mock_new_user
|
mock_enduser_class.return_value = mock_new_user
|
||||||
@ -101,21 +101,20 @@ class TestGetUser:
|
|||||||
# Assert
|
# Assert
|
||||||
assert result == mock_new_user
|
assert result == mock_new_user
|
||||||
mock_session.add.assert_called_once()
|
mock_session.add.assert_called_once()
|
||||||
mock_session.commit.assert_called_once()
|
|
||||||
mock_session.refresh.assert_called_once()
|
mock_session.refresh.assert_called_once()
|
||||||
|
|
||||||
@patch("controllers.inner_api.plugin.wraps.select")
|
@patch("controllers.inner_api.plugin.wraps.select")
|
||||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||||
@patch("controllers.inner_api.plugin.wraps.db")
|
@patch("controllers.inner_api.plugin.wraps.db")
|
||||||
def test_should_use_default_session_id_when_user_id_none(
|
def test_should_use_default_session_id_when_user_id_none(
|
||||||
self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask
|
self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
|
||||||
):
|
):
|
||||||
"""Test using default session ID when user_id is None"""
|
"""Test using default session ID when user_id is None"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_user = MagicMock()
|
mock_user = MagicMock()
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
# When user_id is None, is_anonymous=True, so session.scalar() is used
|
# When user_id is None, is_anonymous=True, so session.scalar() is used
|
||||||
mock_session.scalar.return_value = mock_user
|
mock_session.scalar.return_value = mock_user
|
||||||
|
|
||||||
@ -127,15 +126,13 @@ class TestGetUser:
|
|||||||
assert result == mock_user
|
assert result == mock_user
|
||||||
|
|
||||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||||
@patch("controllers.inner_api.plugin.wraps.db")
|
@patch("controllers.inner_api.plugin.wraps.db")
|
||||||
def test_should_raise_error_on_database_exception(
|
def test_should_raise_error_on_database_exception(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
|
||||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
|
||||||
):
|
|
||||||
"""Test raising ValueError when database operation fails"""
|
"""Test raising ValueError when database operation fails"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
mock_session.get.side_effect = Exception("Database error")
|
mock_session.get.side_effect = Exception("Database error")
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
|
|||||||
@ -433,13 +433,20 @@ class TestConversationApiController:
|
|||||||
handler(api, app_model=app_model, end_user=end_user)
|
handler(api, app_model=app_model, end_user=end_user)
|
||||||
|
|
||||||
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
class _SessionStub:
|
class _BeginStub:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return SimpleNamespace()
|
return SimpleNamespace()
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class _SessionMakerStub:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
return _BeginStub()
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
ConversationService,
|
ConversationService,
|
||||||
"pagination_by_last_id",
|
"pagination_by_last_id",
|
||||||
@ -447,7 +454,7 @@ class TestConversationApiController:
|
|||||||
)
|
)
|
||||||
conversation_module = sys.modules["controllers.service_api.app.conversation"]
|
conversation_module = sys.modules["controllers.service_api.app.conversation"]
|
||||||
monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object()))
|
monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object()))
|
||||||
monkeypatch.setattr(conversation_module, "Session", lambda *_args, **_kwargs: _SessionStub())
|
monkeypatch.setattr(conversation_module, "sessionmaker", _SessionMakerStub)
|
||||||
|
|
||||||
api = ConversationApi()
|
api = ConversationApi()
|
||||||
handler = _unwrap(api.get)
|
handler = _unwrap(api.get)
|
||||||
|
|||||||
@ -470,16 +470,23 @@ class TestWorkflowTaskStopApi:
|
|||||||
|
|
||||||
class TestWorkflowAppLogApi:
|
class TestWorkflowAppLogApi:
|
||||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
class _SessionStub:
|
class _BeginStub:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return SimpleNamespace()
|
return SimpleNamespace()
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class _SessionMakerStub:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
return _BeginStub()
|
||||||
|
|
||||||
workflow_module = sys.modules["controllers.service_api.app.workflow"]
|
workflow_module = sys.modules["controllers.service_api.app.workflow"]
|
||||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||||
monkeypatch.setattr(workflow_module, "Session", lambda *_args, **_kwargs: _SessionStub())
|
monkeypatch.setattr(workflow_module, "sessionmaker", _SessionMakerStub)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
WorkflowAppService,
|
WorkflowAppService,
|
||||||
"get_paginate_workflow_app_logs",
|
"get_paginate_workflow_app_logs",
|
||||||
@ -635,11 +642,14 @@ class TestWorkflowAppLogApiGet:
|
|||||||
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
|
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
|
||||||
mock_wf_svc_cls.return_value = mock_svc_instance
|
mock_wf_svc_cls.return_value = mock_svc_instance
|
||||||
|
|
||||||
# Mock Session context manager
|
# Mock sessionmaker(...).begin() context manager
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_db.engine = Mock()
|
mock_db.engine = Mock()
|
||||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
mock_begin = Mock()
|
||||||
mock_session.__exit__ = Mock(return_value=False)
|
mock_begin.__enter__ = Mock(return_value=mock_session)
|
||||||
|
mock_begin.__exit__ = Mock(return_value=False)
|
||||||
|
mock_session_factory = Mock()
|
||||||
|
mock_session_factory.begin.return_value = mock_begin
|
||||||
|
|
||||||
from controllers.service_api.app.workflow import WorkflowAppLogApi
|
from controllers.service_api.app.workflow import WorkflowAppLogApi
|
||||||
|
|
||||||
@ -647,7 +657,7 @@ class TestWorkflowAppLogApiGet:
|
|||||||
"/workflows/logs?page=1&limit=20",
|
"/workflows/logs?page=1&limit=20",
|
||||||
method="GET",
|
method="GET",
|
||||||
):
|
):
|
||||||
with patch("controllers.service_api.app.workflow.Session", return_value=mock_session):
|
with patch("controllers.service_api.app.workflow.sessionmaker", return_value=mock_session_factory):
|
||||||
api = WorkflowAppLogApi()
|
api = WorkflowAppLogApi()
|
||||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
|
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user