refactor: use sessionmaker in small services 2 (#34696)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
carlos4s 2026-04-08 00:06:50 -05:00 committed by GitHub
parent 909c062ee1
commit a65e1f71b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 86 additions and 53 deletions

View File

@ -9,7 +9,7 @@ from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
class InvitationData(TypedDict):
@ -1516,7 +1516,7 @@ class RegisterService:
check_workspace_member_invite_permission(tenant.id)
with Session(db.engine) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account:

View File

@ -11,7 +11,7 @@ from typing import Any, Union
from celery.result import AsyncResult
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from enums.quota_type import QuotaType
from extensions.ext_database import db
@ -237,7 +237,7 @@ class AsyncWorkflowService:
Returns:
Trigger log as dictionary or None if not found
"""
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
@ -263,7 +263,7 @@ class AsyncWorkflowService:
Returns:
List of trigger logs as dictionaries
"""
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_recent_logs(
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
@ -286,7 +286,7 @@ class AsyncWorkflowService:
Returns:
List of failed trigger logs as dictionaries
"""
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_failed_for_retry(
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit

View File

@ -346,7 +346,7 @@ class ClearFreePlanTenantExpiredLogs:
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
current_time = started_at
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
total_tenant_count = session.query(Tenant.id).count()
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
@ -398,7 +398,7 @@ class ClearFreePlanTenantExpiredLogs:
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
interval = datetime.timedelta(days=1)
# Process tenants in this batch
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
# Calculate tenant count in next batch with current interval
# Try different intervals until we find one with a reasonable tenant count
test_intervals = [

View File

@ -1,7 +1,7 @@
import logging
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.errors.error import QuotaExceededError
@ -71,7 +71,7 @@ class CreditPoolService:
actual_credits = min(credits_required, pool.remaining_credits)
try:
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
stmt = (
update(TenantCreditPool)
.where(
@ -81,7 +81,6 @@ class CreditPoolService:
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
)
session.execute(stmt)
session.commit()
except Exception:
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits")

View File

@ -15,7 +15,7 @@ from graphon.model_runtime.entities.model_entities import ModelFeature, ModelTyp
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from redis.exceptions import LockNotOwnedError
from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
@ -551,7 +551,7 @@ class DatasetService:
external_knowledge_id: External knowledge identifier
external_knowledge_api_id: External knowledge API identifier
"""
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
external_knowledge_binding = (
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
)
@ -559,14 +559,14 @@ class DatasetService:
if not external_knowledge_binding:
raise ValueError("External knowledge binding not found.")
# Update binding if values have changed
if (
external_knowledge_binding.external_knowledge_id != external_knowledge_id
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
):
external_knowledge_binding.external_knowledge_id = external_knowledge_id
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
db.session.add(external_knowledge_binding)
# Update binding if values have changed
if (
external_knowledge_binding.external_knowledge_id != external_knowledge_id
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
):
external_knowledge_binding.external_knowledge_id = external_knowledge_id
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
session.add(external_knowledge_binding)
@staticmethod
def _update_internal_dataset(dataset, data, user):

View File

@ -2,7 +2,7 @@ import enum
import uuid
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest
from extensions.ext_database import db
@ -29,7 +29,7 @@ class OAuthServerService:
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
with Session(db.engine) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return session.execute(query).scalar_one_or_none()
@staticmethod

View File

@ -1182,7 +1182,7 @@ class RagPipelineService:
workflow = db.session.get(Workflow, pipeline.workflow_id)
if not workflow:
raise ValueError("Workflow not found")
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
dataset = pipeline.retrieve_dataset(session=session)
if not dataset:
raise ValueError("Dataset not found")
@ -1209,7 +1209,7 @@ class RagPipelineService:
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
if args.get("icon_info") is None:

View File

@ -834,7 +834,7 @@ class WorkflowService:
if workflow_node_execution is None:
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
outputs = workflow_node_execution.load_full_outputs(session, storage)
with Session(bind=db.engine) as session, session.begin():

View File

@ -1427,16 +1427,18 @@ class TestRegisterService:
mock_tenant.name = "Test Workspace"
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
# Mock database queries - need to mock the Session query
# Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with (
patch("services.account_service.Session") as mock_session_class,
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = None
# Mock RegisterService.register
@ -1485,12 +1487,14 @@ class TestRegisterService:
mixed_email = "Invitee@Example.com"
mock_session = MagicMock()
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with (
patch("services.account_service.Session") as mock_session_class,
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = None
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
@ -1541,16 +1545,18 @@ class TestRegisterService:
account_id="existing-user-456", email="existing@example.com", status="pending"
)
# Mock database queries - need to mock the Session query
# Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with (
patch("services.account_service.Session") as mock_session_class,
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = mock_existing_account
# Mock scalar for TenantAccountJoin lookup - no existing member

View File

@ -357,11 +357,12 @@ class TestAsyncWorkflowService:
mock_session_context.__enter__.return_value = mock_session
mock_session_context.__exit__.return_value = None
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = mock_session_context
with (
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
patch.object(
async_workflow_service_module, "Session", return_value=mock_session_context
) as mock_session_class,
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
patch.object(
async_workflow_service_module,
"SQLAlchemyWorkflowTriggerLogRepository",
@ -373,7 +374,7 @@ class TestAsyncWorkflowService:
# Assert
assert result == expected
mock_session_class.assert_called_once_with(fake_engine)
mock_sessionmaker.assert_called_once_with(fake_engine)
mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123")
def test_should_return_recent_logs_as_dict_list(self):
@ -391,9 +392,12 @@ class TestAsyncWorkflowService:
mock_session_context.__enter__.return_value = mock_session
mock_session_context.__exit__.return_value = None
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = mock_session_context
with (
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
patch.object(
async_workflow_service_module,
"SQLAlchemyWorkflowTriggerLogRepository",
@ -432,9 +436,12 @@ class TestAsyncWorkflowService:
mock_session_context.__enter__.return_value = mock_session
mock_session_context.__exit__.return_value = None
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = mock_session_context
with (
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
patch.object(
async_workflow_service_module,
"SQLAlchemyWorkflowTriggerLogRepository",

View File

@ -209,8 +209,22 @@ def _session_wrapper_for_no_autoflush(session: Mock) -> Mock:
return wrapper
def _sessionmaker_wrapper_for_begin(session: Mock) -> Mock:
"""
ClearFreePlanTenantExpiredLogs.process uses: with sessionmaker(db.engine).begin() as session:
so sessionmaker(db.engine) must return an object with a begin() method that returns a context manager.
"""
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
begin_cm.__exit__.return_value = None
sessionmaker_result = MagicMock()
sessionmaker_result.begin.return_value = begin_cm
return sessionmaker_result
def _session_wrapper_for_direct(session: Mock) -> Mock:
"""ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session:"""
"""ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session: (for old code paths)"""
wrapper = MagicMock()
wrapper.__enter__.return_value = session
wrapper.__exit__.return_value = None
@ -348,7 +362,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py
count_query.count.return_value = 2
count_session.query.return_value = count_query
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session))
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
# Avoid LocalProxy usage
flask_app = service_module.Flask("test-app")
@ -438,8 +452,8 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)]
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0))
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
process_tenant_mock = MagicMock()
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
@ -457,7 +471,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo
count_query = MagicMock()
count_query.count.return_value = 100
count_session.query.return_value = count_query
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session))
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
flask_app = service_module.Flask("test-app")
monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app))
@ -523,8 +537,8 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
batch_session.query.side_effect = [*count_queries, q_rs]
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)]
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0))
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
process_tenant_mock = MagicMock()
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)

View File

@ -578,26 +578,33 @@ class TestDatasetServiceCreationAndUpdate:
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = binding
session.add = MagicMock()
session_context = _make_session_context(session)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = session_context
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.Session", return_value=session_context),
patch("services.dataset_service.sessionmaker", mock_sessionmaker),
):
DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api")
assert binding.external_knowledge_id == "new-knowledge"
assert binding.external_knowledge_api_id == "new-api"
mock_db.session.add.assert_called_once_with(binding)
session.add.assert_called_once_with(binding)
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = None
session_context = _make_session_context(session)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = session_context
with (
patch("services.dataset_service.db"),
patch("services.dataset_service.Session", return_value=session_context),
patch("services.dataset_service.sessionmaker", mock_sessionmaker),
):
with pytest.raises(ValueError, match="External knowledge binding not found"):
DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1")