diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index ec39dfdaaf..6e14d996ea 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Mapping from typing import Any, TypedDict -from sqlalchemy import desc, func +from sqlalchemy import delete, desc, func, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -73,27 +73,28 @@ class TriggerProviderService: workflows_in_use_map: dict[str, int] = {} with Session(db.engine, expire_on_commit=False) as session: # Get all subscriptions - subscriptions_db = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) + subscriptions_db = session.scalars( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + ) .order_by(desc(TriggerSubscription.created_at)) - .all() - ) + ).all() subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db] if not subscriptions: return [] - usage_counts = ( - session.query( + usage_counts = session.execute( + select( WorkflowPluginTrigger.subscription_id, func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"), ) - .filter( + .where( WorkflowPluginTrigger.tenant_id == tenant_id, WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]), ) .group_by(WorkflowPluginTrigger.subscription_id) - .all() - ) + ).all() workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts} provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) @@ -156,9 +157,13 @@ class TriggerProviderService: with redis_client.lock(lock_key, timeout=20): # Check provider count limit provider_count = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) - .count() + session.scalar( + select(func.count(TriggerSubscription.id)).where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + ) + ) + or 0 ) if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__: @@ -168,10 +173,14 @@ class TriggerProviderService: ) # Check if name already exists - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() + existing = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + TriggerSubscription.name == name, + ) + .limit(1) ) if existing: raise ValueError(f"Credential name '{name}' already exists for this provider") @@ -248,8 +257,13 @@ class TriggerProviderService: # Use distributed lock to prevent race conditions on the same subscription lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}" with redis_client.lock(lock_key, timeout=20): - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if not subscription: raise ValueError(f"Trigger subscription {subscription_id} not found") @@ -259,10 +273,14 @@ class TriggerProviderService: # Check for name uniqueness if name is being updated if name is not None and name != subscription.name: - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() + existing = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + TriggerSubscription.name == name, + ) + .limit(1) ) if existing: raise ValueError(f"Subscription name '{name}' already exists for this provider") @@ -320,11 +338,18 @@ class TriggerProviderService: with Session(db.engine, expire_on_commit=False) as session: subscription: TriggerSubscription | None = None if subscription_id: - subscription = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) else: - subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first() + subscription = session.scalar( + select(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant_id).limit(1) + ) if subscription: provider_controller = TriggerManager.get_trigger_provider( tenant_id, TriggerProviderID(subscription.provider_id) @@ -353,8 +378,13 @@ class TriggerProviderService: :param subscription_id: Subscription instance ID :return: Success response """ - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -406,7 +436,14 @@ class TriggerProviderService: :return: New token info """ with sessionmaker(bind=db.engine).begin() as session: - subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) + ) if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -479,8 +516,13 @@ class TriggerProviderService: now_ts: int = int(now if now is not None else _time.time()) with sessionmaker(bind=db.engine).begin() as session: - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if subscription is None: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -556,15 +598,15 @@ class TriggerProviderService: tenant_id=tenant_id, provider_id=provider_id ) with Session(db.engine, expire_on_commit=False) as session: - tenant_client: TriggerOAuthTenantClient | None = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - enabled=True, + tenant_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) oauth_params: Mapping[str, Any] | None = None @@ -582,10 +624,13 @@ class TriggerProviderService: return None # Check for system-level OAuth client - system_client: TriggerOAuthSystemClient | None = ( - session.query(TriggerOAuthSystemClient) - .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) - .first() + system_client = session.scalar( + select(TriggerOAuthSystemClient) + .where( + TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id, + TriggerOAuthSystemClient.provider == provider_id.provider_name, + ) + .limit(1) ) if system_client: @@ -606,10 +651,13 @@ class TriggerProviderService: if not is_verified: return False with Session(db.engine, expire_on_commit=False) as session: - system_client: TriggerOAuthSystemClient | None = ( - session.query(TriggerOAuthSystemClient) - .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) - .first() + system_client = session.scalar( + select(TriggerOAuthSystemClient) + .where( + TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id, + TriggerOAuthSystemClient.provider == provider_id.provider_name, + ) + .limit(1) ) return system_client is not None @@ -640,14 +688,14 @@ class TriggerProviderService: with sessionmaker(bind=db.engine).begin() as session: # Find existing custom client params - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, ) - .first() + .limit(1) ) # Create new record if doesn't exist @@ -694,14 +742,14 @@ class TriggerProviderService: :return: Masked OAuth client parameters """ with Session(db.engine) as session: - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, ) - .first() + .limit(1) ) if custom_client is None: @@ -731,11 +779,15 @@ class TriggerProviderService: :return: Success response """ with sessionmaker(bind=db.engine).begin() as session: - session.query(TriggerOAuthTenantClient).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ).delete() + session.execute( + delete(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + ) + .execution_options(synchronize_session=False) + ) return {"result": "success"} @@ -749,15 +801,15 @@ class TriggerProviderService: :return: True if enabled, False otherwise """ with Session(db.engine, expire_on_commit=False) as session: - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, - enabled=True, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) return custom_client is not None @@ -767,7 +819,9 @@ class TriggerProviderService: Get a trigger subscription by the endpoint ID. """ with Session(db.engine, expire_on_commit=False) as session: - subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first() + subscription = session.scalar( + select(TriggerSubscription).where(TriggerSubscription.endpoint_id == endpoint_id).limit(1) + ) if not subscription: return None provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py index 350ff718c1..bd2e936b62 100644 --- a/api/tests/unit_tests/services/test_trigger_provider_service.py +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -124,9 +124,7 @@ def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_su provider_id: TriggerProviderID, ) -> None: # Arrange - query = MagicMock() - query.filter_by.return_value.order_by.return_value.all.return_value = [] - mock_session.query.return_value = query + mock_session.scalars.return_value.all.return_value = [] # Act result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) @@ -152,11 +150,8 @@ def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workf db_sub = SimpleNamespace(to_api_entity=lambda: api_sub) usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2) - query_subs = MagicMock() - query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub] - query_usage = MagicMock() - query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row] - mock_session.query.side_effect = [query_subs, query_usage] + mock_session.scalars.return_value.all.return_value = [db_sub] + mock_session.execute.return_value.all.return_value = [usage_row] _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"}) @@ -188,11 +183,7 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap ) -> None: # Arrange _patch_redis_lock(mocker) - query_count = MagicMock() - query_count.filter_by.return_value.count.return_value = 0 - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = None - mock_session.query.side_effect = [query_count, query_existing] + mock_session.scalar.side_effect = [0, None] # count=0, no existing name _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(encrypted={"api_key": "enc"}) @@ -228,11 +219,7 @@ def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorize ) -> None: # Arrange _patch_redis_lock(mocker) - query_count = MagicMock() - query_count.filter_by.return_value.count.return_value = 0 - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = None - mock_session.query.side_effect = [query_count, query_existing] + mock_session.scalar.side_effect = [0, None] # count=0, no existing name _mock_get_trigger_provider(mocker, provider_controller) prop_enc = _encrypter_mock(encrypted={"p": "enc"}) @@ -267,9 +254,7 @@ def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached ) -> None: # Arrange _patch_redis_lock(mocker) - query_count = MagicMock() - query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__ - mock_session.query.return_value = query_count + mock_session.scalar.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__ _mock_get_trigger_provider(mocker, provider_controller) mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger") @@ -297,11 +282,7 @@ def test_add_trigger_subscription_should_raise_error_when_name_exists( ) -> None: # Arrange _patch_redis_lock(mocker) - query_count = MagicMock() - query_count.filter_by.return_value.count.return_value = 0 - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = object() - mock_session.query.side_effect = [query_count, query_existing] + mock_session.scalar.side_effect = [0, object()] # count=0, existing name conflict _mock_get_trigger_provider(mocker, provider_controller) # Act + Assert @@ -325,9 +306,7 @@ def test_update_trigger_subscription_should_raise_error_when_subscription_not_fo ) -> None: # Arrange _patch_redis_lock(mocker) - query_sub = MagicMock() - query_sub.filter_by.return_value.first.return_value = None - mock_session.query.return_value = query_sub + mock_session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="not found"): @@ -347,11 +326,7 @@ def test_update_trigger_subscription_should_raise_error_when_name_conflicts( provider_id="langgenius/github/github", credential_type=CredentialType.API_KEY.value, ) - query_sub = MagicMock() - query_sub.filter_by.return_value.first.return_value = subscription - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = object() - mock_session.query.side_effect = [query_sub, query_existing] + mock_session.scalar.side_effect = [subscription, object()] # found sub, name conflict _mock_get_trigger_provider(mocker, provider_controller) # Act + Assert @@ -378,11 +353,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache( credential_expires_at=0, expires_at=0, ) - query_sub = MagicMock() - query_sub.filter_by.return_value.first.return_value = subscription - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = None - mock_session.query.side_effect = [query_sub, query_existing] + mock_session.scalar.side_effect = [subscription, None] # found sub, no name conflict _mock_get_trigger_provider(mocker, provider_controller) prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"}) @@ -417,7 +388,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache( def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") @@ -439,7 +410,7 @@ def test_get_subscription_by_id_should_decrypt_credentials_and_properties( credentials={"token": "enc"}, properties={"project": "enc"}, ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(decrypted={"token": "plain"}) prop_enc = _encrypter_mock(decrypted={"project": "plain"}) @@ -466,7 +437,7 @@ def test_delete_trigger_provider_should_raise_error_when_subscription_missing( mock_session: MagicMock, ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="not found"): @@ -488,7 +459,7 @@ def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscri credentials={"token": "enc"}, to_entity=lambda: SimpleNamespace(id="sub-1"), ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(decrypted={"token": "plain"}) mocker.patch( @@ -524,7 +495,7 @@ def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized( credentials={}, to_entity=lambda: SimpleNamespace(id="sub-2"), ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger") mocker.patch( @@ -544,7 +515,7 @@ def test_refresh_oauth_token_should_raise_error_when_subscription_missing( mocker: MockerFixture, mock_session: MagicMock ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="not found"): @@ -556,7 +527,7 @@ def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials( ) -> None: # Arrange subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription # Act + Assert with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"): @@ -577,7 +548,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials( credentials={"access_token": "enc"}, credential_expires_at=0, ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) cache = MagicMock() cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"}) @@ -606,7 +577,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing( mocker: MockerFixture, mock_session: MagicMock ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="not found"): @@ -616,7 +587,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing( def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None: # Arrange subscription = SimpleNamespace(expires_at=200) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription # Act result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) @@ -643,7 +614,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties( credentials={"c": "enc"}, credential_type=CredentialType.API_KEY.value, ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(decrypted={"c": "plain"}) prop_cache = MagicMock() @@ -681,10 +652,7 @@ def test_get_oauth_client_should_return_tenant_client_when_available( ) -> None: # Arrange tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"}) - system_client = None - query_tenant = MagicMock() - query_tenant.filter_by.return_value.first.return_value = tenant_client - mock_session.query.return_value = query_tenant + mock_session.scalar.return_value = tenant_client _mock_get_trigger_provider(mocker, provider_controller) enc = _encrypter_mock(decrypted={"client_id": "plain"}) mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) @@ -703,11 +671,7 @@ def test_get_oauth_client_should_return_none_when_plugin_not_verified( provider_controller: MagicMock, ) -> None: # Arrange - query_tenant = MagicMock() - query_tenant.filter_by.return_value.first.return_value = None - query_system = MagicMock() - query_system.filter_by.return_value.first.return_value = None - mock_session.query.side_effect = [query_tenant, query_system] + mock_session.scalar.return_value = None # no tenant client; plugin not verified → early return _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) @@ -725,11 +689,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified( provider_controller: MagicMock, ) -> None: # Arrange - query_tenant = MagicMock() - query_tenant.filter_by.return_value.first.return_value = None - query_system = MagicMock() - query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") - mock_session.query.side_effect = [query_tenant, query_system] + mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")] _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) mocker.patch( @@ -751,11 +711,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails( provider_controller: MagicMock, ) -> None: # Arrange - query_tenant = MagicMock() - query_tenant.filter_by.return_value.first.return_value = None - query_system = MagicMock() - query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") - mock_session.query.side_effect = [query_tenant, query_system] + mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")] _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) mocker.patch( @@ -794,7 +750,7 @@ def test_is_oauth_system_client_exists_should_reflect_database_record( provider_controller: MagicMock, ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None + mock_session.scalar.return_value = object() if has_client else None _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) @@ -823,11 +779,11 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w provider_controller: MagicMock, ) -> None: # Arrange - query = MagicMock() - query.filter_by.return_value.first.return_value = None - mock_session.query.return_value = query + mock_session.scalar.return_value = None _mock_get_trigger_provider(mocker, provider_controller) fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={}) + # Also mock select() so SQLAlchemy doesn't validate the patched TriggerOAuthTenantClient. + mocker.patch("services.trigger.trigger_provider_service.select", MagicMock(return_value=MagicMock())) mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model) # Act @@ -853,7 +809,7 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c ) -> None: # Arrange custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False) - mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + mock_session.scalar.return_value = custom_client _mock_get_trigger_provider(mocker, provider_controller) cache = MagicMock() enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"}) @@ -882,7 +838,7 @@ def test_get_custom_oauth_client_params_should_return_empty_when_record_missing( provider_id: TriggerProviderID, ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) @@ -899,7 +855,7 @@ def test_get_custom_oauth_client_params_should_return_masked_decrypted_values( ) -> None: # Arrange custom_client = SimpleNamespace(oauth_params={"client_id": "enc"}) - mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + mock_session.scalar.return_value = custom_client _mock_get_trigger_provider(mocker, provider_controller) enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"}) mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) @@ -916,9 +872,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit( mock_session: MagicMock, provider_id: TriggerProviderID, ) -> None: - # Arrange - mock_session.query.return_value.filter_by.return_value.delete.return_value = 1 - # Act result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id) @@ -934,7 +887,7 @@ def test_is_oauth_custom_client_enabled_should_return_expected_boolean( provider_id: TriggerProviderID, ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None + mock_session.scalar.return_value = object() if exists else None # Act result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id) @@ -947,7 +900,7 @@ def test_get_subscription_by_endpoint_should_return_none_when_not_found( mocker: MockerFixture, mock_session: MagicMock ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") @@ -968,7 +921,7 @@ def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties( credentials={"token": "enc"}, properties={"hook": "enc"}, ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) mocker.patch( "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription",