refactor: use sessionmaker in small services 2 (#34696)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
carlos4s 2026-04-08 00:06:50 -05:00 committed by GitHub
parent 909c062ee1
commit a65e1f71b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 86 additions and 53 deletions

View File

@ -9,7 +9,7 @@ from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from sqlalchemy import delete, func, select, update from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
class InvitationData(TypedDict): class InvitationData(TypedDict):
@ -1516,7 +1516,7 @@ class RegisterService:
check_workspace_member_invite_permission(tenant.id) check_workspace_member_invite_permission(tenant.id)
with Session(db.engine) as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session) account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account: if not account:

View File

@ -11,7 +11,7 @@ from typing import Any, Union
from celery.result import AsyncResult from celery.result import AsyncResult
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from enums.quota_type import QuotaType from enums.quota_type import QuotaType
from extensions.ext_database import db from extensions.ext_database import db
@ -237,7 +237,7 @@ class AsyncWorkflowService:
Returns: Returns:
Trigger log as dictionary or None if not found Trigger log as dictionary or None if not found
""" """
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id) trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
@ -263,7 +263,7 @@ class AsyncWorkflowService:
Returns: Returns:
List of trigger logs as dictionaries List of trigger logs as dictionaries
""" """
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_recent_logs( logs = trigger_log_repo.get_recent_logs(
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
@ -286,7 +286,7 @@ class AsyncWorkflowService:
Returns: Returns:
List of failed trigger logs as dictionaries List of failed trigger logs as dictionaries
""" """
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_failed_for_retry( logs = trigger_log_repo.get_failed_for_retry(
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit

View File

@ -346,7 +346,7 @@ class ClearFreePlanTenantExpiredLogs:
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
current_time = started_at current_time = started_at
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
total_tenant_count = session.query(Tenant.id).count() total_tenant_count = session.query(Tenant.id).count()
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
@ -398,7 +398,7 @@ class ClearFreePlanTenantExpiredLogs:
# Initial interval of 1 day, will be dynamically adjusted based on tenant count # Initial interval of 1 day, will be dynamically adjusted based on tenant count
interval = datetime.timedelta(days=1) interval = datetime.timedelta(days=1)
# Process tenants in this batch # Process tenants in this batch
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
# Calculate tenant count in next batch with current interval # Calculate tenant count in next batch with current interval
# Try different intervals until we find one with a reasonable tenant count # Try different intervals until we find one with a reasonable tenant count
test_intervals = [ test_intervals = [

View File

@ -1,7 +1,7 @@
import logging import logging
from sqlalchemy import select, update from sqlalchemy import select, update
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from configs import dify_config from configs import dify_config
from core.errors.error import QuotaExceededError from core.errors.error import QuotaExceededError
@ -71,7 +71,7 @@ class CreditPoolService:
actual_credits = min(credits_required, pool.remaining_credits) actual_credits = min(credits_required, pool.remaining_credits)
try: try:
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
stmt = ( stmt = (
update(TenantCreditPool) update(TenantCreditPool)
.where( .where(
@ -81,7 +81,6 @@ class CreditPoolService:
.values(quota_used=TenantCreditPool.quota_used + actual_credits) .values(quota_used=TenantCreditPool.quota_used + actual_credits)
) )
session.execute(stmt) session.execute(stmt)
session.commit()
except Exception: except Exception:
logger.exception("Failed to deduct credits for tenant %s", tenant_id) logger.exception("Failed to deduct credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits") raise QuotaExceededError("Failed to deduct credits")

View File

@ -15,7 +15,7 @@ from graphon.model_runtime.entities.model_entities import ModelFeature, ModelTyp
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from redis.exceptions import LockNotOwnedError from redis.exceptions import LockNotOwnedError
from sqlalchemy import delete, exists, func, select, update from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
@ -551,7 +551,7 @@ class DatasetService:
external_knowledge_id: External knowledge identifier external_knowledge_id: External knowledge identifier
external_knowledge_api_id: External knowledge API identifier external_knowledge_api_id: External knowledge API identifier
""" """
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
external_knowledge_binding = ( external_knowledge_binding = (
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
) )
@ -559,14 +559,14 @@ class DatasetService:
if not external_knowledge_binding: if not external_knowledge_binding:
raise ValueError("External knowledge binding not found.") raise ValueError("External knowledge binding not found.")
# Update binding if values have changed # Update binding if values have changed
if ( if (
external_knowledge_binding.external_knowledge_id != external_knowledge_id external_knowledge_binding.external_knowledge_id != external_knowledge_id
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
): ):
external_knowledge_binding.external_knowledge_id = external_knowledge_id external_knowledge_binding.external_knowledge_id = external_knowledge_id
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
db.session.add(external_knowledge_binding) session.add(external_knowledge_binding)
@staticmethod @staticmethod
def _update_internal_dataset(dataset, data, user): def _update_internal_dataset(dataset, data, user):

View File

@ -2,7 +2,7 @@ import enum
import uuid import uuid
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from extensions.ext_database import db from extensions.ext_database import db
@ -29,7 +29,7 @@ class OAuthServerService:
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None: def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id) query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
with Session(db.engine) as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return session.execute(query).scalar_one_or_none() return session.execute(query).scalar_one_or_none()
@staticmethod @staticmethod

View File

@ -1182,7 +1182,7 @@ class RagPipelineService:
workflow = db.session.get(Workflow, pipeline.workflow_id) workflow = db.session.get(Workflow, pipeline.workflow_id)
if not workflow: if not workflow:
raise ValueError("Workflow not found") raise ValueError("Workflow not found")
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
dataset = pipeline.retrieve_dataset(session=session) dataset = pipeline.retrieve_dataset(session=session)
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
@ -1209,7 +1209,7 @@ class RagPipelineService:
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
rag_pipeline_dsl_service = RagPipelineDslService(session) rag_pipeline_dsl_service = RagPipelineDslService(session)
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
if args.get("icon_info") is None: if args.get("icon_info") is None:

View File

@ -834,7 +834,7 @@ class WorkflowService:
if workflow_node_execution is None: if workflow_node_execution is None:
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
outputs = workflow_node_execution.load_full_outputs(session, storage) outputs = workflow_node_execution.load_full_outputs(session, storage)
with Session(bind=db.engine) as session, session.begin(): with Session(bind=db.engine) as session, session.begin():

View File

@ -1427,16 +1427,18 @@ class TestRegisterService:
mock_tenant.name = "Test Workspace" mock_tenant.name = "Test Workspace"
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
# Mock database queries - need to mock the Session query # Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock() mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with ( with (
patch("services.account_service.Session") as mock_session_class, patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
): ):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = None mock_lookup.return_value = None
# Mock RegisterService.register # Mock RegisterService.register
@ -1485,12 +1487,14 @@ class TestRegisterService:
mixed_email = "Invitee@Example.com" mixed_email = "Invitee@Example.com"
mock_session = MagicMock() mock_session = MagicMock()
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with ( with (
patch("services.account_service.Session") as mock_session_class, patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
): ):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = None mock_lookup.return_value = None
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock( mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
@ -1541,16 +1545,18 @@ class TestRegisterService:
account_id="existing-user-456", email="existing@example.com", status="pending" account_id="existing-user-456", email="existing@example.com", status="pending"
) )
# Mock database queries - need to mock the Session query # Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock() mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with ( with (
patch("services.account_service.Session") as mock_session_class, patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
): ):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = mock_existing_account mock_lookup.return_value = mock_existing_account
# Mock scalar for TenantAccountJoin lookup - no existing member # Mock scalar for TenantAccountJoin lookup - no existing member

View File

@ -357,11 +357,12 @@ class TestAsyncWorkflowService:
mock_session_context.__enter__.return_value = mock_session mock_session_context.__enter__.return_value = mock_session
mock_session_context.__exit__.return_value = None mock_session_context.__exit__.return_value = None
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = mock_session_context
with ( with (
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)), patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
patch.object( patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
async_workflow_service_module, "Session", return_value=mock_session_context
) as mock_session_class,
patch.object( patch.object(
async_workflow_service_module, async_workflow_service_module,
"SQLAlchemyWorkflowTriggerLogRepository", "SQLAlchemyWorkflowTriggerLogRepository",
@ -373,7 +374,7 @@ class TestAsyncWorkflowService:
# Assert # Assert
assert result == expected assert result == expected
mock_session_class.assert_called_once_with(fake_engine) mock_sessionmaker.assert_called_once_with(fake_engine)
mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123") mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123")
def test_should_return_recent_logs_as_dict_list(self): def test_should_return_recent_logs_as_dict_list(self):
@ -391,9 +392,12 @@ class TestAsyncWorkflowService:
mock_session_context.__enter__.return_value = mock_session mock_session_context.__enter__.return_value = mock_session
mock_session_context.__exit__.return_value = None mock_session_context.__exit__.return_value = None
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = mock_session_context
with ( with (
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())), patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context), patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
patch.object( patch.object(
async_workflow_service_module, async_workflow_service_module,
"SQLAlchemyWorkflowTriggerLogRepository", "SQLAlchemyWorkflowTriggerLogRepository",
@ -432,9 +436,12 @@ class TestAsyncWorkflowService:
mock_session_context.__enter__.return_value = mock_session mock_session_context.__enter__.return_value = mock_session
mock_session_context.__exit__.return_value = None mock_session_context.__exit__.return_value = None
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = mock_session_context
with ( with (
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())), patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context), patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
patch.object( patch.object(
async_workflow_service_module, async_workflow_service_module,
"SQLAlchemyWorkflowTriggerLogRepository", "SQLAlchemyWorkflowTriggerLogRepository",

View File

@ -209,8 +209,22 @@ def _session_wrapper_for_no_autoflush(session: Mock) -> Mock:
return wrapper return wrapper
def _sessionmaker_wrapper_for_begin(session: Mock) -> Mock:
"""
ClearFreePlanTenantExpiredLogs.process uses: with sessionmaker(db.engine).begin() as session:
so sessionmaker(db.engine) must return an object with a begin() method that returns a context manager.
"""
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
begin_cm.__exit__.return_value = None
sessionmaker_result = MagicMock()
sessionmaker_result.begin.return_value = begin_cm
return sessionmaker_result
def _session_wrapper_for_direct(session: Mock) -> Mock: def _session_wrapper_for_direct(session: Mock) -> Mock:
"""ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session:""" """ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session: (for old code paths)"""
wrapper = MagicMock() wrapper = MagicMock()
wrapper.__enter__.return_value = session wrapper.__enter__.return_value = session
wrapper.__exit__.return_value = None wrapper.__exit__.return_value = None
@ -348,7 +362,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py
count_query.count.return_value = 2 count_query.count.return_value = 2
count_session.query.return_value = count_query count_session.query.return_value = count_query
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session)) monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
# Avoid LocalProxy usage # Avoid LocalProxy usage
flask_app = service_module.Flask("test-app") flask_app = service_module.Flask("test-app")
@ -438,8 +452,8 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs] batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)] sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0)) monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
process_tenant_mock = MagicMock() process_tenant_mock = MagicMock()
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
@ -457,7 +471,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo
count_query = MagicMock() count_query = MagicMock()
count_query.count.return_value = 100 count_query.count.return_value = 100
count_session.query.return_value = count_query count_session.query.return_value = count_query
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session)) monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
flask_app = service_module.Flask("test-app") flask_app = service_module.Flask("test-app")
monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app))
@ -523,8 +537,8 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
batch_session.query.side_effect = [*count_queries, q_rs] batch_session.query.side_effect = [*count_queries, q_rs]
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)] sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0)) monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
process_tenant_mock = MagicMock() process_tenant_mock = MagicMock()
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)

View File

@ -578,26 +578,33 @@ class TestDatasetServiceCreationAndUpdate:
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api") binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
session = MagicMock() session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = binding session.query.return_value.filter_by.return_value.first.return_value = binding
session.add = MagicMock()
session_context = _make_session_context(session) session_context = _make_session_context(session)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = session_context
with ( with (
patch("services.dataset_service.db") as mock_db, patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.Session", return_value=session_context), patch("services.dataset_service.sessionmaker", mock_sessionmaker),
): ):
DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api") DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api")
assert binding.external_knowledge_id == "new-knowledge" assert binding.external_knowledge_id == "new-knowledge"
assert binding.external_knowledge_api_id == "new-api" assert binding.external_knowledge_api_id == "new-api"
mock_db.session.add.assert_called_once_with(binding) session.add.assert_called_once_with(binding)
def test_update_external_knowledge_binding_raises_for_missing_binding(self): def test_update_external_knowledge_binding_raises_for_missing_binding(self):
session = MagicMock() session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = None session.query.return_value.filter_by.return_value.first.return_value = None
session_context = _make_session_context(session) session_context = _make_session_context(session)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = session_context
with ( with (
patch("services.dataset_service.db"), patch("services.dataset_service.db"),
patch("services.dataset_service.Session", return_value=session_context), patch("services.dataset_service.sessionmaker", mock_sessionmaker),
): ):
with pytest.raises(ValueError, match="External knowledge binding not found"): with pytest.raises(ValueError, match="External knowledge binding not found"):
DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1") DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1")