From ae9fcc2969fc16887ca29836e3a54432336c6362 Mon Sep 17 00:00:00 2001 From: carlos4s <71615127+carlos4s@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:47:20 -0500 Subject: [PATCH] refactor: use sessionmaker in controllers, events, models, and tasks 1 (#34693) --- api/controllers/console/app/app_import.py | 4 ++-- api/controllers/inner_api/app/dsl.py | 5 ++--- ...ync_workflow_schedule_when_app_published.py | 7 ++----- ...gers_when_app_published_workflow_updated.py | 6 ++---- .../logstore_api_workflow_run_repository.py | 8 ++++---- api/models/model.py | 4 ++-- api/tasks/batch_clean_document_task.py | 2 +- .../rag_pipeline/rag_pipeline_run_task.py | 4 ++-- .../controllers/inner_api/app/test_dsl.py | 10 ++++++---- .../services/test_schedule_service.py | 18 +++++++++--------- 10 files changed, 32 insertions(+), 36 deletions(-) diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index c2805f765b..16e1fa3245 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,6 +1,6 @@ from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( @@ -71,7 +71,7 @@ class AppImportApi(Resource): args = AppImportPayload.model_validate(console_ns.payload) # Create service with session - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: import_service = AppDslService(session) # Import app account = current_user diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index 3b673d6e1d..b1986b2557 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -9,7 +9,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_model from controllers.console.wraps import setup_required @@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource): account.set_tenant_id(workspace_id) - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: dsl_service = AppDslService(session) result = dsl_service.import_app( account=account, @@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource): name=args.name, description=args.description, ) - session.commit() if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 diff --git a/api/events/event_handlers/sync_workflow_schedule_when_app_published.py b/api/events/event_handlers/sync_workflow_schedule_when_app_published.py index 168513fc04..5f8fcd8617 100644 --- a/api/events/event_handlers/sync_workflow_schedule_when_app_published.py +++ b/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @@ -2,7 +2,7 @@ import logging from typing import cast from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate from events.app_event import app_published_workflow_was_updated @@ -45,7 +45,7 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) Returns: Updated or created WorkflowSchedulePlan, or None if no schedule node """ - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: schedule_config = ScheduleService.extract_schedule_config(workflow) existing_plan = session.scalar( @@ -59,7 +59,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) if existing_plan: logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id) ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id) - session.commit() return None if existing_plan: @@ -73,7 +72,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) schedule_id=existing_plan.id, updates=updates, ) - session.commit() return updated_plan else: new_plan = ScheduleService.create_schedule( @@ -82,5 +80,4 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) app_id=app_id, config=schedule_config, ) - session.commit() return new_plan diff --git a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py index b3917d5622..d55fe262fb 100644 --- a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @@ -1,7 +1,7 @@ from typing import cast from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.trigger.constants import TRIGGER_NODE_TYPES from events.app_event import app_published_workflow_was_updated @@ -31,7 +31,7 @@ def handle(sender, **kwargs): # Extract trigger info from workflow trigger_infos = get_trigger_infos_from_workflow(published_workflow) - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: # Get existing app triggers existing_triggers = ( session.execute( @@ -79,8 +79,6 @@ def handle(sender, **kwargs): existing_trigger.title = new_title session.add(existing_trigger) - session.commit() - def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]: """ diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 3c83ab4f84..2745141431 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -354,11 +354,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) -> WorkflowRun | None: """Fallback to PostgreSQL query for records not in LogStore (with tenant isolation).""" from sqlalchemy import select - from sqlalchemy.orm import Session + from sqlalchemy.orm import sessionmaker from extensions.ext_database import db - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: stmt = select(WorkflowRun).where( WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id ) @@ -439,11 +439,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None: """Fallback to PostgreSQL query for records not in LogStore.""" from sqlalchemy import select - from sqlalchemy.orm import Session + from sqlalchemy.orm import sessionmaker from extensions.ext_database import db - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: stmt = select(WorkflowRun).where(WorkflowRun.id == run_id) return session.scalar(stmt) diff --git a/api/models/model.py b/api/models/model.py index 1d73aadf09..43ddf344d2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -18,7 +18,7 @@ from graphon.enums import WorkflowExecutionStatus from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text -from sqlalchemy.orm import Mapped, Session, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS @@ -524,7 +524,7 @@ class App(Base): if not api_provider_ids and not builtin_provider_ids: return [] - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: if api_provider_ids: existing_api_providers = [ str(api_provider.id) diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 747106d373..75e6437f3f 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -92,7 +92,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form # ============ Step 3: Delete metadata binding (separate short transaction) ============ try: with session_factory.create_session() as session: - deleted_count = ( + deleted_count = int( session.query(DatasetMetadataBinding) .where( DatasetMetadataBinding.dataset_id == dataset_id, diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index db04b3375b..8e1e096ed0 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -12,7 +12,7 @@ import click from celery import group, shared_task from flask import current_app, g from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity @@ -131,7 +131,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Load required entities account = session.scalar(select(Account).where(Account.id == user_id).limit(1)) if not account: diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 4a5f91cc5d..974d8f7bc6 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -102,14 +102,16 @@ class TestEnterpriseAppDSLImport: @pytest.fixture def _mock_import_deps(self): - """Patch db, Session, and AppDslService for import handler tests.""" + """Patch db, sessionmaker, and AppDslService for import handler tests.""" + mock_session_ctx = MagicMock() + mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock()) + mock_session_ctx.__exit__ = MagicMock(return_value=False) + mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx))) with ( patch("controllers.inner_api.app.dsl.db"), - patch("controllers.inner_api.app.dsl.Session") as mock_session, + patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker), patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, ): - mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock()) - mock_session.return_value.__exit__ = MagicMock(return_value=False) self._mock_dsl = MagicMock() mock_dsl_cls.return_value = self._mock_dsl yield diff --git a/api/tests/unit_tests/services/test_schedule_service.py b/api/tests/unit_tests/services/test_schedule_service.py index 2a78876da6..334062242b 100644 --- a/api/tests/unit_tests/services/test_schedule_service.py +++ b/api/tests/unit_tests/services/test_schedule_service.py @@ -690,8 +690,8 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase): mock_db.engine = MagicMock() mock_session.__enter__ = MagicMock(return_value=mock_session) mock_session.__exit__ = MagicMock(return_value=None) - Session = MagicMock(return_value=mock_session) - with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session): + sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session))) + with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker): mock_session.scalar.return_value = None # No existing plan # Mock extract_schedule_config to return a ScheduleConfig object @@ -709,7 +709,7 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase): assert result == mock_new_plan mock_service.create_schedule.assert_called_once() - mock_session.commit.assert_called_once() + mock_session.commit.assert_not_called() @patch("events.event_handlers.sync_workflow_schedule_when_app_published.db") @patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService") @@ -720,9 +720,9 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase): mock_db.engine = MagicMock() mock_session.__enter__ = MagicMock(return_value=mock_session) mock_session.__exit__ = MagicMock(return_value=None) - Session = MagicMock(return_value=mock_session) + sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session))) - with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session): + with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker): mock_existing_plan = Mock(spec=WorkflowSchedulePlan) mock_existing_plan.id = "existing-plan-id" mock_session.scalar.return_value = mock_existing_plan @@ -751,7 +751,7 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase): assert updates_obj.node_id == "start" assert updates_obj.cron_expression == "0 12 * * *" assert updates_obj.timezone == "America/New_York" - mock_session.commit.assert_called_once() + mock_session.commit.assert_not_called() @patch("events.event_handlers.sync_workflow_schedule_when_app_published.db") @patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService") @@ -762,9 +762,9 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase): mock_db.engine = MagicMock() mock_session.__enter__ = MagicMock(return_value=mock_session) mock_session.__exit__ = MagicMock(return_value=None) - Session = MagicMock(return_value=mock_session) + sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session))) - with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session): + with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker): mock_existing_plan = Mock(spec=WorkflowSchedulePlan) mock_existing_plan.id = "existing-plan-id" mock_session.scalar.return_value = mock_existing_plan @@ -777,7 +777,7 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase): assert result is None # Now using ScheduleService.delete_schedule instead of session.delete mock_service.delete_schedule.assert_called_once_with(session=mock_session, schedule_id="existing-plan-id") - mock_session.commit.assert_called_once() + mock_session.commit.assert_not_called() @pytest.fixture