From a03d5b8ed365bef4cd9e4c2a46b3ca432aafd5ab Mon Sep 17 00:00:00 2001 From: hj24 Date: Mon, 27 Apr 2026 09:49:40 +0800 Subject: [PATCH] refactor: quota v3 integration (#35436) Co-authored-by: Yansong Zhang <916125788@qq.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/enums/quota_type.py | 188 ------- api/services/app_generate_service.py | 332 ++++++----- api/services/async_workflow_service.py | 44 +- api/services/billing_service.py | 100 +++- api/services/feature_service.py | 2 +- api/services/quota_service.py | 233 ++++++++ api/services/trigger/webhook_service.py | 73 +-- api/services/workflow_service.py | 35 +- api/tasks/trigger_processing_tasks.py | 99 ++-- api/tasks/workflow_schedule_tasks.py | 35 +- .../services/test_app_generate_service.py | 23 +- .../test_webhook_service_relationships.py | 517 ++++++++++++++++++ .../trigger/test_trigger_e2e.py | 4 +- api/tests/unit_tests/enums/__init__.py | 0 api/tests/unit_tests/enums/test_quota_type.py | 349 ++++++++++++ .../services/test_app_generate_service.py | 12 +- .../services/test_async_workflow_service.py | 31 +- .../services/test_billing_service.py | 160 +++++- .../services/test_workflow_service.py | 49 +- .../tasks/test_trigger_processing_tasks.py | 204 +++++++ 20 files changed, 1961 insertions(+), 529 deletions(-) create mode 100644 api/services/quota_service.py create mode 100644 api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py create mode 100644 api/tests/unit_tests/enums/__init__.py create mode 100644 api/tests/unit_tests/enums/test_quota_type.py create mode 100644 api/tests/unit_tests/tasks/test_trigger_processing_tasks.py diff --git a/api/enums/quota_type.py b/api/enums/quota_type.py index 9f511b88ef..a10ac21f69 100644 --- a/api/enums/quota_type.py +++ b/api/enums/quota_type.py @@ -1,56 +1,17 @@ -import logging -from dataclasses import dataclass from enum import StrEnum, auto -logger = logging.getLogger(__name__) - - -@dataclass -class QuotaCharge: - """ - Result of a quota consumption operation. - - Attributes: - success: Whether the quota charge succeeded - charge_id: UUID for refund, or None if failed/disabled - """ - - success: bool - charge_id: str | None - _quota_type: "QuotaType" - - def refund(self) -> None: - """ - Refund this quota charge. - - Safe to call even if charge failed or was disabled. - This method guarantees no exceptions will be raised. - """ - if self.charge_id: - self._quota_type.refund(self.charge_id) - logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id) - class QuotaType(StrEnum): """ Supported quota types for tenant feature usage. - - Add additional types here whenever new billable features become available. """ - # Trigger execution quota TRIGGER = auto() - - # Workflow execution quota WORKFLOW = auto() - UNLIMITED = auto() @property def billing_key(self) -> str: - """ - Get the billing key for the feature. - """ match self: case QuotaType.TRIGGER: return "trigger_event" @@ -58,152 +19,3 @@ class QuotaType(StrEnum): return "api_rate_limit" case _: raise ValueError(f"Invalid quota type: {self}") - - def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge: - """ - Consume quota for the feature. - - Args: - tenant_id: The tenant identifier - amount: Amount to consume (default: 1) - - Returns: - QuotaCharge with success status and charge_id for refund - - Raises: - QuotaExceededError: When quota is insufficient - """ - from configs import dify_config - from services.billing_service import BillingService - from services.errors.app import QuotaExceededError - - if not dify_config.BILLING_ENABLED: - logger.debug("Billing disabled, allowing request for %s", tenant_id) - return QuotaCharge(success=True, charge_id=None, _quota_type=self) - - logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id) - - if amount <= 0: - raise ValueError("Amount to consume must be greater than 0") - - try: - response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount) - - if response.get("result") != "success": - logger.warning( - "Failed to consume quota for %s, feature %s details: %s", - tenant_id, - self.value, - response.get("detail"), - ) - raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount) - - charge_id = response.get("history_id") - logger.debug( - "Successfully consumed %d %s quota for tenant %s, charge_id: %s", - amount, - self.value, - tenant_id, - charge_id, - ) - return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self) - - except QuotaExceededError: - raise - except Exception: - # fail-safe: allow request on billing errors - logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value) - return unlimited() - - def check(self, tenant_id: str, amount: int = 1) -> bool: - """ - Check if tenant has sufficient quota without consuming. - - Args: - tenant_id: The tenant identifier - amount: Amount to check (default: 1) - - Returns: - True if quota is sufficient, False otherwise - """ - from configs import dify_config - - if not dify_config.BILLING_ENABLED: - return True - - if amount <= 0: - raise ValueError("Amount to check must be greater than 0") - - try: - remaining = self.get_remaining(tenant_id) - return remaining >= amount if remaining != -1 else True - except Exception: - logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value) - # fail-safe: allow request on billing errors - return True - - def refund(self, charge_id: str) -> None: - """ - Refund quota using charge_id from consume(). - - This method guarantees no exceptions will be raised. - All errors are logged but silently handled. - - Args: - charge_id: The UUID returned from consume() - """ - try: - from configs import dify_config - from services.billing_service import BillingService - - if not dify_config.BILLING_ENABLED: - return - - if not charge_id: - logger.warning("Cannot refund: charge_id is empty") - return - - logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id) - - response = BillingService.refund_tenant_feature_plan_usage(charge_id) - if response.get("result") == "success": - logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id) - else: - logger.warning("Refund failed for charge_id: %s", charge_id) - - except Exception: - # Catch ALL exceptions - refund must never fail - logger.exception("Failed to refund quota for charge_id: %s", charge_id) - # Don't raise - refund is best-effort and must be silent - - def get_remaining(self, tenant_id: str) -> int: - """ - Get remaining quota for the tenant. - - Args: - tenant_id: The tenant identifier - - Returns: - Remaining quota amount - """ - from services.billing_service import BillingService - - try: - usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key) - # Assuming the API returns a dict with 'remaining' or 'limit' and 'used' - if isinstance(usage_info, dict): - return usage_info.get("remaining", 0) - # If it returns a simple number, treat it as remaining - return int(usage_info) if usage_info else 0 - except Exception: - logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value) - return -1 - - -def unlimited() -> QuotaCharge: - """ - Return a quota charge for unlimited quota. - - This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type. - """ - return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 40013f2b66..d6c01e9dcc 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -4,7 +4,7 @@ import logging import threading import uuid from collections.abc import Callable, Generator, Mapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit from core.app.features.rate_limiting.rate_limit import rate_limit_context from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.db import session_factory -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError +from services.quota_service import QuotaService, unlimited from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task @@ -88,7 +89,7 @@ class AppGenerateService: def generate( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, @@ -106,7 +107,7 @@ class AppGenerateService: quota_charge = unlimited() if dify_config.BILLING_ENABLED: try: - quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id) except QuotaExceededError: raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}") @@ -116,139 +117,150 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) - if app_model.mode == AppMode.COMPLETION: - return rate_limit.generate( - CompletionAppGenerator.convert_to_event_stream( - CompletionAppGenerator().generate( - app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming - ), - ), - request_id=request_id, - ) - elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: - return rate_limit.generate( - AgentChatAppGenerator.convert_to_event_stream( - AgentChatAppGenerator().generate( - app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming - ), - ), - request_id, - ) - elif app_model.mode == AppMode.CHAT: - return rate_limit.generate( - ChatAppGenerator.convert_to_event_stream( - ChatAppGenerator().generate( - app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming - ), - ), - request_id=request_id, - ) - elif app_model.mode == AppMode.ADVANCED_CHAT: - workflow_id = args.get("workflow_id") - workflow = cls._get_workflow(app_model, invoke_from, workflow_id) - - if streaming: - # Streaming mode: subscribe to SSE and enqueue the execution on first subscriber - with rate_limit_context(rate_limit, request_id): - payload = AppExecutionParams.new( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=True, - call_depth=0, - ) - payload_json = payload.model_dump_json() - - def on_subscribe(): - workflow_based_app_execution_task.delay(payload_json) - - on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) - generator = AdvancedChatAppGenerator() + quota_charge.commit() + effective_mode = ( + AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode + ) + match effective_mode: + case AppMode.COMPLETION: return rate_limit.generate( - generator.convert_to_event_stream( - generator.retrieve_events( - AppMode.ADVANCED_CHAT, - payload.workflow_run_id, - on_subscribe=on_subscribe, + CompletionAppGenerator.convert_to_event_stream( + CompletionAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming ), ), request_id=request_id, ) - else: - # Blocking mode: run synchronously and return JSON instead of SSE - # Keep behaviour consistent with WORKFLOW blocking branch. - advanced_generator = AdvancedChatAppGenerator() + case AppMode.AGENT_CHAT: return rate_limit.generate( - advanced_generator.convert_to_event_stream( - advanced_generator.generate( + AgentChatAppGenerator.convert_to_event_stream( + AgentChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming + ), + ), + request_id, + ) + case AppMode.CHAT: + return rate_limit.generate( + ChatAppGenerator.convert_to_event_stream( + ChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming + ), + ), + request_id=request_id, + ) + case AppMode.ADVANCED_CHAT: + workflow_id = args.get("workflow_id") + workflow = cls._get_workflow(app_model, invoke_from, workflow_id) + + if streaming: + # Streaming mode: subscribe to SSE and enqueue the execution on first subscriber + with rate_limit_context(rate_limit, request_id): + payload = AppExecutionParams.new( app_model=app_model, workflow=workflow, user=user, args=args, invoke_from=invoke_from, + streaming=True, + call_depth=0, workflow_run_id=str(uuid.uuid4()), - streaming=False, ) - ), - request_id=request_id, - ) - elif app_model.mode == AppMode.WORKFLOW: - workflow_id = args.get("workflow_id") - workflow = cls._get_workflow(app_model, invoke_from, workflow_id) - if streaming: - with rate_limit_context(rate_limit, request_id): - payload = AppExecutionParams.new( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=True, - call_depth=0, - root_node_id=root_node_id, - workflow_run_id=str(uuid.uuid4()), + payload_json = payload.model_dump_json() + + def on_subscribe(): + workflow_based_app_execution_task.delay(payload_json) + + on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) + generator = AdvancedChatAppGenerator() + return rate_limit.generate( + generator.convert_to_event_stream( + generator.retrieve_events( + AppMode.ADVANCED_CHAT, + payload.workflow_run_id, + on_subscribe=on_subscribe, + ), + ), + request_id=request_id, ) - payload_json = payload.model_dump_json() + else: + # Blocking mode: run synchronously and return JSON instead of SSE + # Keep behaviour consistent with WORKFLOW blocking branch. + pause_config = PauseStateLayerConfig( + session_factory=session_factory.get_session_maker(), + state_owner_user_id=workflow.created_by, + ) + advanced_generator = AdvancedChatAppGenerator() + return rate_limit.generate( + advanced_generator.convert_to_event_stream( + advanced_generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + workflow_run_id=str(uuid.uuid4()), + streaming=False, + pause_state_config=pause_config, + ) + ), + request_id=request_id, + ) + case AppMode.WORKFLOW: + workflow_id = args.get("workflow_id") + workflow = cls._get_workflow(app_model, invoke_from, workflow_id) + if streaming: + with rate_limit_context(rate_limit, request_id): + payload = AppExecutionParams.new( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=True, + call_depth=0, + root_node_id=root_node_id, + workflow_run_id=str(uuid.uuid4()), + ) + payload_json = payload.model_dump_json() - def on_subscribe(): - workflow_based_app_execution_task.delay(payload_json) + def on_subscribe(): + workflow_based_app_execution_task.delay(payload_json) - on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) + on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) + return rate_limit.generate( + WorkflowAppGenerator.convert_to_event_stream( + MessageBasedAppGenerator.retrieve_events( + AppMode.WORKFLOW, + payload.workflow_run_id, + on_subscribe=on_subscribe, + ), + ), + request_id, + ) + + pause_config = PauseStateLayerConfig( + session_factory=session_factory.get_session_maker(), + state_owner_user_id=workflow.created_by, + ) return rate_limit.generate( WorkflowAppGenerator.convert_to_event_stream( - MessageBasedAppGenerator.retrieve_events( - AppMode.WORKFLOW, - payload.workflow_run_id, - on_subscribe=on_subscribe, + WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=False, + root_node_id=root_node_id, + call_depth=0, + pause_state_config=pause_config, ), ), request_id, ) - - pause_config = PauseStateLayerConfig( - session_factory=session_factory.get_session_maker(), - state_owner_user_id=workflow.created_by, - ) - return rate_limit.generate( - WorkflowAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=False, - root_node_id=root_node_id, - call_depth=0, - pause_state_config=pause_config, - ), - ), - request_id, - ) - else: - raise ValueError(f"Invalid app mode {app_model.mode}") + case _: + raise ValueError(f"Invalid app mode {app_model.mode}") except Exception: quota_charge.refund() rate_limit.exit(request_id) @@ -280,53 +292,83 @@ class AppGenerateService: @classmethod def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + match app_model.mode: + case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT: + raise ValueError(f"Invalid app mode {app_model.mode}") + case AppMode.ADVANCED_CHAT: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + AdvancedChatAppGenerator().single_iteration_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) ) - ) - elif app_model.mode == AppMode.WORKFLOW: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + case AppMode.WORKFLOW: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_iteration_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) ) - ) - else: - raise ValueError(f"Invalid app mode {app_model.mode}") + case AppMode.CHANNEL | AppMode.RAG_PIPELINE: + raise ValueError(f"Invalid app mode {app_model.mode}") + case _: + raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod def generate_single_loop( cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True ): - if app_model.mode == AppMode.ADVANCED_CHAT: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - AdvancedChatAppGenerator().single_loop_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + match app_model.mode: + case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT: + raise ValueError(f"Invalid app mode {app_model.mode}") + case AppMode.ADVANCED_CHAT: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + AdvancedChatAppGenerator().single_loop_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) ) - ) - elif app_model.mode == AppMode.WORKFLOW: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().single_loop_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + case AppMode.WORKFLOW: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_loop_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) ) - ) - else: - raise ValueError(f"Invalid app mode {app_model.mode}") + case AppMode.CHANNEL | AppMode.RAG_PIPELINE: + raise ValueError(f"Invalid app mode {app_model.mode}") + case _: + raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod def generate_more_like_this( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[Mapping, Generator]: + ) -> Mapping | Generator: """ Generate more like this :param app_model: app model diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 0133634e5a..63533f6236 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError +from services.quota_service import QuotaService, unlimited from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority from services.workflow_service import WorkflowService @@ -88,7 +89,10 @@ class AsyncWorkflowService: raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}") # 2. Get workflow - workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id) + workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session) + + # commit read only session before starting the billig rpc call + session.commit() # 3. Get dispatcher based on tenant subscription dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id) @@ -131,9 +135,10 @@ class AsyncWorkflowService: trigger_log = trigger_log_repo.create(trigger_log) session.commit() - # 7. Check and consume quota + # 7. Reserve quota (commit after successful dispatch) + quota_charge = unlimited() try: - QuotaType.WORKFLOW.consume(trigger_data.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id) except QuotaExceededError as e: # Update trigger log status trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED @@ -153,13 +158,18 @@ class AsyncWorkflowService: # 9. Dispatch to appropriate queue task_data_dict = task_data.model_dump(mode="json") - task: AsyncResult[Any] | None = None - if queue_name == QueuePriority.PROFESSIONAL: - task = execute_workflow_professional.delay(task_data_dict) - elif queue_name == QueuePriority.TEAM: - task = execute_workflow_team.delay(task_data_dict) - else: # SANDBOX - task = execute_workflow_sandbox.delay(task_data_dict) + try: + task: AsyncResult[Any] | None = None + if queue_name == QueuePriority.PROFESSIONAL: + task = execute_workflow_professional.delay(task_data_dict) + elif queue_name == QueuePriority.TEAM: + task = execute_workflow_team.delay(task_data_dict) + else: # SANDBOX + task = execute_workflow_sandbox.delay(task_data_dict) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise # 10. Update trigger log with task info trigger_log.status = WorkflowTriggerStatus.QUEUED @@ -295,13 +305,21 @@ class AsyncWorkflowService: return [log.to_dict() for log in logs] @staticmethod - def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow: + def _get_workflow( + workflow_service: WorkflowService, + app_model: App, + workflow_id: str | None = None, + session: Session | None = None, + ) -> Workflow: """ Get workflow for the app Args: app_model: App model instance workflow_id: Optional specific workflow ID + session: Reuse this SQLAlchemy session for the lookup when provided, + so the caller's explicit session bears the connection cost + instead of Flask's request-scoped ``db.session``. Returns: Workflow instance @@ -311,12 +329,12 @@ class AsyncWorkflowService: """ if workflow_id: # Get specific published workflow - workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id) + workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session) if not workflow: raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}") else: # Get default published workflow - workflow = workflow_service.get_published_workflow(app_model) + workflow = workflow_service.get_published_workflow(app_model, session=session) if not workflow: raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index a5804d3ab5..6cbc97d8bd 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -25,6 +25,50 @@ class SubscriptionPlan(TypedDict): expiration_date: int +class QuotaReserveResult(TypedDict): + reservation_id: str + available: int + reserved: int + + +class QuotaCommitResult(TypedDict): + available: int + reserved: int + refunded: int + + +class QuotaReleaseResult(TypedDict): + available: int + reserved: int + released: int + + +_quota_reserve_adapter = TypeAdapter(QuotaReserveResult) +_quota_commit_adapter = TypeAdapter(QuotaCommitResult) +_quota_release_adapter = TypeAdapter(QuotaReleaseResult) + + +class _TenantFeatureQuota(TypedDict): + usage: int + limit: int + reset_date: NotRequired[int] + + +class TenantFeatureQuotaInfo(TypedDict): + """Response of /quota/info. + + NOTE (hj24): + - Same convention as BillingInfo: billing may return int fields as str, + always keep non-strict mode to auto-coerce. + """ + + trigger_event: _TenantFeatureQuota + api_rate_limit: _TenantFeatureQuota + + +_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo) + + class _BillingQuota(TypedDict): size: int limit: int @@ -142,13 +186,65 @@ class BillingService: @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): + """Deprecated: Use get_quota_info instead.""" params = {"tenant_id": tenant_id} - usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params) return usage_info @classmethod - def get_knowledge_rate_limit(cls, tenant_id: str): + def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo: + params = {"tenant_id": tenant_id} + return _tenant_feature_quota_info_adapter.validate_python( + cls._send_request("GET", "/quota/info", params=params) + ) + + @classmethod + def quota_reserve( + cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None + ) -> QuotaReserveResult: + """Reserve quota before task execution.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "request_id": request_id, + "amount": amount, + } + if meta: + payload["meta"] = meta + return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload)) + + @classmethod + def quota_commit( + cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None + ) -> QuotaCommitResult: + """Commit a reservation with actual consumption.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + "actual_amount": actual_amount, + } + if meta: + payload["meta"] = meta + return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload)) + + @classmethod + def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult: + """Release a reservation (cancel, return frozen quota).""" + return _quota_release_adapter.validate_python( + cls._send_request( + "POST", + "/quota/release", + json={ + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + }, + ) + ) + + @classmethod + def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict: params = {"tenant_id": tenant_id} knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index df653e0ba7..9216a7fb99 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -281,7 +281,7 @@ class FeatureService: def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id) + features_usage_info = BillingService.get_quota_info(tenant_id) features.billing.enabled = billing_info["enabled"] features.billing.subscription.plan = billing_info["subscription"]["plan"] diff --git a/api/services/quota_service.py b/api/services/quota_service.py new file mode 100644 index 0000000000..4c784315c7 --- /dev/null +++ b/api/services/quota_service.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from enums.quota_type import QuotaType + +logger = logging.getLogger(__name__) + + +@dataclass +class QuotaCharge: + """ + Result of a quota reservation (Reserve phase). + + Lifecycle: + charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id) + try: + do_work() + charge.commit() # Confirm consumption + except: + charge.refund() # Release frozen quota + + If neither commit() nor refund() is called, the billing system's + cleanup CronJob will auto-release the reservation within ~75 seconds. + """ + + success: bool + charge_id: str | None # reservation_id + _quota_type: QuotaType + _tenant_id: str | None = None + _feature_key: str | None = None + _amount: int = 0 + _committed: bool = field(default=False, repr=False) + + def commit(self, actual_amount: int | None = None) -> None: + """ + Confirm the consumption with actual amount. + + Args: + actual_amount: Actual amount consumed. Defaults to the reserved amount. + If less than reserved, the difference is refunded automatically. + """ + if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key: + return + + try: + from services.billing_service import BillingService + + amount = actual_amount if actual_amount is not None else self._amount + BillingService.quota_commit( + tenant_id=self._tenant_id, + feature_key=self._feature_key, + reservation_id=self.charge_id, + actual_amount=amount, + ) + self._committed = True + logger.debug( + "Committed %s quota for tenant %s, reservation_id: %s, amount: %d", + self._quota_type, + self._tenant_id, + self.charge_id, + amount, + ) + except Exception: + logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id) + + def refund(self) -> None: + """ + Release the reserved quota (cancel the charge). + + Safe to call even if: + - charge failed or was disabled (charge_id is None) + - already committed (Release after Commit is a no-op) + - already refunded (idempotent) + + This method guarantees no exceptions will be raised. + """ + if not self.charge_id or not self._tenant_id or not self._feature_key: + return + + QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key) + + +def unlimited() -> QuotaCharge: + from enums.quota_type import QuotaType + + return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) + + +class QuotaService: + """Orchestrates quota reserve / commit / release lifecycle via BillingService.""" + + @staticmethod + def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve + immediate Commit (one-shot mode). + + The returned QuotaCharge supports .refund() which calls Release. + For two-phase usage (e.g. streaming), use reserve() directly. + """ + charge = QuotaService.reserve(quota_type, tenant_id, amount) + if charge.success and charge.charge_id: + charge.commit() + return charge + + @staticmethod + def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve quota before task execution (Reserve phase only). + + The caller MUST call charge.commit() after the task succeeds, + or charge.refund() if the task fails. + + Raises: + QuotaExceededError: When quota is insufficient + """ + from services.billing_service import BillingService + from services.errors.app import QuotaExceededError + + if not dify_config.BILLING_ENABLED: + logger.debug("Billing disabled, allowing request for %s", tenant_id) + return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type) + + logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id) + + if amount <= 0: + raise ValueError("Amount to reserve must be greater than 0") + + request_id = str(uuid.uuid4()) + feature_key = quota_type.billing_key + + try: + reserve_resp = BillingService.quota_reserve( + tenant_id=tenant_id, + feature_key=feature_key, + request_id=request_id, + amount=amount, + ) + + reservation_id = reserve_resp.get("reservation_id") + if not reservation_id: + logger.warning( + "Reserve returned no reservation_id for %s, feature %s, response: %s", + tenant_id, + quota_type.value, + reserve_resp, + ) + raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount) + + logger.debug( + "Reserved %d %s quota for tenant %s, reservation_id: %s", + amount, + quota_type.value, + tenant_id, + reservation_id, + ) + return QuotaCharge( + success=True, + charge_id=reservation_id, + _quota_type=quota_type, + _tenant_id=tenant_id, + _feature_key=feature_key, + _amount=amount, + ) + + except QuotaExceededError: + raise + except ValueError: + raise + except Exception: + logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value) + return unlimited() + + @staticmethod + def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool: + if not dify_config.BILLING_ENABLED: + return True + + if amount <= 0: + raise ValueError("Amount to check must be greater than 0") + + try: + remaining = QuotaService.get_remaining(quota_type, tenant_id) + return remaining >= amount if remaining != -1 else True + except Exception: + logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value) + return True + + @staticmethod + def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None: + """Release a reservation. Guarantees no exceptions.""" + try: + from services.billing_service import BillingService + + if not dify_config.BILLING_ENABLED: + return + + if not reservation_id: + return + + logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id) + BillingService.quota_release( + tenant_id=tenant_id, + feature_key=feature_key, + reservation_id=reservation_id, + ) + except Exception: + logger.exception("Failed to release quota, reservation_id: %s", reservation_id) + + @staticmethod + def get_remaining(quota_type: QuotaType, tenant_id: str) -> int: + from services.billing_service import BillingService + + try: + usage_info = BillingService.get_quota_info(tenant_id) + if isinstance(usage_info, dict): + feature_info = usage_info.get(quota_type.billing_key, {}) + if isinstance(feature_info, dict): + limit = feature_info.get("limit", 0) + usage = feature_info.get("usage", 0) + if limit == -1: + return -1 + return max(0, limit - usage) + return 0 + except Exception: + logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value) + return -1 diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 3c1a4cc747..3fd12ce9cf 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -37,6 +37,7 @@ from models.workflow import Workflow from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData @@ -758,45 +759,47 @@ class WebhookService: Exception: If workflow execution fails """ try: - with Session(db.engine) as session: - # Prepare inputs for the webhook node - # The webhook node expects webhook_data in the inputs - workflow_inputs = cls.build_workflow_inputs(webhook_data) + workflow_inputs = cls.build_workflow_inputs(webhook_data) - # Create trigger data - trigger_data = WebhookTriggerData( - app_id=webhook_trigger.app_id, - workflow_id=workflow.id, - root_node_id=webhook_trigger.node_id, # Start from the webhook node - inputs=workflow_inputs, - tenant_id=webhook_trigger.tenant_id, + trigger_data = WebhookTriggerData( + app_id=webhook_trigger.app_id, + workflow_id=workflow.id, + root_node_id=webhook_trigger.node_id, + inputs=workflow_inputs, + tenant_id=webhook_trigger.tenant_id, + ) + + end_user = EndUserService.get_or_create_end_user_by_type( + type=InvokeFrom.TRIGGER, + tenant_id=webhook_trigger.tenant_id, + app_id=webhook_trigger.app_id, + user_id=None, + ) + + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) + logger.info( + "Tenant %s rate limited, skipping webhook trigger %s", + webhook_trigger.tenant_id, + webhook_trigger.webhook_id, ) + raise - end_user = EndUserService.get_or_create_end_user_by_type( - type=InvokeFrom.TRIGGER, - tenant_id=webhook_trigger.tenant_id, - app_id=webhook_trigger.app_id, - user_id=None, - ) - - # consume quota before triggering workflow execution - try: - QuotaType.TRIGGER.consume(webhook_trigger.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) - logger.info( - "Tenant %s rate limited, skipping webhook trigger %s", - webhook_trigger.tenant_id, - webhook_trigger.webhook_id, + try: + # NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()` + # trigger_workflow_async need to handle multipe session commits internally + with Session(db.engine, expire_on_commit=False) as session: + AsyncWorkflowService.trigger_workflow_async( + session, + end_user, + trigger_data, ) - raise - - # Trigger workflow execution asynchronously - AsyncWorkflowService.trigger_workflow_async( - session, - end_user, - trigger_data, - ) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise except Exception: logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 66976058c0..5b5604aeb6 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -132,31 +132,38 @@ class WorkflowService: if workflow_id: return self.get_published_workflow_by_id(app_model, workflow_id) # fetch draft workflow by app_model - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == Workflow.VERSION_DRAFT, ) - .first() + .limit(1) ) # return draft workflow return workflow - def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None: + def get_published_workflow_by_id( + self, app_model: App, workflow_id: str, session: Session | None = None + ) -> Workflow | None: """ fetch published workflow by workflow_id + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ - workflow = ( - db.session.query(Workflow) + bind = session if session is not None else db.session + workflow = bind.scalar( + select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id, ) - .first() + .limit(1) ) if not workflow: return None @@ -167,23 +174,27 @@ class WorkflowService: ) return workflow - def get_published_workflow(self, app_model: App) -> Workflow | None: + def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None: """ Get published workflow + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ if not app_model.workflow_id: return None - # fetch published workflow by workflow_id - workflow = ( - db.session.query(Workflow) + bind = session if session is not None else db.session + workflow = bind.scalar( + select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == app_model.workflow_id, ) - .first() + .limit(1) ) return workflow diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index f8c7964805..00e08ddfe3 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,6 +12,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -27,8 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.enums import WorkflowExecutionStatus -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from models.enums import ( AppTriggerType, CreatorUserRole, @@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_request_service import TriggerHttpRequestCachingService @@ -258,59 +259,58 @@ def dispatch_triggered_workflow( tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id) ) trigger_entity: TriggerProviderEntity = provider_controller.entity + + # Ensure expire_on_commit is set to False to remain workflows available with session_factory.create_session() as session: workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers) - end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( - type=InvokeFrom.TRIGGER, - tenant_id=subscription.tenant_id, - app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], - user_id=user_id, - ) - for plugin_trigger in subscribers: - # Get workflow from mapping - workflow: Workflow | None = workflows.get(plugin_trigger.app_id) - if not workflow: - logger.error( - "Workflow not found for app %s", - plugin_trigger.app_id, - ) - continue + end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( + type=InvokeFrom.TRIGGER, + tenant_id=subscription.tenant_id, + app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], + user_id=user_id, + ) - # Find the trigger node in the workflow - event_node = None - for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): - if node_id == plugin_trigger.node_id: - event_node = node_config - break - - if not event_node: - logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) - continue - - # invoke trigger - trigger_metadata = PluginTriggerMetadata( - plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", - endpoint_id=subscription.endpoint_id, - provider_id=subscription.provider_id, - event_name=event_name, - icon_filename=trigger_entity.identity.icon or "", - icon_dark_filename=trigger_entity.identity.icon_dark or "", + for plugin_trigger in subscribers: + workflow: Workflow | None = workflows.get(plugin_trigger.app_id) + if not workflow: + logger.error( + "Workflow not found for app %s", + plugin_trigger.app_id, ) + continue - # consume quota before invoking trigger - quota_charge = unlimited() - try: - quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) - logger.info( - "Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id - ) - return 0 + event_node = None + for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): + if node_id == plugin_trigger.node_id: + event_node = node_config + break - node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) - invoke_response: TriggerInvokeEventResponse | None = None + if not event_node: + logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) + continue + + trigger_metadata = PluginTriggerMetadata( + plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", + endpoint_id=subscription.endpoint_id, + provider_id=subscription.provider_id, + event_name=event_name, + icon_filename=trigger_entity.identity.icon or "", + icon_dark_filename=trigger_entity.identity.icon_dark or "", + ) + + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) + logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id) + return dispatched_count + + node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) + invoke_response: TriggerInvokeEventResponse | None = None + + with session_factory.create_session() as session: try: invoke_response = TriggerManager.invoke_trigger_event( tenant_id=subscription.tenant_id, @@ -387,6 +387,7 @@ def dispatch_triggered_workflow( raise ValueError(f"End user not found for app {plugin_trigger.app_id}") AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data) + quota_charge.commit() dispatched_count += 1 logger.info( "Triggered workflow for app %s with trigger event %s", @@ -401,7 +402,7 @@ def dispatch_triggered_workflow( plugin_trigger.app_id, ) - return dispatched_count + return dispatched_count def dispatch_triggered_workflows( diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index 8c64d3ab27..7638652000 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import ( ScheduleNotFoundError, TenantOwnerNotFoundError, ) -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.schedule_service import ScheduleService from services.workflow.entities import ScheduleTriggerData @@ -32,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None: TenantOwnerNotFoundError: If no owner/admin for tenant ScheduleExecutionError: If workflow trigger fails """ + # Ensure expire_on_commit is set to False to remain schedule/tenant_owner available with session_factory.create_session() as session: schedule = session.get(WorkflowSchedulePlan, schedule_id) if not schedule: @@ -41,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None: if not tenant_owner: raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}") - quota_charge = unlimited() - try: - quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) - logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) - return + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) + logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) + return - try: - # Production dispatch: Trigger the workflow normally + try: + with session_factory.create_session() as session: response = AsyncWorkflowService.trigger_workflow_async( session=session, user=tenant_owner, @@ -61,9 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None: tenant_id=schedule.tenant_id, ), ) - logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) - except Exception as e: - quota_charge.refund() - raise ScheduleExecutionError( - f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" - ) from e + quota_charge.commit() + logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) + except Exception as e: + quota_charge.refund() + raise ScheduleExecutionError( + f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" + ) from e diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5b1a4790f5..3229693fd4 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -36,12 +36,19 @@ class TestAppGenerateService: ) as mock_message_based_generator, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config, + patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config, patch("configs.dify_config", autospec=True) as mock_global_dify_config, ): # Setup default mock returns for billing service - mock_billing_service.update_tenant_feature_plan_usage.return_value = { - "result": "success", - "history_id": "test_history_id", + mock_billing_service.quota_reserve.return_value = { + "reservation_id": "test-reservation-id", + "available": 100, + "reserved": 1, + } + mock_billing_service.quota_commit.return_value = { + "available": 99, + "reserved": 0, + "refunded": 0, } # Setup default mock returns for workflow service @@ -101,6 +108,8 @@ class TestAppGenerateService: mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100 mock_dify_config.APP_DAILY_RATE_LIMIT = 1000 + mock_quota_dify_config.BILLING_ENABLED = False + mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 @@ -118,6 +127,7 @@ class TestAppGenerateService: "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, + "quota_dify_config": mock_quota_dify_config, "global_dify_config": mock_global_dify_config, } @@ -465,6 +475,7 @@ class TestAppGenerateService: # Set BILLING_ENABLED to True for this test mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True + mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True # Setup test arguments @@ -478,8 +489,10 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - # Verify billing service was called to consume quota - mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() + # Verify billing two-phase quota (reserve + commit) + billing = mock_external_service_dependencies["billing_service"] + billing.quota_reserve.assert_called_once() + billing.quota_commit.assert_called_once() def test_generate_with_invalid_app_mode( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py new file mode 100644 index 0000000000..85ce3a6ba6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -0,0 +1,517 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from enums.quota_type import QuotaType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import AppTriggerStatus, AppTriggerType +from models.model import App +from models.trigger import AppTrigger, WorkflowWebhookTrigger +from models.workflow import Workflow +from services.errors.app import QuotaExceededError +from services.trigger.webhook_service import WebhookService + + +class WebhookServiceRelationshipFactory: + @staticmethod + def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]: + account = Account( + name=f"Account {uuid4()}", + email=f"webhook-{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App: + app = App( + tenant_id=tenant.id, + name=f"Webhook App {uuid4()}", + description="", + mode="workflow", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + @staticmethod + def create_workflow( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_ids: list[str], + version: str, + ) -> Workflow: + graph = { + "nodes": [ + { + "id": node_id, + "data": { + "type": TRIGGER_WEBHOOK_NODE_TYPE, + "title": f"Webhook {node_id}", + "method": "post", + "content_type": "application/json", + "headers": [], + "params": [], + "body": [], + "status_code": 200, + "response_body": '{"status": "ok"}', + "timeout": 30, + }, + } + for node_id in node_ids + ], + "edges": [], + } + + workflow = Workflow( + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + version=version, + ) + db_session_with_containers.add(workflow) + db_session_with_containers.commit() + return workflow + + @staticmethod + def create_webhook_trigger( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_id: str, + webhook_id: str | None = None, + ) -> WorkflowWebhookTrigger: + webhook_trigger = WorkflowWebhookTrigger( + app_id=app.id, + node_id=node_id, + tenant_id=app.tenant_id, + webhook_id=webhook_id or uuid4().hex[:24], + created_by=account.id, + ) + db_session_with_containers.add(webhook_trigger) + db_session_with_containers.commit() + return webhook_trigger + + @staticmethod + def create_app_trigger( + db_session_with_containers: Session, + *, + app: App, + node_id: str, + status: AppTriggerStatus, + ) -> AppTrigger: + app_trigger = AppTrigger( + tenant_id=app.tenant_id, + app_id=app.id, + node_id=node_id, + trigger_type=AppTriggerType.TRIGGER_WEBHOOK, + provider_name="webhook", + title=f"Webhook {node_id}", + status=status, + ) + db_session_with_containers.add(app_trigger) + db_session_with_containers.commit() + return app_trigger + + +class TestWebhookServiceLookupWithContainers: + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with pytest.raises(ValueError, match="App trigger not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED + ) + + with pytest.raises(ValueError, match="rate limited"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED + ) + + with pytest.raises(ValueError, match="disabled"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED + ) + + with pytest.raises(ValueError, match="Workflow not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["published-node"], + version="2026-04-14.001", + ) + draft_workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["debug-node"], + version=Workflow.VERSION_DRAFT, + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="debug-node" + ) + + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow( + webhook_trigger.webhook_id, + is_debug=True, + ) + + assert got_trigger.id == webhook_trigger.id + assert got_workflow.id == draft_workflow.id + assert got_node_config["id"] == "debug-node" + + +class TestWebhookServiceTriggerExecutionWithContainers: + def test_trigger_workflow_execution_triggers_async_workflow_successfully( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + end_user = SimpleNamespace(id=str(uuid4())) + webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"} + + quota_charge = MagicMock() + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=end_user, + ), + patch( + "services.trigger.webhook_service.QuotaService.reserve", + return_value=quota_charge, + ) as mock_reserve, + patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger, + ): + WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) + + mock_reserve.assert_called_once() + reserve_args = mock_reserve.call_args.args + assert reserve_args[0] == QuotaType.TRIGGER + assert reserve_args[1] == webhook_trigger.tenant_id + quota_charge.commit.assert_called_once() + mock_trigger.assert_called_once() + trigger_args = mock_trigger.call_args.args + assert trigger_args[1] is end_user + assert trigger_args[2].workflow_id == workflow.id + assert trigger_args[2].root_node_id == webhook_trigger.node_id + + def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=SimpleNamespace(id=str(uuid4())), + ), + patch( + "services.trigger.webhook_service.QuotaService.reserve", + side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1), + ), + patch( + "services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited" + ) as mock_mark_rate_limited, + ): + with pytest.raises(QuotaExceededError): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_mark_rate_limited.assert_called_once_with(tenant.id) + + def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + side_effect=RuntimeError("boom"), + ), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + with pytest.raises(RuntimeError, match="boom"): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_logger_exception.assert_called_once() + + +class TestWebhookServiceRelationshipSyncWithContainers: + def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)] + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT + ) + + with pytest.raises(ValueError, match="maximum webhook node limit"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_raises_when_lock_not_acquired( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = False + + with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock): + with pytest.raises(RuntimeError, match="Failed to acquire lock"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + stale_trigger = factory.create_webhook_trigger( + db_session_with_containers, + app=app, + account=account, + node_id="node-stale", + webhook_id="stale-webhook-id-000001", + ) + stale_trigger_id = stale_trigger.id + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-new"], + version=Workflow.VERSION_DRAFT, + ) + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + db_session_with_containers.expire_all() + records = db_session_with_containers.scalars( + select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id) + ).all() + + assert [record.node_id for record in records] == ["node-new"] + assert records[0].webhook_id == "new-webhook-id-000001" + assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None + + def test_sync_webhook_relationships_sets_redis_cache_for_new_record( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-cache"], + version=Workflow.VERSION_DRAFT, + ) + cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache" + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key) + assert cached_payload is not None + assert cached_payload["node_id"] == "node-cache" + assert cached_payload["webhook_id"] == "cache-webhook-id-00001" + + def test_sync_webhook_relationships_logs_when_lock_release_fails( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = RuntimeError("release failed") + + with ( + patch("services.trigger.webhook_service.redis_client.lock", return_value=lock), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + WebhookService.sync_webhook_relationships(app, workflow) + + mock_logger_exception.assert_called_once() + + +def _read_cache(cache_key: str) -> dict[str, str] | None: + from extensions.ext_redis import redis_client + + cached = redis_client.get(cache_key) + if not cached: + return None + if isinstance(cached, bytes): + cached = cached.decode("utf-8") + return json.loads(cached) + + +WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache) diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 4ea8d8c1c7..3b05207dcd 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -602,9 +602,9 @@ def test_schedule_trigger_creates_trigger_log( ) # Mock quota to avoid rate limiting - from enums import quota_type + from services import quota_service - monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited()) # Execute schedule trigger workflow_schedule_tasks.run_schedule_trigger(plan.id) diff --git a/api/tests/unit_tests/enums/__init__.py b/api/tests/unit_tests/enums/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/enums/test_quota_type.py b/api/tests/unit_tests/enums/test_quota_type.py new file mode 100644 index 0000000000..f256ff3b4e --- /dev/null +++ b/api/tests/unit_tests/enums/test_quota_type.py @@ -0,0 +1,349 @@ +"""Unit tests for QuotaType, QuotaService, and QuotaCharge.""" + +from unittest.mock import patch + +import pytest + +from enums.quota_type import QuotaType +from services.quota_service import QuotaCharge, QuotaService, unlimited + + +class TestQuotaType: + def test_billing_key_trigger(self): + assert QuotaType.TRIGGER.billing_key == "trigger_event" + + def test_billing_key_workflow(self): + assert QuotaType.WORKFLOW.billing_key == "api_rate_limit" + + def test_billing_key_unlimited_raises(self): + with pytest.raises(ValueError, match="Invalid quota type"): + _ = QuotaType.UNLIMITED.billing_key + + +class TestQuotaService: + def test_reserve_billing_disabled(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService"), + ): + mock_cfg.BILLING_ENABLED = False + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") + assert charge.success is True + assert charge.charge_id is None + + def test_reserve_zero_amount_raises(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = True + with pytest.raises(ValueError, match="greater than 0"): + QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0) + + def test_reserve_success(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99} + + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1) + + assert charge.success is True + assert charge.charge_id == "rid-1" + assert charge._tenant_id == "t1" + assert charge._feature_key == "trigger_event" + assert charge._amount == 1 + mock_bs.quota_reserve.assert_called_once() + + def test_reserve_no_reservation_id_raises(self): + from services.errors.app import QuotaExceededError + + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {} + + with pytest.raises(QuotaExceededError): + QuotaService.reserve(QuotaType.TRIGGER, "t1") + + def test_reserve_quota_exceeded_propagates(self): + from services.errors.app import QuotaExceededError + + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1) + + with pytest.raises(QuotaExceededError): + QuotaService.reserve(QuotaType.TRIGGER, "t1") + + def test_reserve_api_exception_returns_unlimited(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.side_effect = RuntimeError("network") + + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") + assert charge.success is True + assert charge.charge_id is None + + def test_consume_calls_reserve_and_commit(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"} + mock_bs.quota_commit.return_value = {} + + charge = QuotaService.consume(QuotaType.TRIGGER, "t1") + assert charge.success is True + mock_bs.quota_commit.assert_called_once() + + def test_check_billing_disabled(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = False + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True + + def test_check_zero_amount_raises(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = True + with pytest.raises(ValueError, match="greater than 0"): + QuotaService.check(QuotaType.TRIGGER, "t1", amount=0) + + def test_check_sufficient_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=100), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True + + def test_check_insufficient_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=5), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False + + def test_check_unlimited_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=-1), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True + + def test_check_exception_returns_true(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", side_effect=RuntimeError), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True + + def test_release_billing_disabled(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = False + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + mock_bs.quota_release.assert_not_called() + + def test_release_empty_reservation(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event") + mock_bs.quota_release.assert_not_called() + + def test_release_success(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_release.return_value = {} + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + mock_bs.quota_release.assert_called_once_with( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1" + ) + + def test_release_exception_swallowed(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_release.side_effect = RuntimeError("fail") + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + + def test_get_remaining_normal(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70 + + def test_get_remaining_unlimited(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 + + def test_get_remaining_over_limit_returns_zero(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_exception_returns_neg1(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.side_effect = RuntimeError + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 + + def test_get_remaining_empty_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_non_dict_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = "invalid" + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_feature_not_in_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}} + remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1") + assert remaining == 0 + + def test_get_remaining_non_dict_feature_info(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + +class TestQuotaCharge: + def test_commit_success(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + mock_bs.quota_commit.assert_called_once_with( + tenant_id="t1", + feature_key="trigger_event", + reservation_id="rid-1", + actual_amount=1, + ) + assert charge._committed is True + + def test_commit_with_actual_amount(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=10, + ) + charge.commit(actual_amount=5) + call_kwargs = mock_bs.quota_commit.call_args[1] + assert call_kwargs["actual_amount"] == 5 + + def test_commit_idempotent(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + charge.commit() + assert mock_bs.quota_commit.call_count == 1 + + def test_commit_no_charge_id_noop(self): + with patch("services.billing_service.BillingService") as mock_bs: + charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER) + charge.commit() + mock_bs.quota_commit.assert_not_called() + + def test_commit_no_tenant_id_noop(self): + with patch("services.billing_service.BillingService") as mock_bs: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id=None, + _feature_key="trigger_event", + ) + charge.commit() + mock_bs.quota_commit.assert_not_called() + + def test_commit_exception_swallowed(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.side_effect = RuntimeError("fail") + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + + def test_refund_success(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + ) + charge.refund() + mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + + def test_refund_no_charge_id_noop(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER) + charge.refund() + mock_rel.assert_not_called() + + def test_refund_no_tenant_id_noop(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id=None, + ) + charge.refund() + mock_rel.assert_not_called() + + +class TestUnlimited: + def test_unlimited_returns_success_with_no_charge_id(self): + charge = unlimited() + assert charge.success is True + assert charge.charge_id is None + assert charge._quota_type == QuotaType.UNLIMITED diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index c2b430c551..c88daf6b1e 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -23,6 +23,7 @@ import pytest import services.app_generate_service as ags_module from core.app.entities.app_invoke_entities import InvokeFrom +from enums.quota_type import QuotaType from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError @@ -447,8 +448,8 @@ class TestGenerateBilling: def test_billing_enabled_consumes_quota(self, mocker, monkeypatch): monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() - consume_mock = mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + reserve_mock = mocker.patch( + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( @@ -467,7 +468,8 @@ class TestGenerateBilling: invoke_from=InvokeFrom.SERVICE_API, streaming=False, ) - consume_mock.assert_called_once_with("tenant-id") + reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id") + quota_charge.commit.assert_called_once() def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch): from services.errors.app import QuotaExceededError @@ -475,7 +477,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaService.reserve", side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1), ) @@ -492,7 +494,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py index 639e091041..226234290e 100644 --- a/api/tests/unit_tests/services/test_async_workflow_service.py +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -57,7 +57,7 @@ class TestAsyncWorkflowService: - repo: SQLAlchemyWorkflowTriggerLogRepository - dispatcher_manager_class: QueueDispatcherManager class - dispatcher: dispatcher instance - - quota_workflow: QuotaType.WORKFLOW + - quota_service: QuotaService mock - get_workflow: AsyncWorkflowService._get_workflow method - professional_task: execute_workflow_professional - team_task: execute_workflow_team @@ -72,12 +72,7 @@ class TestAsyncWorkflowService: mock_repo.create.side_effect = _create_side_effect mock_dispatcher = MagicMock() - quota_workflow = MagicMock() - mock_get_workflow = MagicMock() - - mock_professional_task = MagicMock() - mock_team_task = MagicMock() - mock_sandbox_task = MagicMock() + mock_quota_service = MagicMock() with ( patch.object( @@ -93,8 +88,8 @@ class TestAsyncWorkflowService: ) as mock_get_workflow, patch.object( async_workflow_service_module, - "QuotaType", - new=SimpleNamespace(WORKFLOW=quota_workflow), + "QuotaService", + new=mock_quota_service, ), patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task, patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task, @@ -107,7 +102,7 @@ class TestAsyncWorkflowService: "repo": mock_repo, "dispatcher_manager_class": mock_dispatcher_manager_class, "dispatcher": mock_dispatcher, - "quota_workflow": quota_workflow, + "quota_service": mock_quota_service, "get_workflow": mock_get_workflow, "professional_task": mock_professional_task, "team_task": mock_team_task, @@ -146,6 +141,9 @@ class TestAsyncWorkflowService: mocks["team_task"].delay.return_value = task_result mocks["sandbox_task"].delay.return_value = task_result + quota_charge_mock = MagicMock() + mocks["quota_service"].reserve.return_value = quota_charge_mock + class DummyAccount: def __init__(self, user_id: str): self.id = user_id @@ -163,8 +161,9 @@ class TestAsyncWorkflowService: assert result.status == "queued" assert result.queue == queue_name - mocks["quota_workflow"].consume.assert_called_once_with("tenant-123") - assert session.commit.call_count == 2 + mocks["quota_service"].reserve.assert_called_once() + quota_charge_mock.commit.assert_called_once() + assert session.commit.call_count == 3 created_log = mocks["repo"].create.call_args[0][0] assert created_log.status == WorkflowTriggerStatus.QUEUED @@ -250,7 +249,7 @@ class TestAsyncWorkflowService: mocks = async_workflow_trigger_mocks mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM mocks["get_workflow"].return_value = workflow - mocks["quota_workflow"].consume.side_effect = QuotaExceededError( + mocks["quota_service"].reserve.side_effect = QuotaExceededError( feature="workflow", tenant_id="tenant-123", required=1, @@ -267,7 +266,7 @@ class TestAsyncWorkflowService: trigger_data=trigger_data, ) - assert session.commit.call_count == 2 + assert session.commit.call_count == 3 updated_log = mocks["repo"].update.call_args[0][0] assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED assert "Quota limit reached" in updated_log.error @@ -463,7 +462,7 @@ class TestAsyncWorkflowServiceGetWorkflow: # Assert assert result == workflow - workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123") + workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123", session=None) workflow_service.get_published_workflow.assert_not_called() def test_should_raise_when_specific_workflow_id_not_found(self): @@ -491,7 +490,7 @@ class TestAsyncWorkflowServiceGetWorkflow: # Assert assert result == workflow - workflow_service.get_published_workflow.assert_called_once_with(app_model) + workflow_service.get_published_workflow.assert_called_once_with(app_model, session=None) workflow_service.get_published_workflow_by_id.assert_not_called() def test_should_raise_when_default_published_workflow_not_found(self): diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 7d7ec81de4..d616b2e879 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -425,7 +425,7 @@ class TestBillingServiceUsageCalculation: yield mock def test_get_tenant_feature_plan_usage_info(self, mock_send_request): - """Test retrieval of tenant feature plan usage information.""" + """Test retrieval of tenant feature plan usage information (legacy endpoint).""" # Arrange tenant_id = "tenant-123" expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}} @@ -438,6 +438,20 @@ class TestBillingServiceUsageCalculation: assert result == expected_response mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id}) + def test_get_quota_info(self, mock_send_request): + """Test retrieval of quota info from new endpoint.""" + # Arrange + tenant_id = "tenant-123" + expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_quota_info(tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id}) + def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request): """Test updating tenant feature usage with positive delta (adding credits).""" # Arrange @@ -515,6 +529,150 @@ class TestBillingServiceUsageCalculation: ) +class TestBillingServiceQuotaOperations: + """Unit tests for quota reserve/commit/release operations.""" + + @pytest.fixture + def mock_send_request(self): + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_quota_reserve_success(self, mock_send_request): + expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1} + mock_send_request.return_value = expected + + result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1) + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/reserve", + json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1}, + ) + + def test_quota_reserve_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"} + + result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1) + + assert result["available"] == 99 + assert isinstance(result["available"], int) + assert result["reserved"] == 1 + assert isinstance(result["reserved"], int) + + def test_quota_reserve_with_meta(self, mock_send_request): + mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1} + meta = {"source": "webhook"} + + BillingService.quota_reserve( + tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta + ) + + call_json = mock_send_request.call_args[1]["json"] + assert call_json["meta"] == {"source": "webhook"} + + def test_quota_commit_success(self, mock_send_request): + expected = {"available": 98, "reserved": 0, "refunded": 0} + mock_send_request.return_value = expected + + result = BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1 + ) + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/commit", + json={ + "tenant_id": "t1", + "feature_key": "trigger_event", + "reservation_id": "rid-1", + "actual_amount": 1, + }, + ) + + def test_quota_commit_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"} + + result = BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1 + ) + + assert result["available"] == 97 + assert isinstance(result["available"], int) + assert result["refunded"] == 1 + assert isinstance(result["refunded"], int) + + def test_quota_commit_with_meta(self, mock_send_request): + mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0} + meta = {"reason": "partial"} + + BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta + ) + + call_json = mock_send_request.call_args[1]["json"] + assert call_json["meta"] == {"reason": "partial"} + + def test_quota_release_success(self, mock_send_request): + expected = {"available": 100, "reserved": 0, "released": 1} + mock_send_request.return_value = expected + + result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1") + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/release", + json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"}, + ) + + def test_quota_release_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"} + + result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s") + + assert result["available"] == 100 + assert isinstance(result["available"], int) + assert result["released"] == 1 + assert isinstance(result["released"], int) + + def test_get_quota_info_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int for get_quota_info.""" + mock_send_request.return_value = { + "trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"}, + "api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"}, + } + + result = BillingService.get_quota_info("t1") + + assert result["trigger_event"]["usage"] == 42 + assert isinstance(result["trigger_event"]["usage"], int) + assert result["trigger_event"]["limit"] == 3000 + assert isinstance(result["trigger_event"]["limit"], int) + assert result["trigger_event"]["reset_date"] == 1700000000 + assert isinstance(result["trigger_event"]["reset_date"], int) + assert result["api_rate_limit"]["limit"] == -1 + assert isinstance(result["api_rate_limit"]["limit"], int) + + def test_get_quota_info_accepts_int_values(self, mock_send_request): + """Test that get_quota_info works with native int values.""" + expected = { + "trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000}, + "api_rate_limit": {"usage": 0, "limit": -1}, + } + mock_send_request.return_value = expected + + result = BillingService.get_quota_info("t1") + + assert result["trigger_event"]["usage"] == 42 + assert result["trigger_event"]["limit"] == 3000 + assert result["api_rate_limit"]["limit"] == -1 + + class TestBillingServiceRateLimitEnforcement: """Unit tests for rate limit enforcement mechanisms. diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index d26c2f674f..2419367b6e 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -337,10 +337,7 @@ class TestWorkflowService: app = TestWorkflowAssociatedDataFactory.create_app_mock() mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock() - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_draft_workflow(app) @@ -350,10 +347,7 @@ class TestWorkflowService: """Test get_draft_workflow returns None when no draft exists.""" app = TestWorkflowAssociatedDataFactory.create_app_mock() - # Mock database query to return None - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = None + mock_db_session.session.scalar.return_value = None result = workflow_service.get_draft_workflow(app) @@ -365,10 +359,7 @@ class TestWorkflowService: workflow_id = "workflow-123" mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1") - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id) @@ -383,10 +374,7 @@ class TestWorkflowService: workflow_id = "workflow-123" mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1") - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_published_workflow_by_id(app, workflow_id) @@ -405,10 +393,7 @@ class TestWorkflowService: workflow_id=workflow_id, version=Workflow.VERSION_DRAFT ) - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow with pytest.raises(IsDraftWorkflowError): workflow_service.get_published_workflow_by_id(app, workflow_id) @@ -418,10 +403,7 @@ class TestWorkflowService: app = TestWorkflowAssociatedDataFactory.create_app_mock() workflow_id = "nonexistent-workflow" - # Mock database query to return None - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = None + mock_db_session.session.scalar.return_value = None result = workflow_service.get_published_workflow_by_id(app, workflow_id) @@ -433,10 +415,7 @@ class TestWorkflowService: app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id) mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1") - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_published_workflow(app) @@ -465,11 +444,7 @@ class TestWorkflowService: graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph() features = {"file_upload": {"enabled": False}} - # Mock get_draft_workflow to return None (no existing draft) - # This simulates the first time a workflow is created for an app - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = None + mock_db_session.session.scalar.return_value = None with ( patch.object(workflow_service, "validate_features_structure"), @@ -506,9 +481,7 @@ class TestWorkflowService: # Mock existing draft workflow mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash) - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow with ( patch.object(workflow_service, "validate_features_structure"), @@ -547,9 +520,7 @@ class TestWorkflowService: # Mock existing draft workflow with different hash mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash") - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow with pytest.raises(WorkflowHashNotEqualError): workflow_service.sync_draft_workflow( diff --git a/api/tests/unit_tests/tasks/test_trigger_processing_tasks.py b/api/tests/unit_tests/tasks/test_trigger_processing_tasks.py new file mode 100644 index 0000000000..59da5cc7a2 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_trigger_processing_tasks.py @@ -0,0 +1,204 @@ +from unittest.mock import MagicMock, patch + +import pytest + +import tasks.trigger_processing_tasks as trigger_processing_tasks_module +from services.errors.app import QuotaExceededError +from tasks.trigger_processing_tasks import dispatch_triggered_workflow + + +class TestDispatchTriggeredWorkflow: + """Unit tests covering branch behaviours of ``dispatch_triggered_workflow``. + + The covered branches are: + - workflow missing for ``plugin_trigger.app_id`` → log + ``continue`` + - ``QuotaService.reserve`` raising ``QuotaExceededError`` → + ``mark_tenant_triggers_rate_limited`` + early ``return`` + - ``trigger_workflow_async`` succeeds → + ``quota_charge.commit()`` + ``dispatched_count`` increments + """ + + @pytest.fixture + def subscription(self): + sub = MagicMock() + sub.id = "subscription-123" + sub.tenant_id = "tenant-123" + sub.provider_id = "langgenius/test_plugin/test_plugin" + sub.endpoint_id = "endpoint-123" + sub.credentials = {} + sub.credential_type = "api_key" + return sub + + @pytest.fixture + def plugin_trigger(self): + trigger = MagicMock() + trigger.id = "plugin-trigger-123" + trigger.app_id = "app-123" + trigger.node_id = "node-123" + return trigger + + @pytest.fixture + def provider_controller(self): + controller = MagicMock() + controller.plugin_unique_identifier = "langgenius/test_plugin:0.0.1" + controller.entity.identity.name = "Test Plugin" + controller.entity.identity.icon = "icon.svg" + controller.entity.identity.icon_dark = "icon_dark.svg" + return controller + + @pytest.fixture + def dispatch_mocks(self, subscription, plugin_trigger, provider_controller): + """Patch all external dependencies reached by ``dispatch_triggered_workflow``. + + Defaults are configured so the code flow can reach the final async + trigger block (line ~385); each test overrides specific handles + (``get_workflows``, ``reserve``, ``create_end_user_batch``, ...) to + drive the path it targets. + """ + session_cm = MagicMock() + session_cm.__enter__.return_value = MagicMock() + session_cm.__exit__.return_value = False + + invoke_response = MagicMock() + invoke_response.cancelled = False + invoke_response.variables = {} + + quota_charge = MagicMock() + + with ( + patch.object( + trigger_processing_tasks_module.TriggerHttpRequestCachingService, + "get_request", + return_value=MagicMock(), + ), + patch.object( + trigger_processing_tasks_module.TriggerHttpRequestCachingService, + "get_payload", + return_value=MagicMock(), + ), + patch.object( + trigger_processing_tasks_module.TriggerSubscriptionOperatorService, + "get_subscriber_triggers", + return_value=[plugin_trigger], + ), + patch.object( + trigger_processing_tasks_module.TriggerManager, + "get_trigger_provider", + return_value=provider_controller, + ), + patch.object( + trigger_processing_tasks_module.TriggerManager, + "invoke_trigger_event", + return_value=invoke_response, + ) as invoke_trigger_event, + patch.object( + trigger_processing_tasks_module.TriggerEventNodeData, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + trigger_processing_tasks_module, + "_get_latest_workflows_by_app_ids", + ) as get_workflows, + patch.object( + trigger_processing_tasks_module.EndUserService, + "create_end_user_batch", + return_value={}, + ) as create_end_user_batch, + patch.object( + trigger_processing_tasks_module.session_factory, + "create_session", + return_value=session_cm, + ), + patch.object( + trigger_processing_tasks_module.QuotaService, + "reserve", + return_value=quota_charge, + ) as reserve, + patch.object( + trigger_processing_tasks_module.AppTriggerService, + "mark_tenant_triggers_rate_limited", + ) as mark_rate_limited, + patch.object( + trigger_processing_tasks_module.AsyncWorkflowService, + "trigger_workflow_async", + ) as trigger_workflow_async, + ): + yield { + "get_workflows": get_workflows, + "reserve": reserve, + "quota_charge": quota_charge, + "mark_rate_limited": mark_rate_limited, + "invoke_trigger_event": invoke_trigger_event, + "invoke_response": invoke_response, + "create_end_user_batch": create_end_user_batch, + "trigger_workflow_async": trigger_workflow_async, + } + + def test_dispatch_skips_when_workflow_missing(self, subscription, dispatch_mocks): + """Covers missing workflow → log + ``continue``.""" + dispatch_mocks["get_workflows"].return_value = {} + + dispatched = dispatch_triggered_workflow( + user_id="user-123", + subscription=subscription, + event_name="test_event", + request_id="request-123", + ) + + assert dispatched == 0 + dispatch_mocks["reserve"].assert_not_called() + dispatch_mocks["invoke_trigger_event"].assert_not_called() + dispatch_mocks["mark_rate_limited"].assert_not_called() + + def test_dispatch_marks_rate_limited_when_quota_exceeded(self, subscription, plugin_trigger, dispatch_mocks): + """Covers QuotaExceededError → mark rate-limited + early return.""" + workflow_mock = MagicMock() + workflow_mock.walk_nodes.return_value = iter( + [(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})] + ) + dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock} + dispatch_mocks["reserve"].side_effect = QuotaExceededError( + feature="trigger", tenant_id=subscription.tenant_id, required=1 + ) + + dispatched = dispatch_triggered_workflow( + user_id="user-123", + subscription=subscription, + event_name="test_event", + request_id="request-123", + ) + + assert dispatched == 0 + dispatch_mocks["reserve"].assert_called_once() + dispatch_mocks["mark_rate_limited"].assert_called_once_with(subscription.tenant_id) + dispatch_mocks["invoke_trigger_event"].assert_not_called() + + def test_dispatch_commits_quota_and_counts_when_workflow_triggered( + self, subscription, plugin_trigger, dispatch_mocks + ): + """Happy path: end user exists and async trigger succeeds.""" + workflow_mock = MagicMock() + workflow_mock.id = "workflow-123" + workflow_mock.walk_nodes.return_value = iter( + [(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})] + ) + dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock} + + end_user_mock = MagicMock() + dispatch_mocks["create_end_user_batch"].return_value = {plugin_trigger.app_id: end_user_mock} + + dispatched = dispatch_triggered_workflow( + user_id="user-123", + subscription=subscription, + event_name="test_event", + request_id="request-123", + ) + + assert dispatched == 1 + dispatch_mocks["trigger_workflow_async"].assert_called_once() + _, kwargs = dispatch_mocks["trigger_workflow_async"].call_args + assert kwargs["user"] is end_user_mock + dispatch_mocks["quota_charge"].commit.assert_called_once() + dispatch_mocks["quota_charge"].refund.assert_not_called() + dispatch_mocks["mark_rate_limited"].assert_not_called()