diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index a6ca0689d0..39c8aaa451 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,6 +1,6 @@ from flask_restx import Resource, marshal from pydantic import BaseModel -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden import services @@ -54,12 +54,13 @@ class CreateRagPipelineDatasetApi(Resource): yaml_content=payload.yaml_content, ) try: - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( tenant_id=current_tenant_id, rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, ) + session.commit() if rag_pipeline_dataset_create_entity.permission == "partial_members": DatasetPermissionService.update_partial_member_list( current_tenant_id, diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index aa27458176..cf92218508 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -1,7 +1,7 @@ from flask import request from flask_restx import Resource, fields, marshal_with # type: ignore from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns @@ -67,10 +67,12 @@ class RagPipelineImportApi(Resource): current_user, _ = current_account_with_tenant() payload = RagPipelineImportPayload.model_validate(console_ns.payload or {}) - # Create service with session - with sessionmaker(db.engine).begin() as session: + # Use a plain Session so that caught exceptions inside the service + # (which return FAILED status instead of re-raising) do not leave the + # transaction in a closed state that a .begin() context manager cannot + # handle. See app_import.py for the canonical pattern. + with Session(db.engine, expire_on_commit=False) as session: import_service = RagPipelineDslService(session) - # Import app account = current_user result = import_service.import_rag_pipeline( account=account, @@ -80,6 +82,10 @@ class RagPipelineImportApi(Resource): pipeline_id=payload.pipeline_id, dataset_name=payload.name, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() # Return appropriate status code based on result status = result.status @@ -102,12 +108,14 @@ class RagPipelineImportConfirmApi(Resource): def post(self, import_id): current_user, _ = current_account_with_tenant() - # Create service with session - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = RagPipelineDslService(session) - # Confirm import account = current_user result = import_service.confirm_import(import_id=import_id, account=account) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() # Return appropriate status code based on result if result.status == ImportStatus.FAILED: @@ -124,7 +132,7 @@ class RagPipelineImportCheckDependenciesApi(Resource): @edit_permission_required @marshal_with(pipeline_import_check_dependencies_model) def get(self, pipeline: Pipeline): - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = RagPipelineDslService(session) result = import_service.check_dependencies(pipeline=pipeline) @@ -142,7 +150,7 @@ class RagPipelineExportApi(Resource): # Add include_secret params query = IncludeSecretQuery.model_validate(request.args.to_dict()) - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: export_service = RagPipelineDslService(session) result = export_service.export_rag_pipeline_dsl( pipeline=pipeline, include_secret=query.include_secret == "true" diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 37ebffbeb4..99fd3f5628 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -78,9 +78,9 @@ class CheckDependenciesPendingData(BaseModel): class RagPipelineDslService: """Import, export, and inspect RAG pipeline DSL using the caller-owned session. - Controllers wrap this service in a SQLAlchemy transaction context, so methods must only flush interim changes when - generated IDs are needed. Committing inside the service would close the caller's transaction and break later work in - the same context manager. + Callers pass a plain ``Session`` (not wrapped in ``.begin()``) and are responsible for calling + ``session.commit()`` on success or ``session.rollback()`` on failure. Methods here only flush + when generated IDs are needed mid-operation; they never commit or rollback. """ def __init__(self, session: Session):