refactor: use sessionmaker in controllers, events, models, and tasks 1 (#34693)

This commit is contained in:
carlos4s 2026-04-07 18:47:20 -05:00 committed by GitHub
parent 624db69f12
commit ae9fcc2969
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 32 additions and 36 deletions

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@ -71,7 +71,7 @@ class AppImportApi(Resource):
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
import_service = AppDslService(session)
# Import app
account = current_user

View File

@ -9,7 +9,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
account.set_tenant_id(workspace_id)
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
dsl_service = AppDslService(session)
result = dsl_service.import_app(
account=account,
@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource):
name=args.name,
description=args.description,
)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400

View File

@ -2,7 +2,7 @@ import logging
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate
from events.app_event import app_published_workflow_was_updated
@ -45,7 +45,7 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
Returns:
Updated or created WorkflowSchedulePlan, or None if no schedule node
"""
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
schedule_config = ScheduleService.extract_schedule_config(workflow)
existing_plan = session.scalar(
@ -59,7 +59,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
if existing_plan:
logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id)
ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id)
session.commit()
return None
if existing_plan:
@ -73,7 +72,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
schedule_id=existing_plan.id,
updates=updates,
)
session.commit()
return updated_plan
else:
new_plan = ScheduleService.create_schedule(
@ -82,5 +80,4 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
app_id=app_id,
config=schedule_config,
)
session.commit()
return new_plan

View File

@ -1,7 +1,7 @@
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from core.trigger.constants import TRIGGER_NODE_TYPES
from events.app_event import app_published_workflow_was_updated
@ -31,7 +31,7 @@ def handle(sender, **kwargs):
# Extract trigger info from workflow
trigger_infos = get_trigger_infos_from_workflow(published_workflow)
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
# Get existing app triggers
existing_triggers = (
session.execute(
@ -79,8 +79,6 @@ def handle(sender, **kwargs):
existing_trigger.title = new_title
session.add(existing_trigger)
session.commit()
def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]:
"""

View File

@ -354,11 +354,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
stmt = select(WorkflowRun).where(
WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
)
@ -439,11 +439,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore."""
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
return session.scalar(stmt)

View File

@ -18,7 +18,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from graphon.file import helpers as file_helpers
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
@ -524,7 +524,7 @@ class App(Base):
if not api_provider_ids and not builtin_provider_ids:
return []
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
if api_provider_ids:
existing_api_providers = [
str(api_provider.id)

View File

@ -92,7 +92,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
# ============ Step 3: Delete metadata binding (separate short transaction) ============
try:
with session_factory.create_session() as session:
deleted_count = (
deleted_count = int(
session.query(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == dataset_id,

View File

@ -12,7 +12,7 @@ import click
from celery import group, shared_task
from flask import current_app, g
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
@ -131,7 +131,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Load required entities
account = session.scalar(select(Account).where(Account.id == user_id).limit(1))
if not account:

View File

@ -102,14 +102,16 @@ class TestEnterpriseAppDSLImport:
@pytest.fixture
def _mock_import_deps(self):
"""Patch db, Session, and AppDslService for import handler tests."""
"""Patch db, sessionmaker, and AppDslService for import handler tests."""
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock())
mock_session_ctx.__exit__ = MagicMock(return_value=False)
mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx)))
with (
patch("controllers.inner_api.app.dsl.db"),
patch("controllers.inner_api.app.dsl.Session") as mock_session,
patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker),
patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls,
):
mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock())
mock_session.return_value.__exit__ = MagicMock(return_value=False)
self._mock_dsl = MagicMock()
mock_dsl_cls.return_value = self._mock_dsl
yield

View File

@ -690,8 +690,8 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
mock_db.engine = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=None)
Session = MagicMock(return_value=mock_session)
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session)))
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker):
mock_session.scalar.return_value = None # No existing plan
# Mock extract_schedule_config to return a ScheduleConfig object
@ -709,7 +709,7 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
assert result == mock_new_plan
mock_service.create_schedule.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.commit.assert_not_called()
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
@ -720,9 +720,9 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
mock_db.engine = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=None)
Session = MagicMock(return_value=mock_session)
sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session)))
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker):
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
mock_existing_plan.id = "existing-plan-id"
mock_session.scalar.return_value = mock_existing_plan
@ -751,7 +751,7 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
assert updates_obj.node_id == "start"
assert updates_obj.cron_expression == "0 12 * * *"
assert updates_obj.timezone == "America/New_York"
mock_session.commit.assert_called_once()
mock_session.commit.assert_not_called()
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
@ -762,9 +762,9 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
mock_db.engine = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=None)
Session = MagicMock(return_value=mock_session)
sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session)))
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker):
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
mock_existing_plan.id = "existing-plan-id"
mock_session.scalar.return_value = mock_existing_plan
@ -777,7 +777,7 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
assert result is None
# Now using ScheduleService.delete_schedule instead of session.delete
mock_service.delete_schedule.assert_called_once_with(session=mock_session, schedule_id="existing-plan-id")
mock_session.commit.assert_called_once()
mock_session.commit.assert_not_called()
@pytest.fixture