refactor: use sessionmaker().begin() in console app controllers (#34282)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Desel72 2026-03-31 16:10:16 +03:00 committed by GitHub
parent d9a0665b2c
commit cf50d7c7b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 51 additions and 43 deletions

View File

@ -9,7 +9,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo from controllers.common.helpers import FileInfo
@ -642,7 +642,7 @@ class AppCopyApi(Resource):
args = CopyAppPayload.model_validate(console_ns.payload or {}) args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
import_service = AppDslService(session) import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
result = import_service.import_app( result = import_service.import_app(
@ -655,7 +655,6 @@ class AppCopyApi(Resource):
icon=args.icon, icon=args.icon,
icon_background=args.icon_background, icon_background=args.icon_background,
) )
session.commit()
# Inherit web app permission from original app # Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -71,7 +71,7 @@ class AppImportApi(Resource):
args = AppImportPayload.model_validate(console_ns.payload) args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session # Create service with session
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
import_service = AppDslService(session) import_service = AppDslService(session)
# Import app # Import app
account = current_user account = current_user
@ -87,7 +87,6 @@ class AppImportApi(Resource):
icon_background=args.icon_background, icon_background=args.icon_background,
app_id=args.app_id, app_id=args.app_id,
) )
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private # update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
@ -112,12 +111,11 @@ class AppImportConfirmApi(Resource):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
# Create service with session # Create service with session
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
import_service = AppDslService(session) import_service = AppDslService(session)
# Confirm import # Confirm import
account = current_user account = current_user
result = import_service.confirm_import(import_id=import_id, account=account) result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result # Return appropriate status code based on result
if result.status == ImportStatus.FAILED: if result.status == ImportStatus.FAILED:
@ -134,7 +132,7 @@ class AppImportCheckDependenciesApi(Resource):
@marshal_with(app_import_check_dependencies_model) @marshal_with(app_import_check_dependencies_model)
@edit_permission_required @edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
import_service = AppDslService(session) import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model) result = import_service.check_dependencies(app_model=app_model)

View File

@ -2,7 +2,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -69,7 +69,7 @@ class ConversationVariablesApi(Resource):
page_size = 100 page_size = 100
stmt = stmt.limit(page_size).offset((page - 1) * page_size) stmt = stmt.limit(page_size).offset((page - 1) * page_size)
with Session(db.engine) as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
rows = session.scalars(stmt).all() rows = session.scalars(stmt).all()
return { return {

View File

@ -10,7 +10,7 @@ from graphon.file import File
from graphon.graph_engine.manager import GraphEngineManager from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services import services
@ -840,7 +840,7 @@ class PublishedWorkflowApi(Resource):
args = PublishWorkflowPayload.model_validate(console_ns.payload or {}) args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService() workflow_service = WorkflowService()
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
workflow = workflow_service.publish_workflow( workflow = workflow_service.publish_workflow(
session=session, session=session,
app_model=app_model, app_model=app_model,
@ -858,8 +858,6 @@ class PublishedWorkflowApi(Resource):
workflow_created_at = TimestampField().format(workflow.created_at) workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return { return {
"result": "success", "result": "success",
"created_at": workflow_created_at, "created_at": workflow_created_at,
@ -982,7 +980,7 @@ class PublishedAllWorkflowApi(Resource):
raise Forbidden() raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
workflows, has_more = workflow_service.get_all_published_workflow( workflows, has_more = workflow_service.get_all_published_workflow(
session=session, session=session,
app_model=app_model, app_model=app_model,
@ -1072,7 +1070,7 @@ class WorkflowByIdApi(Resource):
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Create a session and manage the transaction # Create a session and manage the transaction
with Session(db.engine, expire_on_commit=False) as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow = workflow_service.update_workflow( workflow = workflow_service.update_workflow(
session=session, session=session,
workflow_id=workflow_id, workflow_id=workflow_id,
@ -1084,9 +1082,6 @@ class WorkflowByIdApi(Resource):
if not workflow: if not workflow:
raise NotFound("Workflow not found") raise NotFound("Workflow not found")
# Commit the transaction in the controller
session.commit()
return workflow return workflow
@setup_required @setup_required
@ -1101,13 +1096,11 @@ class WorkflowByIdApi(Resource):
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Create a session and manage the transaction # Create a session and manage the transaction
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
try: try:
workflow_service.delete_workflow( workflow_service.delete_workflow(
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
) )
# Commit the transaction in the controller
session.commit()
except WorkflowInUseError as e: except WorkflowInUseError as e:
abort(400, description=str(e)) abort(400, description=str(e))
except DraftWorkflowDeletionError as e: except DraftWorkflowDeletionError as e:

View File

@ -5,7 +5,7 @@ from flask import request
from flask_restx import Resource, marshal_with from flask_restx import Resource, marshal_with
from graphon.enums import WorkflowExecutionStatus from graphon.enums import WorkflowExecutionStatus
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session, session=session,
app_model=app_model, app_model=app_model,
@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs( workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session, session=session,
app_model=app_model, app_model=app_model,

View File

@ -10,7 +10,7 @@ from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
from graphon.variables.types import SegmentType from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
@ -244,7 +244,7 @@ class WorkflowVariableCollectionApi(Resource):
raise DraftWorkflowNotExist() raise DraftWorkflowNotExist()
# fetch draft workflow by app_model # fetch draft workflow by app_model
with Session(bind=db.engine, expire_on_commit=False) as session: with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=session, session=session,
) )
@ -298,7 +298,7 @@ class NodeVariableCollectionApi(Resource):
@marshal_with(workflow_draft_variable_list_model) @marshal_with(workflow_draft_variable_list_model)
def get(self, app_model: App, node_id: str): def get(self, app_model: App, node_id: str):
validate_node_id(node_id) validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session: with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=session, session=session,
) )
@ -465,7 +465,7 @@ class VariableResetApi(Resource):
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
with Session(bind=db.engine, expire_on_commit=False) as session: with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=session, session=session,
) )

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from configs import dify_config from configs import dify_config
@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
node_id = args.node_id node_id = args.node_id
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
# Get webhook trigger for this app and node # Get webhook trigger for this app and node
webhook_trigger = ( webhook_trigger = (
session.query(WorkflowWebhookTrigger) session.query(WorkflowWebhookTrigger)
@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
# Get all triggers for this app using select API # Get all triggers for this app using select API
triggers = ( triggers = (
session.execute( session.execute(
@ -137,7 +137,7 @@ class AppTriggerEnableApi(Resource):
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
trigger_id = args.trigger_id trigger_id = args.trigger_id
with Session(db.engine) as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Find the trigger using select # Find the trigger using select
trigger = session.execute( trigger = session.execute(
select(AppTrigger).where( select(AppTrigger).where(
@ -153,9 +153,6 @@ class AppTriggerEnableApi(Resource):
# Update status based on enable_trigger boolean # Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
# Add computed icon field # Add computed icon field
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
if trigger.trigger_type == "trigger-plugin": if trigger.trigger_type == "trigger-plugin":

View File

@ -383,14 +383,21 @@ class TestWorkflowAppLogEndpoints:
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
class DummySession: class DummySessionCtx:
def __enter__(self): def __enter__(self):
return "session" return "session"
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession()) class DummySessionMaker:
def __init__(self, *args, **kwargs):
pass
def begin(self):
return DummySessionCtx()
monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker)
def fake_get_paginate(self, **_kwargs): def fake_get_paginate(self, **_kwargs):
return {"items": [], "total": 0} return {"items": [], "total": 0}
@ -423,13 +430,20 @@ class TestWorkflowDraftVariableEndpoints:
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1")) monkeypatch.setattr(workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1"))
class DummySession: class DummySessionCtx:
def __enter__(self): def __enter__(self):
return "session" return "session"
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
class DummySessionMaker:
def __init__(self, *args, **kwargs):
pass
def begin(self):
return DummySessionCtx()
class DummyDraftService: class DummyDraftService:
def __init__(self, session): def __init__(self, session):
self.session = session self.session = session
@ -437,7 +451,7 @@ class TestWorkflowDraftVariableEndpoints:
def list_variables_without_values(self, **_kwargs): def list_variables_without_values(self, **_kwargs):
return {"items": [], "total": 0} return {"items": [], "total": 0}
monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession()) monkeypatch.setattr(workflow_draft_variable_module, "sessionmaker", DummySessionMaker)
class DummyWorkflowService: class DummyWorkflowService:
def is_workflow_exist(self, *args, **kwargs): def is_workflow_exist(self, *args, **kwargs):
@ -543,14 +557,21 @@ class TestWorkflowTriggerEndpoints:
session = MagicMock() session = MagicMock()
session.query.return_value.where.return_value.first.return_value = trigger session.query.return_value.where.return_value.first.return_value = trigger
class DummySession: class DummySessionCtx:
def __enter__(self): def __enter__(self):
return session return session
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession()) class DummySessionMaker:
def __init__(self, *args, **kwargs):
pass
def begin(self):
return DummySessionCtx()
monkeypatch.setattr(workflow_trigger_module, "sessionmaker", DummySessionMaker)
with app.test_request_context("/?node_id=node-1"): with app.test_request_context("/?node_id=node-1"):
result = method(app_model=SimpleNamespace(id="app-1")) result = method(app_model=SimpleNamespace(id="app-1"))