mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor: use sessionmaker in controllers, events, models, and tasks 1 (#34693)
This commit is contained in:
parent
624db69f12
commit
ae9fcc2969
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@ -71,7 +71,7 @@ class AppImportApi(Resource):
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
|
||||
@ -9,7 +9,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
|
||||
|
||||
account.set_tenant_id(workspace_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
dsl_service = AppDslService(session)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource):
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
@ -45,7 +45,7 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
||||
Returns:
|
||||
Updated or created WorkflowSchedulePlan, or None if no schedule node
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
schedule_config = ScheduleService.extract_schedule_config(workflow)
|
||||
|
||||
existing_plan = session.scalar(
|
||||
@ -59,7 +59,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
||||
if existing_plan:
|
||||
logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id)
|
||||
ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id)
|
||||
session.commit()
|
||||
return None
|
||||
|
||||
if existing_plan:
|
||||
@ -73,7 +72,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
||||
schedule_id=existing_plan.id,
|
||||
updates=updates,
|
||||
)
|
||||
session.commit()
|
||||
return updated_plan
|
||||
else:
|
||||
new_plan = ScheduleService.create_schedule(
|
||||
@ -82,5 +80,4 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
||||
app_id=app_id,
|
||||
config=schedule_config,
|
||||
)
|
||||
session.commit()
|
||||
return new_plan
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
@ -31,7 +31,7 @@ def handle(sender, **kwargs):
|
||||
# Extract trigger info from workflow
|
||||
trigger_infos = get_trigger_infos_from_workflow(published_workflow)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Get existing app triggers
|
||||
existing_triggers = (
|
||||
session.execute(
|
||||
@ -79,8 +79,6 @@ def handle(sender, **kwargs):
|
||||
existing_trigger.title = new_title
|
||||
session.add(existing_trigger)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]:
|
||||
"""
|
||||
|
||||
@ -354,11 +354,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
) -> WorkflowRun | None:
|
||||
"""Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
|
||||
)
|
||||
@ -439,11 +439,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
|
||||
"""Fallback to PostgreSQL query for records not in LogStore."""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
|
||||
return session.scalar(stmt)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from graphon.file import helpers as file_helpers
|
||||
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
@ -524,7 +524,7 @@ class App(Base):
|
||||
if not api_provider_ids and not builtin_provider_ids:
|
||||
return []
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
if api_provider_ids:
|
||||
existing_api_providers = [
|
||||
str(api_provider.id)
|
||||
|
||||
@ -92,7 +92,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
# ============ Step 3: Delete metadata binding (separate short transaction) ============
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
deleted_count = (
|
||||
deleted_count = int(
|
||||
session.query(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
|
||||
@ -12,7 +12,7 @@ import click
|
||||
from celery import group, shared_task
|
||||
from flask import current_app, g
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
@ -131,7 +131,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Load required entities
|
||||
account = session.scalar(select(Account).where(Account.id == user_id).limit(1))
|
||||
if not account:
|
||||
|
||||
@ -102,14 +102,16 @@ class TestEnterpriseAppDSLImport:
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_import_deps(self):
|
||||
"""Patch db, Session, and AppDslService for import handler tests."""
|
||||
"""Patch db, sessionmaker, and AppDslService for import handler tests."""
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx)))
|
||||
with (
|
||||
patch("controllers.inner_api.app.dsl.db"),
|
||||
patch("controllers.inner_api.app.dsl.Session") as mock_session,
|
||||
patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker),
|
||||
patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls,
|
||||
):
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
self._mock_dsl = MagicMock()
|
||||
mock_dsl_cls.return_value = self._mock_dsl
|
||||
yield
|
||||
|
||||
@ -690,8 +690,8 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
Session = MagicMock(return_value=mock_session)
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
|
||||
sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session)))
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker):
|
||||
mock_session.scalar.return_value = None # No existing plan
|
||||
|
||||
# Mock extract_schedule_config to return a ScheduleConfig object
|
||||
@ -709,7 +709,7 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
|
||||
assert result == mock_new_plan
|
||||
mock_service.create_schedule.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
|
||||
@ -720,9 +720,9 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
Session = MagicMock(return_value=mock_session)
|
||||
sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session)))
|
||||
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker):
|
||||
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_existing_plan.id = "existing-plan-id"
|
||||
mock_session.scalar.return_value = mock_existing_plan
|
||||
@ -751,7 +751,7 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
assert updates_obj.node_id == "start"
|
||||
assert updates_obj.cron_expression == "0 12 * * *"
|
||||
assert updates_obj.timezone == "America/New_York"
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
|
||||
@ -762,9 +762,9 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
Session = MagicMock(return_value=mock_session)
|
||||
sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session)))
|
||||
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker):
|
||||
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_existing_plan.id = "existing-plan-id"
|
||||
mock_session.scalar.return_value = mock_existing_plan
|
||||
@ -777,7 +777,7 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
assert result is None
|
||||
# Now using ScheduleService.delete_schedule instead of session.delete
|
||||
mock_service.delete_schedule.assert_called_once_with(session=mock_session, schedule_id="existing-plan-id")
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Loading…
Reference in New Issue
Block a user