From d9a0665b2c8bfee584c2caf59f0312d1e5ace4f1 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Tue, 31 Mar 2026 16:09:18 +0300 Subject: [PATCH] refactor: use sessionmaker().begin() in console datasets controllers (#34283) --- .../console/datasets/data_source.py | 6 ++--- .../datasets/rag_pipeline/rag_pipeline.py | 4 +-- .../rag_pipeline/rag_pipeline_datasets.py | 4 +-- .../rag_pipeline_draft_variable.py | 8 +++--- .../rag_pipeline/rag_pipeline_import.py | 12 ++++----- .../rag_pipeline/rag_pipeline_workflow.py | 16 ++++-------- .../console/datasets/test_data_source.py | 26 +++++++++---------- 7 files changed, 34 insertions(+), 42 deletions(-) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index daef4e005a..ac14349045 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -6,7 +6,7 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound from controllers.common.schema import get_or_create_model, register_schema_model @@ -159,7 +159,7 @@ class DataSourceApi(Resource): @account_initialization_required def patch(self, binding_id, action: Literal["enable", "disable"]): binding_id = str(binding_id) - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: data_source_binding = session.execute( select(DataSourceOauthBinding).filter_by(id=binding_id) ).scalar_one_or_none() @@ -211,7 +211,7 @@ class DataSourceNotionListApi(Resource): if not credential: raise NotFound("Credential not found.") exist_page_ids = [] - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: # import notion in the exist dataset if query.dataset_id: dataset = DatasetService.get_dataset(query.dataset_id) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 4f31093cfe..1758bad31d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -3,7 +3,7 @@ import logging from flask import request from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -85,7 +85,7 @@ class CustomizedPipelineTemplateApi(Resource): @account_initialization_required @enterprise_license_required def post(self, template_id: str): - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: template = ( session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() ) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index e65cb19b39..a6ca0689d0 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,6 +1,6 @@ from flask_restx import Resource, marshal from pydantic import BaseModel -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden import services @@ -54,7 +54,7 @@ class CreateRagPipelineDatasetApi(Resource): yaml_content=payload.yaml_content, ) try: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: rag_pipeline_dsl_service = RagPipelineDslService(session) import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( tenant_id=current_tenant_id, diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index f12cbd3495..d635dcb530 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -5,7 +5,7 @@ from flask import Response, request from flask_restx import Resource, marshal, marshal_with from graphon.variables.types import SegmentType from pydantic import BaseModel, Field -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models @@ -96,7 +96,7 @@ class RagPipelineVariableCollectionApi(Resource): raise DraftWorkflowNotExist() # fetch draft workflow by app_model - with Session(bind=db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( session=session, ) @@ -143,7 +143,7 @@ class RagPipelineNodeVariableCollectionApi(Resource): @marshal_with(workflow_draft_variable_list_model) def get(self, pipeline: Pipeline, node_id: str): validate_node_id(node_id) - with Session(bind=db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( session=session, ) @@ -289,7 +289,7 @@ class RagPipelineVariableResetApi(Resource): def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList: - with Session(bind=db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( session=session, ) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index af142b4646..732a6dc446 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -1,7 +1,7 @@ from flask import request from flask_restx import Resource, fields, marshal_with # type: ignore from pydantic import BaseModel, Field -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns @@ -68,7 +68,7 @@ class RagPipelineImportApi(Resource): payload = RagPipelineImportPayload.model_validate(console_ns.payload or {}) # Create service with session - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: import_service = RagPipelineDslService(session) # Import app account = current_user @@ -80,7 +80,6 @@ class RagPipelineImportApi(Resource): pipeline_id=payload.pipeline_id, dataset_name=payload.name, ) - session.commit() # Return appropriate status code based on result status = result.status @@ -102,12 +101,11 @@ class RagPipelineImportConfirmApi(Resource): current_user, _ = current_account_with_tenant() # Create service with session - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: import_service = RagPipelineDslService(session) # Confirm import account = current_user result = import_service.confirm_import(import_id=import_id, account=account) - session.commit() # Return appropriate status code based on result if result.status == ImportStatus.FAILED: @@ -124,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource): @edit_permission_required @marshal_with(pipeline_import_check_dependencies_model) def get(self, pipeline: Pipeline): - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: import_service = RagPipelineDslService(session) result = import_service.check_dependencies(pipeline=pipeline) @@ -142,7 +140,7 @@ class RagPipelineExportApi(Resource): # Add include_secret params query = IncludeSecretQuery.model_validate(request.args.to_dict()) - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: export_service = RagPipelineDslService(session) result = export_service.export_rag_pipeline_dsl( pipeline=pipeline, include_secret=query.include_secret == "true" diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 8efb59a8e9..e08cb155b6 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -6,7 +6,7 @@ from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services @@ -608,7 +608,7 @@ class PublishedRagPipelineApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() rag_pipeline_service = RagPipelineService() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: pipeline = session.merge(pipeline) workflow = rag_pipeline_service.publish_workflow( session=session, @@ -620,8 +620,6 @@ class PublishedRagPipelineApi(Resource): session.add(pipeline) workflow_created_at = TimestampField().format(workflow.created_at) - session.commit() - return { "result": "success", "created_at": workflow_created_at, @@ -695,7 +693,7 @@ class PublishedAllRagPipelineApi(Resource): raise Forbidden() rag_pipeline_service = RagPipelineService() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: workflows, has_more = rag_pipeline_service.get_all_published_workflow( session=session, pipeline=pipeline, @@ -767,7 +765,7 @@ class RagPipelineByIdApi(Resource): rag_pipeline_service = RagPipelineService() # Create a session and manage the transaction - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: workflow = rag_pipeline_service.update_workflow( session=session, workflow_id=workflow_id, @@ -779,9 +777,6 @@ class RagPipelineByIdApi(Resource): if not workflow: raise NotFound("Workflow not found") - # Commit the transaction in the controller - session.commit() - return workflow @setup_required @@ -798,14 +793,13 @@ class RagPipelineByIdApi(Resource): workflow_service = WorkflowService() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: try: workflow_service.delete_workflow( session=session, workflow_id=workflow_id, tenant_id=pipeline.tenant_id, ) - session.commit() except WorkflowInUseError as e: abort(400, description=str(e)) except DraftWorkflowDeletionError as e: diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index 1c07d4ca1c..1c4c6a899f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -102,12 +102,12 @@ class TestDataSourceApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class, patch("controllers.console.datasets.data_source.db.session.add"), patch("controllers.console.datasets.data_source.db.session.commit"), ): mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.execute.return_value.scalar_one_or_none.return_value = binding response, status = method(api, "b1", "enable") @@ -123,12 +123,12 @@ class TestDataSourceApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class, patch("controllers.console.datasets.data_source.db.session.add"), patch("controllers.console.datasets.data_source.db.session.commit"), ): mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.execute.return_value.scalar_one_or_none.return_value = binding response, status = method(api, "b1", "disable") @@ -142,10 +142,10 @@ class TestDataSourceApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class, ): mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.execute.return_value.scalar_one_or_none.return_value = None with pytest.raises(NotFound): @@ -159,10 +159,10 @@ class TestDataSourceApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class, ): mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.execute.return_value.scalar_one_or_none.return_value = binding with pytest.raises(ValueError): @@ -176,10 +176,10 @@ class TestDataSourceApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class, ): mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.execute.return_value.scalar_one_or_none.return_value = binding with pytest.raises(ValueError): @@ -282,7 +282,7 @@ class TestDataSourceNotionListApi: "controllers.console.datasets.data_source.DatasetService.get_dataset", return_value=dataset, ), - patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class, patch( "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", return_value=MagicMock( @@ -292,7 +292,7 @@ class TestDataSourceNotionListApi: ), ): mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.scalars.return_value.all.return_value = [document] response, status = method(api) @@ -315,7 +315,7 @@ class TestDataSourceNotionListApi: "controllers.console.datasets.data_source.DatasetService.get_dataset", return_value=dataset, ), - patch("controllers.console.datasets.data_source.Session"), + patch("controllers.console.datasets.data_source.sessionmaker"), ): with pytest.raises(ValueError): method(api)