diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 45de338559..ce2870c82e 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,4 +1,6 @@ import base64 +import json +from datetime import UTC, datetime, timedelta from typing import Literal from flask import request @@ -10,6 +12,7 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from enums.cloud_plan import CloudPlan +from extensions.ext_redis import redis_client from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService @@ -77,3 +80,39 @@ class PartnerTenants(Resource): raise BadRequest("Invalid partner information") return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id) + + +_DEBUG_KEY = "billing:debug" +_DEBUG_TTL = timedelta(days=7) + + +class DebugDataPayload(BaseModel): + type: str = Field(..., min_length=1, description="Data type key") + data: str = Field(..., min_length=1, description="Data value to append") + + +@console_ns.route("/billing/debug/data") +class DebugData(Resource): + def post(self): + body = DebugDataPayload.model_validate(request.get_json(force=True)) + item = json.dumps({ + "type": body.type, + "data": body.data, + "createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"), + }) + redis_client.lpush(_DEBUG_KEY, item) + redis_client.expire(_DEBUG_KEY, _DEBUG_TTL) + return {"result": "ok"}, 201 + + def get(self): + recent = request.args.get("recent", 10, type=int) + items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1) + return { + "data": [ + json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items + ] + } + + def delete(self): + redis_client.delete(_DEBUG_KEY) + return {"result": "ok"} diff --git a/api/enums/quota_type.py b/api/enums/quota_type.py index 9f511b88ef..a10ac21f69 100644 --- a/api/enums/quota_type.py +++ b/api/enums/quota_type.py @@ -1,56 +1,17 @@ -import logging -from dataclasses import dataclass from enum import StrEnum, auto -logger = logging.getLogger(__name__) - - -@dataclass -class QuotaCharge: - """ - Result of a quota consumption operation. - - Attributes: - success: Whether the quota charge succeeded - charge_id: UUID for refund, or None if failed/disabled - """ - - success: bool - charge_id: str | None - _quota_type: "QuotaType" - - def refund(self) -> None: - """ - Refund this quota charge. - - Safe to call even if charge failed or was disabled. - This method guarantees no exceptions will be raised. - """ - if self.charge_id: - self._quota_type.refund(self.charge_id) - logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id) - class QuotaType(StrEnum): """ Supported quota types for tenant feature usage. - - Add additional types here whenever new billable features become available. """ - # Trigger execution quota TRIGGER = auto() - - # Workflow execution quota WORKFLOW = auto() - UNLIMITED = auto() @property def billing_key(self) -> str: - """ - Get the billing key for the feature. - """ match self: case QuotaType.TRIGGER: return "trigger_event" @@ -58,152 +19,3 @@ class QuotaType(StrEnum): return "api_rate_limit" case _: raise ValueError(f"Invalid quota type: {self}") - - def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge: - """ - Consume quota for the feature. - - Args: - tenant_id: The tenant identifier - amount: Amount to consume (default: 1) - - Returns: - QuotaCharge with success status and charge_id for refund - - Raises: - QuotaExceededError: When quota is insufficient - """ - from configs import dify_config - from services.billing_service import BillingService - from services.errors.app import QuotaExceededError - - if not dify_config.BILLING_ENABLED: - logger.debug("Billing disabled, allowing request for %s", tenant_id) - return QuotaCharge(success=True, charge_id=None, _quota_type=self) - - logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id) - - if amount <= 0: - raise ValueError("Amount to consume must be greater than 0") - - try: - response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount) - - if response.get("result") != "success": - logger.warning( - "Failed to consume quota for %s, feature %s details: %s", - tenant_id, - self.value, - response.get("detail"), - ) - raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount) - - charge_id = response.get("history_id") - logger.debug( - "Successfully consumed %d %s quota for tenant %s, charge_id: %s", - amount, - self.value, - tenant_id, - charge_id, - ) - return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self) - - except QuotaExceededError: - raise - except Exception: - # fail-safe: allow request on billing errors - logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value) - return unlimited() - - def check(self, tenant_id: str, amount: int = 1) -> bool: - """ - Check if tenant has sufficient quota without consuming. - - Args: - tenant_id: The tenant identifier - amount: Amount to check (default: 1) - - Returns: - True if quota is sufficient, False otherwise - """ - from configs import dify_config - - if not dify_config.BILLING_ENABLED: - return True - - if amount <= 0: - raise ValueError("Amount to check must be greater than 0") - - try: - remaining = self.get_remaining(tenant_id) - return remaining >= amount if remaining != -1 else True - except Exception: - logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value) - # fail-safe: allow request on billing errors - return True - - def refund(self, charge_id: str) -> None: - """ - Refund quota using charge_id from consume(). - - This method guarantees no exceptions will be raised. - All errors are logged but silently handled. - - Args: - charge_id: The UUID returned from consume() - """ - try: - from configs import dify_config - from services.billing_service import BillingService - - if not dify_config.BILLING_ENABLED: - return - - if not charge_id: - logger.warning("Cannot refund: charge_id is empty") - return - - logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id) - - response = BillingService.refund_tenant_feature_plan_usage(charge_id) - if response.get("result") == "success": - logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id) - else: - logger.warning("Refund failed for charge_id: %s", charge_id) - - except Exception: - # Catch ALL exceptions - refund must never fail - logger.exception("Failed to refund quota for charge_id: %s", charge_id) - # Don't raise - refund is best-effort and must be silent - - def get_remaining(self, tenant_id: str) -> int: - """ - Get remaining quota for the tenant. - - Args: - tenant_id: The tenant identifier - - Returns: - Remaining quota amount - """ - from services.billing_service import BillingService - - try: - usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key) - # Assuming the API returns a dict with 'remaining' or 'limit' and 'used' - if isinstance(usage_info, dict): - return usage_info.get("remaining", 0) - # If it returns a simple number, treat it as remaining - return int(usage_info) if usage_info else 0 - except Exception: - logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value) - return -1 - - -def unlimited() -> QuotaCharge: - """ - Return a quota charge for unlimited quota. - - This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type. - """ - return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 5e8c7aa337..2c9d815b64 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit from core.app.features.rate_limiting.rate_limit import rate_limit_context from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.db import session_factory -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError +from services.quota_service import QuotaService, unlimited from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task @@ -106,7 +107,7 @@ class AppGenerateService: quota_charge = unlimited() if dify_config.BILLING_ENABLED: try: - quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id) except QuotaExceededError: raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}") @@ -116,6 +117,7 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) + quota_charge.commit() effective_mode = ( AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode ) diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index a731d5c048..8b39d63385 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError +from services.quota_service import QuotaService, unlimited from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority from services.workflow_service import WorkflowService @@ -131,9 +132,10 @@ class AsyncWorkflowService: trigger_log = trigger_log_repo.create(trigger_log) session.commit() - # 7. Check and consume quota + # 7. Reserve quota (commit after successful dispatch) + quota_charge = unlimited() try: - QuotaType.WORKFLOW.consume(trigger_data.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id) except QuotaExceededError as e: # Update trigger log status trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED @@ -153,13 +155,18 @@ class AsyncWorkflowService: # 9. Dispatch to appropriate queue task_data_dict = task_data.model_dump(mode="json") - task: AsyncResult[Any] | None = None - if queue_name == QueuePriority.PROFESSIONAL: - task = execute_workflow_professional.delay(task_data_dict) - elif queue_name == QueuePriority.TEAM: - task = execute_workflow_team.delay(task_data_dict) - else: # SANDBOX - task = execute_workflow_sandbox.delay(task_data_dict) + try: + task: AsyncResult[Any] | None = None + if queue_name == QueuePriority.PROFESSIONAL: + task = execute_workflow_professional.delay(task_data_dict) + elif queue_name == QueuePriority.TEAM: + task = execute_workflow_team.delay(task_data_dict) + else: # SANDBOX + task = execute_workflow_sandbox.delay(task_data_dict) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise # 10. Update trigger log with task info trigger_log.status = WorkflowTriggerStatus.QUEUED diff --git a/api/services/billing_service.py b/api/services/billing_service.py index a1362ccad6..eeaddfee2f 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -32,6 +32,102 @@ class SubscriptionPlan(TypedDict): expiration_date: int +class QuotaReserveResult(TypedDict): + reservation_id: str + available: int + reserved: int + + +class QuotaCommitResult(TypedDict): + available: int + reserved: int + refunded: int + + +class QuotaReleaseResult(TypedDict): + available: int + reserved: int + released: int + + +_quota_reserve_adapter = TypeAdapter(QuotaReserveResult) +_quota_commit_adapter = TypeAdapter(QuotaCommitResult) +_quota_release_adapter = TypeAdapter(QuotaReleaseResult) +class _BillingQuota(TypedDict): + size: int + limit: int + + +class _VectorSpaceQuota(TypedDict): + size: float + limit: int + + +class _KnowledgeRateLimit(TypedDict): + # NOTE (hj24): + # 1. Return for sandbox users but is null for other plans, it's defined but never used. + # 2. Keep it for compatibility for now, can be deprecated in future versions. + size: NotRequired[int] + # NOTE END + limit: int + + +class _BillingSubscription(TypedDict): + plan: str + interval: str + education: bool + + +class BillingInfo(TypedDict): + """Response of /subscription/info. + + NOTE (hj24): + - Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python() + - To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter: + 1. validate_python in non-strict mode will coerce it to the expected type + 2. In strict mode, it will raise ValidationError + 3. To preserve compatibility, always keep non-strict mode here and avoid strict mode + """ + + enabled: bool + subscription: _BillingSubscription + members: _BillingQuota + apps: _BillingQuota + vector_space: _VectorSpaceQuota + knowledge_rate_limit: _KnowledgeRateLimit + documents_upload_quota: _BillingQuota + annotation_quota_limit: _BillingQuota + docs_processing: str + can_replace_logo: bool + model_load_balancing_enabled: bool + knowledge_pipeline_publish_enabled: bool + next_credit_reset_date: NotRequired[int] + + +_billing_info_adapter = TypeAdapter(BillingInfo) + + +class _TenantFeatureQuota(TypedDict): + usage: int + limit: int + reset_date: NotRequired[int] + + +class TenantFeatureQuotaInfo(TypedDict): + """Response of /quota/info. + + NOTE (hj24): + - Same convention as BillingInfo: billing may return int fields as str, + always keep non-strict mode to auto-coerce. + """ + + trigger_event: _TenantFeatureQuota + api_rate_limit: _TenantFeatureQuota + + +_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo) + + class _BillingQuota(TypedDict): size: int limit: int @@ -149,11 +245,63 @@ class BillingService: @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): + """Deprecated: Use get_quota_info instead.""" params = {"tenant_id": tenant_id} - usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params) return usage_info + @classmethod + def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo: + params = {"tenant_id": tenant_id} + return _tenant_feature_quota_info_adapter.validate_python( + cls._send_request("GET", "/quota/info", params=params) + ) + + @classmethod + def quota_reserve( + cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None + ) -> QuotaReserveResult: + """Reserve quota before task execution.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "request_id": request_id, + "amount": amount, + } + if meta: + payload["meta"] = meta + return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload)) + + @classmethod + def quota_commit( + cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None + ) -> QuotaCommitResult: + """Commit a reservation with actual consumption.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + "actual_amount": actual_amount, + } + if meta: + payload["meta"] = meta + return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload)) + + @classmethod + def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult: + """Release a reservation (cancel, return frozen quota).""" + return _quota_release_adapter.validate_python( + cls._send_request( + "POST", + "/quota/release", + json={ + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + }, + ) + ) + @classmethod def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict: params = {"tenant_id": tenant_id} diff --git a/api/services/feature_service.py b/api/services/feature_service.py index df653e0ba7..9216a7fb99 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -281,7 +281,7 @@ class FeatureService: def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id) + features_usage_info = BillingService.get_quota_info(tenant_id) features.billing.enabled = billing_info["enabled"] features.billing.subscription.plan = billing_info["subscription"]["plan"] diff --git a/api/services/quota_service.py b/api/services/quota_service.py new file mode 100644 index 0000000000..4c784315c7 --- /dev/null +++ b/api/services/quota_service.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from enums.quota_type import QuotaType + +logger = logging.getLogger(__name__) + + +@dataclass +class QuotaCharge: + """ + Result of a quota reservation (Reserve phase). + + Lifecycle: + charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id) + try: + do_work() + charge.commit() # Confirm consumption + except: + charge.refund() # Release frozen quota + + If neither commit() nor refund() is called, the billing system's + cleanup CronJob will auto-release the reservation within ~75 seconds. + """ + + success: bool + charge_id: str | None # reservation_id + _quota_type: QuotaType + _tenant_id: str | None = None + _feature_key: str | None = None + _amount: int = 0 + _committed: bool = field(default=False, repr=False) + + def commit(self, actual_amount: int | None = None) -> None: + """ + Confirm the consumption with actual amount. + + Args: + actual_amount: Actual amount consumed. Defaults to the reserved amount. + If less than reserved, the difference is refunded automatically. + """ + if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key: + return + + try: + from services.billing_service import BillingService + + amount = actual_amount if actual_amount is not None else self._amount + BillingService.quota_commit( + tenant_id=self._tenant_id, + feature_key=self._feature_key, + reservation_id=self.charge_id, + actual_amount=amount, + ) + self._committed = True + logger.debug( + "Committed %s quota for tenant %s, reservation_id: %s, amount: %d", + self._quota_type, + self._tenant_id, + self.charge_id, + amount, + ) + except Exception: + logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id) + + def refund(self) -> None: + """ + Release the reserved quota (cancel the charge). + + Safe to call even if: + - charge failed or was disabled (charge_id is None) + - already committed (Release after Commit is a no-op) + - already refunded (idempotent) + + This method guarantees no exceptions will be raised. + """ + if not self.charge_id or not self._tenant_id or not self._feature_key: + return + + QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key) + + +def unlimited() -> QuotaCharge: + from enums.quota_type import QuotaType + + return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) + + +class QuotaService: + """Orchestrates quota reserve / commit / release lifecycle via BillingService.""" + + @staticmethod + def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve + immediate Commit (one-shot mode). + + The returned QuotaCharge supports .refund() which calls Release. + For two-phase usage (e.g. streaming), use reserve() directly. + """ + charge = QuotaService.reserve(quota_type, tenant_id, amount) + if charge.success and charge.charge_id: + charge.commit() + return charge + + @staticmethod + def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve quota before task execution (Reserve phase only). + + The caller MUST call charge.commit() after the task succeeds, + or charge.refund() if the task fails. + + Raises: + QuotaExceededError: When quota is insufficient + """ + from services.billing_service import BillingService + from services.errors.app import QuotaExceededError + + if not dify_config.BILLING_ENABLED: + logger.debug("Billing disabled, allowing request for %s", tenant_id) + return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type) + + logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id) + + if amount <= 0: + raise ValueError("Amount to reserve must be greater than 0") + + request_id = str(uuid.uuid4()) + feature_key = quota_type.billing_key + + try: + reserve_resp = BillingService.quota_reserve( + tenant_id=tenant_id, + feature_key=feature_key, + request_id=request_id, + amount=amount, + ) + + reservation_id = reserve_resp.get("reservation_id") + if not reservation_id: + logger.warning( + "Reserve returned no reservation_id for %s, feature %s, response: %s", + tenant_id, + quota_type.value, + reserve_resp, + ) + raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount) + + logger.debug( + "Reserved %d %s quota for tenant %s, reservation_id: %s", + amount, + quota_type.value, + tenant_id, + reservation_id, + ) + return QuotaCharge( + success=True, + charge_id=reservation_id, + _quota_type=quota_type, + _tenant_id=tenant_id, + _feature_key=feature_key, + _amount=amount, + ) + + except QuotaExceededError: + raise + except ValueError: + raise + except Exception: + logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value) + return unlimited() + + @staticmethod + def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool: + if not dify_config.BILLING_ENABLED: + return True + + if amount <= 0: + raise ValueError("Amount to check must be greater than 0") + + try: + remaining = QuotaService.get_remaining(quota_type, tenant_id) + return remaining >= amount if remaining != -1 else True + except Exception: + logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value) + return True + + @staticmethod + def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None: + """Release a reservation. Guarantees no exceptions.""" + try: + from services.billing_service import BillingService + + if not dify_config.BILLING_ENABLED: + return + + if not reservation_id: + return + + logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id) + BillingService.quota_release( + tenant_id=tenant_id, + feature_key=feature_key, + reservation_id=reservation_id, + ) + except Exception: + logger.exception("Failed to release quota, reservation_id: %s", reservation_id) + + @staticmethod + def get_remaining(quota_type: QuotaType, tenant_id: str) -> int: + from services.billing_service import BillingService + + try: + usage_info = BillingService.get_quota_info(tenant_id) + if isinstance(usage_info, dict): + feature_info = usage_info.get(quota_type.billing_key, {}) + if isinstance(feature_info, dict): + limit = feature_info.get("limit", 0) + usage = feature_info.get("usage", 0) + if limit == -1: + return -1 + return max(0, limit - usage) + return 0 + except Exception: + logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value) + return -1 diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index bb767a6759..c782bffad4 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -38,6 +38,7 @@ from models.workflow import Workflow from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData @@ -819,9 +820,9 @@ class WebhookService: user_id=None, ) - # consume quota before triggering workflow execution + # reserve quota before triggering workflow execution try: - QuotaType.TRIGGER.consume(webhook_trigger.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) logger.info( @@ -832,11 +833,16 @@ class WebhookService: raise # Trigger workflow execution asynchronously - AsyncWorkflowService.trigger_workflow_async( - session, - end_user, - trigger_data, - ) + try: + AsyncWorkflowService.trigger_workflow_async( + session, + end_user, + trigger_data, + ) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise except Exception: logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 56626e372e..b9f382eccf 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -28,7 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from models.enums import ( AppTriggerType, CreatorUserRole, @@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_request_service import TriggerHttpRequestCachingService @@ -298,10 +299,10 @@ def dispatch_triggered_workflow( icon_dark_filename=trigger_entity.identity.icon_dark or "", ) - # consume quota before invoking trigger + # reserve quota before invoking trigger quota_charge = unlimited() try: - quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) logger.info( @@ -387,6 +388,7 @@ def dispatch_triggered_workflow( raise ValueError(f"End user not found for app {plugin_trigger.app_id}") AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data) + quota_charge.commit() dispatched_count += 1 logger.info( "Triggered workflow for app %s with trigger event %s", diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index 8c64d3ab27..dfb2fb3391 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import ( ScheduleNotFoundError, TenantOwnerNotFoundError, ) -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.schedule_service import ScheduleService from services.workflow.entities import ScheduleTriggerData @@ -43,7 +44,7 @@ def run_schedule_trigger(schedule_id: str) -> None: quota_charge = unlimited() try: - quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) @@ -61,6 +62,7 @@ def run_schedule_trigger(schedule_id: str) -> None: tenant_id=schedule.tenant_id, ), ) + quota_charge.commit() logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) except Exception as e: quota_charge.refund() diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5b1a4790f5..3229693fd4 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -36,12 +36,19 @@ class TestAppGenerateService: ) as mock_message_based_generator, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config, + patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config, patch("configs.dify_config", autospec=True) as mock_global_dify_config, ): # Setup default mock returns for billing service - mock_billing_service.update_tenant_feature_plan_usage.return_value = { - "result": "success", - "history_id": "test_history_id", + mock_billing_service.quota_reserve.return_value = { + "reservation_id": "test-reservation-id", + "available": 100, + "reserved": 1, + } + mock_billing_service.quota_commit.return_value = { + "available": 99, + "reserved": 0, + "refunded": 0, } # Setup default mock returns for workflow service @@ -101,6 +108,8 @@ class TestAppGenerateService: mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100 mock_dify_config.APP_DAILY_RATE_LIMIT = 1000 + mock_quota_dify_config.BILLING_ENABLED = False + mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 @@ -118,6 +127,7 @@ class TestAppGenerateService: "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, + "quota_dify_config": mock_quota_dify_config, "global_dify_config": mock_global_dify_config, } @@ -465,6 +475,7 @@ class TestAppGenerateService: # Set BILLING_ENABLED to True for this test mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True + mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True # Setup test arguments @@ -478,8 +489,10 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - # Verify billing service was called to consume quota - mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() + # Verify billing two-phase quota (reserve + commit) + billing = mock_external_service_dependencies["billing_service"] + billing.quota_reserve.assert_called_once() + billing.quota_commit.assert_called_once() def test_generate_with_invalid_app_mode( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index d725fb990a..7c4553d4a0 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log( ) # Mock quota to avoid rate limiting - from enums import quota_type + from services import quota_service - monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited()) # Execute schedule trigger workflow_schedule_tasks.run_schedule_trigger(plan.id) diff --git a/api/tests/unit_tests/enums/__init__.py b/api/tests/unit_tests/enums/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/enums/test_quota_type.py b/api/tests/unit_tests/enums/test_quota_type.py new file mode 100644 index 0000000000..f256ff3b4e --- /dev/null +++ b/api/tests/unit_tests/enums/test_quota_type.py @@ -0,0 +1,349 @@ +"""Unit tests for QuotaType, QuotaService, and QuotaCharge.""" + +from unittest.mock import patch + +import pytest + +from enums.quota_type import QuotaType +from services.quota_service import QuotaCharge, QuotaService, unlimited + + +class TestQuotaType: + def test_billing_key_trigger(self): + assert QuotaType.TRIGGER.billing_key == "trigger_event" + + def test_billing_key_workflow(self): + assert QuotaType.WORKFLOW.billing_key == "api_rate_limit" + + def test_billing_key_unlimited_raises(self): + with pytest.raises(ValueError, match="Invalid quota type"): + _ = QuotaType.UNLIMITED.billing_key + + +class TestQuotaService: + def test_reserve_billing_disabled(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService"), + ): + mock_cfg.BILLING_ENABLED = False + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") + assert charge.success is True + assert charge.charge_id is None + + def test_reserve_zero_amount_raises(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = True + with pytest.raises(ValueError, match="greater than 0"): + QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0) + + def test_reserve_success(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99} + + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1) + + assert charge.success is True + assert charge.charge_id == "rid-1" + assert charge._tenant_id == "t1" + assert charge._feature_key == "trigger_event" + assert charge._amount == 1 + mock_bs.quota_reserve.assert_called_once() + + def test_reserve_no_reservation_id_raises(self): + from services.errors.app import QuotaExceededError + + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {} + + with pytest.raises(QuotaExceededError): + QuotaService.reserve(QuotaType.TRIGGER, "t1") + + def test_reserve_quota_exceeded_propagates(self): + from services.errors.app import QuotaExceededError + + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1) + + with pytest.raises(QuotaExceededError): + QuotaService.reserve(QuotaType.TRIGGER, "t1") + + def test_reserve_api_exception_returns_unlimited(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.side_effect = RuntimeError("network") + + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") + assert charge.success is True + assert charge.charge_id is None + + def test_consume_calls_reserve_and_commit(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"} + mock_bs.quota_commit.return_value = {} + + charge = QuotaService.consume(QuotaType.TRIGGER, "t1") + assert charge.success is True + mock_bs.quota_commit.assert_called_once() + + def test_check_billing_disabled(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = False + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True + + def test_check_zero_amount_raises(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = True + with pytest.raises(ValueError, match="greater than 0"): + QuotaService.check(QuotaType.TRIGGER, "t1", amount=0) + + def test_check_sufficient_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=100), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True + + def test_check_insufficient_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=5), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False + + def test_check_unlimited_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=-1), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True + + def test_check_exception_returns_true(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", side_effect=RuntimeError), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True + + def test_release_billing_disabled(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = False + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + mock_bs.quota_release.assert_not_called() + + def test_release_empty_reservation(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event") + mock_bs.quota_release.assert_not_called() + + def test_release_success(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_release.return_value = {} + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + mock_bs.quota_release.assert_called_once_with( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1" + ) + + def test_release_exception_swallowed(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_release.side_effect = RuntimeError("fail") + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + + def test_get_remaining_normal(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70 + + def test_get_remaining_unlimited(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 + + def test_get_remaining_over_limit_returns_zero(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_exception_returns_neg1(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.side_effect = RuntimeError + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 + + def test_get_remaining_empty_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_non_dict_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = "invalid" + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_feature_not_in_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}} + remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1") + assert remaining == 0 + + def test_get_remaining_non_dict_feature_info(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + +class TestQuotaCharge: + def test_commit_success(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + mock_bs.quota_commit.assert_called_once_with( + tenant_id="t1", + feature_key="trigger_event", + reservation_id="rid-1", + actual_amount=1, + ) + assert charge._committed is True + + def test_commit_with_actual_amount(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=10, + ) + charge.commit(actual_amount=5) + call_kwargs = mock_bs.quota_commit.call_args[1] + assert call_kwargs["actual_amount"] == 5 + + def test_commit_idempotent(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + charge.commit() + assert mock_bs.quota_commit.call_count == 1 + + def test_commit_no_charge_id_noop(self): + with patch("services.billing_service.BillingService") as mock_bs: + charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER) + charge.commit() + mock_bs.quota_commit.assert_not_called() + + def test_commit_no_tenant_id_noop(self): + with patch("services.billing_service.BillingService") as mock_bs: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id=None, + _feature_key="trigger_event", + ) + charge.commit() + mock_bs.quota_commit.assert_not_called() + + def test_commit_exception_swallowed(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.side_effect = RuntimeError("fail") + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + + def test_refund_success(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + ) + charge.refund() + mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + + def test_refund_no_charge_id_noop(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER) + charge.refund() + mock_rel.assert_not_called() + + def test_refund_no_tenant_id_noop(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id=None, + ) + charge.refund() + mock_rel.assert_not_called() + + +class TestUnlimited: + def test_unlimited_returns_success_with_no_charge_id(self): + charge = unlimited() + assert charge.success is True + assert charge.charge_id is None + assert charge._quota_type == QuotaType.UNLIMITED diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index c2b430c551..c88daf6b1e 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -23,6 +23,7 @@ import pytest import services.app_generate_service as ags_module from core.app.entities.app_invoke_entities import InvokeFrom +from enums.quota_type import QuotaType from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError @@ -447,8 +448,8 @@ class TestGenerateBilling: def test_billing_enabled_consumes_quota(self, mocker, monkeypatch): monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() - consume_mock = mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + reserve_mock = mocker.patch( + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( @@ -467,7 +468,8 @@ class TestGenerateBilling: invoke_from=InvokeFrom.SERVICE_API, streaming=False, ) - consume_mock.assert_called_once_with("tenant-id") + reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id") + quota_charge.commit.assert_called_once() def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch): from services.errors.app import QuotaExceededError @@ -475,7 +477,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaService.reserve", side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1), ) @@ -492,7 +494,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py index ca6ff9dc63..361e95a557 100644 --- a/api/tests/unit_tests/services/test_async_workflow_service.py +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -57,7 +57,7 @@ class TestAsyncWorkflowService: - repo: SQLAlchemyWorkflowTriggerLogRepository - dispatcher_manager_class: QueueDispatcherManager class - dispatcher: dispatcher instance - - quota_workflow: QuotaType.WORKFLOW + - quota_service: QuotaService mock - get_workflow: AsyncWorkflowService._get_workflow method - professional_task: execute_workflow_professional - team_task: execute_workflow_team @@ -72,7 +72,12 @@ class TestAsyncWorkflowService: mock_repo.create.side_effect = _create_side_effect mock_dispatcher = MagicMock() - quota_workflow = MagicMock() + mock_quota_service = MagicMock() + mock_get_workflow = MagicMock() + + mock_professional_task = MagicMock() + mock_team_task = MagicMock() + mock_sandbox_task = MagicMock() with ( patch.object( @@ -88,8 +93,8 @@ class TestAsyncWorkflowService: ) as mock_get_workflow, patch.object( async_workflow_service_module, - "QuotaType", - new=SimpleNamespace(WORKFLOW=quota_workflow), + "QuotaService", + new=mock_quota_service, ), patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task, patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task, @@ -102,7 +107,7 @@ class TestAsyncWorkflowService: "repo": mock_repo, "dispatcher_manager_class": mock_dispatcher_manager_class, "dispatcher": mock_dispatcher, - "quota_workflow": quota_workflow, + "quota_service": mock_quota_service, "get_workflow": mock_get_workflow, "professional_task": mock_professional_task, "team_task": mock_team_task, @@ -141,6 +146,9 @@ class TestAsyncWorkflowService: mocks["team_task"].delay.return_value = task_result mocks["sandbox_task"].delay.return_value = task_result + quota_charge_mock = MagicMock() + mocks["quota_service"].reserve.return_value = quota_charge_mock + class DummyAccount: def __init__(self, user_id: str): self.id = user_id @@ -158,7 +166,8 @@ class TestAsyncWorkflowService: assert result.status == "queued" assert result.queue == queue_name - mocks["quota_workflow"].consume.assert_called_once_with("tenant-123") + mocks["quota_service"].reserve.assert_called_once() + quota_charge_mock.commit.assert_called_once() assert session.commit.call_count == 2 created_log = mocks["repo"].create.call_args[0][0] @@ -245,7 +254,7 @@ class TestAsyncWorkflowService: mocks = async_workflow_trigger_mocks mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM mocks["get_workflow"].return_value = workflow - mocks["quota_workflow"].consume.side_effect = QuotaExceededError( + mocks["quota_service"].reserve.side_effect = QuotaExceededError( feature="workflow", tenant_id="tenant-123", required=1, diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 9ab0171eac..36592196c6 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -425,7 +425,7 @@ class TestBillingServiceUsageCalculation: yield mock def test_get_tenant_feature_plan_usage_info(self, mock_send_request): - """Test retrieval of tenant feature plan usage information.""" + """Test retrieval of tenant feature plan usage information (legacy endpoint).""" # Arrange tenant_id = "tenant-123" expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}} @@ -438,6 +438,20 @@ class TestBillingServiceUsageCalculation: assert result == expected_response mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id}) + def test_get_quota_info(self, mock_send_request): + """Test retrieval of quota info from new endpoint.""" + # Arrange + tenant_id = "tenant-123" + expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_quota_info(tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id}) + def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request): """Test updating tenant feature usage with positive delta (adding credits).""" # Arrange @@ -515,6 +529,150 @@ class TestBillingServiceUsageCalculation: ) +class TestBillingServiceQuotaOperations: + """Unit tests for quota reserve/commit/release operations.""" + + @pytest.fixture + def mock_send_request(self): + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_quota_reserve_success(self, mock_send_request): + expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1} + mock_send_request.return_value = expected + + result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1) + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/reserve", + json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1}, + ) + + def test_quota_reserve_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"} + + result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1) + + assert result["available"] == 99 + assert isinstance(result["available"], int) + assert result["reserved"] == 1 + assert isinstance(result["reserved"], int) + + def test_quota_reserve_with_meta(self, mock_send_request): + mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1} + meta = {"source": "webhook"} + + BillingService.quota_reserve( + tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta + ) + + call_json = mock_send_request.call_args[1]["json"] + assert call_json["meta"] == {"source": "webhook"} + + def test_quota_commit_success(self, mock_send_request): + expected = {"available": 98, "reserved": 0, "refunded": 0} + mock_send_request.return_value = expected + + result = BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1 + ) + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/commit", + json={ + "tenant_id": "t1", + "feature_key": "trigger_event", + "reservation_id": "rid-1", + "actual_amount": 1, + }, + ) + + def test_quota_commit_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"} + + result = BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1 + ) + + assert result["available"] == 97 + assert isinstance(result["available"], int) + assert result["refunded"] == 1 + assert isinstance(result["refunded"], int) + + def test_quota_commit_with_meta(self, mock_send_request): + mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0} + meta = {"reason": "partial"} + + BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta + ) + + call_json = mock_send_request.call_args[1]["json"] + assert call_json["meta"] == {"reason": "partial"} + + def test_quota_release_success(self, mock_send_request): + expected = {"available": 100, "reserved": 0, "released": 1} + mock_send_request.return_value = expected + + result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1") + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/release", + json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"}, + ) + + def test_quota_release_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"} + + result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s") + + assert result["available"] == 100 + assert isinstance(result["available"], int) + assert result["released"] == 1 + assert isinstance(result["released"], int) + + def test_get_quota_info_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int for get_quota_info.""" + mock_send_request.return_value = { + "trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"}, + "api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"}, + } + + result = BillingService.get_quota_info("t1") + + assert result["trigger_event"]["usage"] == 42 + assert isinstance(result["trigger_event"]["usage"], int) + assert result["trigger_event"]["limit"] == 3000 + assert isinstance(result["trigger_event"]["limit"], int) + assert result["trigger_event"]["reset_date"] == 1700000000 + assert isinstance(result["trigger_event"]["reset_date"], int) + assert result["api_rate_limit"]["limit"] == -1 + assert isinstance(result["api_rate_limit"]["limit"], int) + + def test_get_quota_info_accepts_int_values(self, mock_send_request): + """Test that get_quota_info works with native int values.""" + expected = { + "trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000}, + "api_rate_limit": {"usage": 0, "limit": -1}, + } + mock_send_request.return_value = expected + + result = BillingService.get_quota_info("t1") + + assert result["trigger_event"]["usage"] == 42 + assert result["trigger_event"]["limit"] == 3000 + assert result["api_rate_limit"]["limit"] == -1 + + class TestBillingServiceRateLimitEnforcement: """Unit tests for rate limit enforcement mechanisms. diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index ffdcc046f9..02fbe473df 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -559,3 +559,772 @@ class TestWebhookServiceUnit: result = _prepare_webhook_execution("test_webhook", is_debug=True) assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None) + + + +# === Merged from test_webhook_service_additional.py === + + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from flask import Flask +from graphon.variables.types import SegmentType +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import RequestEntityTooLarge + +from core.workflow.nodes.trigger_webhook.entities import ( + ContentType, + WebhookBodyParameter, + WebhookData, + WebhookParameter, +) +from models.enums import AppTriggerStatus +from models.model import App +from models.trigger import WorkflowWebhookTrigger +from models.workflow import Workflow +from services.errors.app import QuotaExceededError +from services.trigger import webhook_service as service_module +from services.trigger.webhook_service import WebhookService + + +class _FakeQuery: + def __init__(self, result: Any) -> None: + self._result = result + + def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery": + return self + + def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery": + return self + + def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery": + return self + + def first(self) -> Any: + return self._result + + +class _SessionContext: + def __init__(self, session: Any) -> None: + self._session = session + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +class _SessionmakerContext: + def __init__(self, session: Any) -> None: + self._session = session + + def begin(self) -> "_SessionmakerContext": + return self + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +@pytest.fixture +def flask_app() -> Flask: + return Flask(__name__) + + +def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock())) + monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session)) + monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session)) + + +def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger: + return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs)) + + +def _workflow(**kwargs: Any) -> Workflow: + return cast(Workflow, SimpleNamespace(**kwargs)) + + +def _app(**kwargs: Any) -> App: + return cast(App, SimpleNamespace(**kwargs)) + + +def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + fake_session = MagicMock() + fake_session.scalar.return_value = None + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="Webhook not found"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, None] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="App trigger not found"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED) + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="rate limited"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED) + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="disabled"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="Workflow not found"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}} + + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow] + _patch_session(monkeypatch, fake_session) + + # Act + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + # Assert + assert got_trigger is webhook_trigger + assert got_workflow is workflow + assert got_node_config == {"data": {"key": "value"}} + + +def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}} + + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, workflow] + _patch_session(monkeypatch, fake_session) + + # Act + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow( + "webhook-1", is_debug=True + ) + + # Assert + assert got_trigger is webhook_trigger + assert got_workflow is workflow + assert got_node_config == {"data": {"mode": "debug"}} + + +def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + warning_mock = MagicMock() + monkeypatch.setattr(service_module.logger, "warning", warning_mock) + webhook_trigger = MagicMock() + + # Act + with flask_app.test_request_context( + "/webhook", + method="POST", + headers={"Content-Type": "application/vnd.custom"}, + data="plain content", + ): + result = WebhookService.extract_webhook_data(webhook_trigger) + + # Assert + assert result["body"] == {"raw": "plain content"} + warning_mock.assert_called_once() + + +def test_extract_webhook_data_should_raise_for_request_too_large( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1) + + # Act / Assert + with flask_app.test_request_context("/webhook", method="POST", data="ab"): + with pytest.raises(RequestEntityTooLarge): + WebhookService.extract_webhook_data(MagicMock()) + + +def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None: + # Arrange + webhook_trigger = MagicMock() + + # Act + with flask_app.test_request_context("/webhook", method="POST", data=b""): + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + # Assert + assert body == {"raw": None} + assert files == {} + + +def test_extract_octet_stream_body_should_return_none_when_processing_raises( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = MagicMock() + monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream")) + monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom"))) + + # Act + with flask_app.test_request_context("/webhook", method="POST", data=b"abc"): + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + # Assert + assert body == {"raw": None} + assert files == {} + + +def test_extract_text_body_should_return_empty_string_when_request_read_fails( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error"))) + + # Act + with flask_app.test_request_context("/webhook", method="POST", data="abc"): + body, files = WebhookService._extract_text_body() + + # Assert + assert body == {"raw": ""} + assert files == {} + + +def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + fake_magic = MagicMock() + fake_magic.from_buffer.side_effect = RuntimeError("magic failed") + monkeypatch.setattr(service_module, "magic", fake_magic) + + # Act + result = WebhookService._detect_binary_mimetype(b"binary") + + # Assert + assert result == "application/octet-stream" + + +def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1") + file_obj = MagicMock() + file_obj.to_dict.return_value = {"id": "f-1"} + monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj)) + monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None))) + + uploaded = MagicMock() + uploaded.filename = "file.unknown" + uploaded.content_type = None + uploaded.read.return_value = b"content" + + # Act + result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger) + + # Assert + assert result == {"f": {"id": "f-1"}} + + +def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1") + manager = MagicMock() + manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1") + monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager)) + expected_file = MagicMock() + monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file)) + + # Act + result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger) + + # Assert + assert result is expected_file + manager.create_file_by_raw.assert_called_once() + + +@pytest.mark.parametrize( + ("raw_value", "param_type", "expected"), + [ + ("42", SegmentType.NUMBER, 42), + ("3.14", SegmentType.NUMBER, 3.14), + ("yes", SegmentType.BOOLEAN, True), + ("no", SegmentType.BOOLEAN, False), + ], +) +def test_convert_form_value_should_convert_supported_types( + raw_value: str, + param_type: str, + expected: Any, +) -> None: + # Arrange + + # Act + result = WebhookService._convert_form_value("param", raw_value, param_type) + + # Assert + assert result == expected + + +def test_convert_form_value_should_raise_for_unsupported_type() -> None: + # Arrange + + # Act / Assert + with pytest.raises(ValueError, match="Unsupported type"): + WebhookService._convert_form_value("p", "x", SegmentType.FILE) + + +def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + warning_mock = MagicMock() + monkeypatch.setattr(service_module.logger, "warning", warning_mock) + + # Act + result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type") + + # Assert + assert result == {"x": 1} + warning_mock.assert_called_once() + + +def test_validate_and_convert_value_should_wrap_conversion_errors() -> None: + # Arrange + + # Act / Assert + with pytest.raises(ValueError, match="validation failed"): + WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True) + + +def test_process_parameters_should_raise_when_required_parameter_missing() -> None: + # Arrange + raw_params = {"optional": "x"} + config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)] + + # Act / Assert + with pytest.raises(ValueError, match="Required parameter missing"): + WebhookService._process_parameters(raw_params, config, is_form_data=True) + + +def test_process_parameters_should_include_unconfigured_parameters() -> None: + # Arrange + raw_params = {"known": "1", "unknown": "x"} + config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)] + + # Act + result = WebhookService._process_parameters(raw_params, config, is_form_data=True) + + # Assert + assert result == {"known": 1, "unknown": "x"} + + +def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None: + # Arrange + + # Act / Assert + with pytest.raises(ValueError, match="Required body content missing"): + WebhookService._process_body_parameters( + raw_body={"raw": ""}, + body_configs=[WebhookBodyParameter(name="raw", required=True)], + content_type=ContentType.TEXT, + ) + + +def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None: + # Arrange + raw_body = {"message": "hello", "extra": "x"} + body_configs = [ + WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True), + WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True), + ] + + # Act + result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA) + + # Assert + assert result == {"message": "hello", "extra": "x"} + + +def test_validate_required_headers_should_accept_sanitized_header_names() -> None: + # Arrange + headers = {"x_api_key": "123"} + configs = [WebhookParameter(name="x-api-key", required=True)] + + # Act + WebhookService._validate_required_headers(headers, configs) + + # Assert + assert True + + +def test_validate_required_headers_should_raise_when_required_header_missing() -> None: + # Arrange + headers = {"x-other": "123"} + configs = [WebhookParameter(name="x-api-key", required=True)] + + # Act / Assert + with pytest.raises(ValueError, match="Required header missing"): + WebhookService._validate_required_headers(headers, configs) + + +def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None: + # Arrange + webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}} + node_data = WebhookData(method="post", content_type=ContentType.TEXT) + + # Act + result = WebhookService._validate_http_metadata(webhook_data, node_data) + + # Assert + assert result["valid"] is False + assert "Content-type mismatch" in result["error"] + + +def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None: + # Arrange + headers = {"content-type": "application/json; charset=utf-8"} + + # Act + result = WebhookService._extract_content_type(headers) + + # Assert + assert result == "application/json" + + +def test_build_workflow_inputs_should_include_expected_keys() -> None: + # Arrange + webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}} + + # Act + result = WebhookService.build_workflow_inputs(webhook_data) + + # Assert + assert result["webhook_data"] == webhook_data + assert result["webhook_headers"] == {"h": "v"} + assert result["webhook_query_params"] == {"q": 1} + assert result["webhook_body"] == {"b": 2} + + +def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = _workflow_trigger( + app_id="app-1", + node_id="node-1", + tenant_id="tenant-1", + webhook_id="webhook-1", + ) + workflow = _workflow(id="wf-1") + webhook_data = {"body": {"x": 1}} + + session = MagicMock() + _patch_session(monkeypatch, session) + + end_user = SimpleNamespace(id="end-user-1") + monkeypatch.setattr( + service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user) + ) + quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock())) + monkeypatch.setattr(service_module, "QuotaType", quota_type) + trigger_async_mock = MagicMock() + monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock) + + # Act + WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) + + # Assert + trigger_async_mock.assert_called_once() + + +def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = _workflow_trigger( + app_id="app-1", + node_id="node-1", + tenant_id="tenant-1", + webhook_id="webhook-1", + ) + workflow = _workflow(id="wf-1") + + session = MagicMock() + _patch_session(monkeypatch, session) + + monkeypatch.setattr( + service_module.EndUserService, + "get_or_create_end_user_by_type", + MagicMock(return_value=SimpleNamespace(id="end-user-1")), + ) + monkeypatch.setattr( + service_module.QuotaService, + "reserve", + MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1)), + ) + mark_rate_limited_mock = MagicMock() + monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock) + + # Act / Assert + with pytest.raises(QuotaExceededError): + WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow) + mark_rate_limited_mock.assert_called_once_with("tenant-1") + + +def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = _workflow_trigger( + app_id="app-1", + node_id="node-1", + tenant_id="tenant-1", + webhook_id="webhook-1", + ) + workflow = _workflow(id="wf-1") + + session = MagicMock() + _patch_session(monkeypatch, session) + + monkeypatch.setattr( + service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom")) + ) + logger_exception_mock = MagicMock() + monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock) + + # Act / Assert + with pytest.raises(RuntimeError, match="boom"): + WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow) + logger_exception_mock.assert_called_once() + + +def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow( + walk_nodes=lambda _node_type: [ + (f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1) + ] + ) + + # Act / Assert + with pytest.raises(ValueError, match="maximum webhook node limit"): + WebhookService.sync_webhook_relationships(app, workflow) + + +def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})]) + + lock = MagicMock() + lock.acquire.return_value = False + monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock)) + + # Act / Assert + with pytest.raises(RuntimeError, match="Failed to acquire lock"): + WebhookService.sync_webhook_relationships(app, workflow) + + +def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})]) + + class _WorkflowWebhookTrigger: + app_id = "app_id" + tenant_id = "tenant_id" + webhook_id = "webhook_id" + node_id = "node_id" + + def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None: + self.id = None + self.app_id = app_id + self.tenant_id = tenant_id + self.node_id = node_id + self.webhook_id = webhook_id + self.created_by = created_by + + class _Select: + def where(self, *args: Any, **kwargs: Any) -> "_Select": + return self + + class _Session: + def __init__(self) -> None: + self.added: list[Any] = [] + self.deleted: list[Any] = [] + self.commit_count = 0 + self.existing_records = [SimpleNamespace(node_id="node-stale")] + + def scalars(self, _stmt: Any) -> Any: + return SimpleNamespace(all=lambda: self.existing_records) + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def flush(self) -> None: + for idx, obj in enumerate(self.added, start=1): + if obj.id is None: + obj.id = f"rec-{idx}" + + def commit(self) -> None: + self.commit_count += 1 + + def delete(self, obj: Any) -> None: + self.deleted.append(obj) + + lock = MagicMock() + lock.acquire.return_value = True + lock.release.return_value = None + + fake_session = _Session() + + monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger) + monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select())) + monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock)) + redis_set_mock = MagicMock() + redis_delete_mock = MagicMock() + monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock) + monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock) + monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id")) + _patch_session(monkeypatch, fake_session) + + # Act + WebhookService.sync_webhook_relationships(app, workflow) + + # Assert + assert len(fake_session.added) == 1 + assert len(fake_session.deleted) == 1 + redis_set_mock.assert_called_once() + redis_delete_mock.assert_called_once() + lock.release.assert_called_once() + + +def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow(walk_nodes=lambda _node_type: []) + + class _Select: + def where(self, *args: Any, **kwargs: Any) -> "_Select": + return self + + class _Session: + def scalars(self, _stmt: Any) -> Any: + return SimpleNamespace(all=lambda: []) + + def commit(self) -> None: + return None + + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = RuntimeError("release failed") + + logger_exception_mock = MagicMock() + + monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select())) + monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock) + _patch_session(monkeypatch, _Session()) + + # Act + WebhookService.sync_webhook_relationships(app, workflow) + + # Assert + assert logger_exception_mock.call_count == 1 + + +def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None: + # Arrange + node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}} + + # Act + body, status = WebhookService.generate_webhook_response(node_config) + + # Assert + assert status == 200 + assert "message" in body + + +def test_generate_webhook_id_should_return_24_character_identifier() -> None: + # Arrange + + # Act + webhook_id = WebhookService.generate_webhook_id() + + # Assert + assert isinstance(webhook_id, str) + assert len(webhook_id) == 24 + + +def test_sanitize_key_should_return_original_value_for_non_string_input() -> None: + # Arrange + + # Act + result = WebhookService._sanitize_key(123) # type: ignore[arg-type] + + # Assert + assert result == 123 + diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index e08ece6666..30d8f3e410 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -9,6 +9,7 @@ import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' import { usePathname, useRouter, useSearchParams } from '@/next/navigation' +import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking' import { sendGAEvent } from '@/utils/gtag' import { fetchSetupStatusWithCache } from '@/utils/setup-status' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' @@ -45,6 +46,8 @@ export const AppInitializer = ({ (async () => { const action = searchParams.get('action') + rememberCreateAppExternalAttribution({ searchParams }) + if (oauthNewUser) { let utmInfo = null const utmInfoStr = Cookies.get('utm_info') diff --git a/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx b/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx index 3ebc5f7157..a319bb58f7 100644 --- a/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx +++ b/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx @@ -4,7 +4,6 @@ import { AppModeEnum } from '@/types/app' import Apps from '../index' const mockUseExploreAppList = vi.fn() -const mockTrackEvent = vi.fn() const mockImportDSL = vi.fn() const mockFetchAppDetail = vi.fn() const mockHandleCheckPluginDependencies = vi.fn() @@ -12,6 +11,7 @@ const mockGetRedirection = vi.fn() const mockPush = vi.fn() const mockToastSuccess = vi.fn() const mockToastError = vi.fn() +const mockTrackCreateApp = vi.fn() let latestDebounceFn = () => {} vi.mock('ahooks', () => ({ @@ -92,8 +92,8 @@ vi.mock('@/app/components/base/ui/toast', () => ({ error: (...args: unknown[]) => mockToastError(...args), }, })) -vi.mock('@/app/components/base/amplitude', () => ({ - trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), })) vi.mock('@/service/apps', () => ({ importDSL: (...args: unknown[]) => mockImportDSL(...args), @@ -246,10 +246,9 @@ describe('Apps', () => { })) }) - expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_template', expect.objectContaining({ - template_id: 'Alpha', - template_name: 'Alpha', - })) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated') expect(onSuccess).toHaveBeenCalled() expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('created-app-id') diff --git a/web/app/components/app/create-app-dialog/app-list/index.tsx b/web/app/components/app/create-app-dialog/app-list/index.tsx index 1aa40d2014..daf49115c8 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.tsx @@ -8,7 +8,6 @@ import * as React from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import AppTypeSelector from '@/app/components/app/type-selector' -import { trackEvent } from '@/app/components/base/amplitude' import Divider from '@/app/components/base/divider' import Input from '@/app/components/base/input' import Loading from '@/app/components/base/loading' @@ -25,6 +24,7 @@ import { useExploreAppList } from '@/service/use-explore' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import AppCard from '../app-card' import Sidebar, { AppCategories, AppCategoryLabel } from './sidebar' @@ -127,14 +127,7 @@ const Apps = ({ icon_background, description, }) - - // Track app creation from template - trackEvent('create_app_with_template', { - app_mode: mode, - template_id: currApp?.app.id, - template_name: currApp?.app.name, - description, - }) + trackCreateApp({ appMode: mode }) setIsShowCreateModal(false) toast.success(t('newApp.appCreated', { ns: 'app' })) diff --git a/web/app/components/app/create-app-modal/__tests__/index.spec.tsx b/web/app/components/app/create-app-modal/__tests__/index.spec.tsx index ee24ab4006..3e06b89f0e 100644 --- a/web/app/components/app/create-app-modal/__tests__/index.spec.tsx +++ b/web/app/components/app/create-app-modal/__tests__/index.spec.tsx @@ -1,7 +1,6 @@ import type { App } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' -import { trackEvent } from '@/app/components/base/amplitude' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' @@ -10,6 +9,7 @@ import { useRouter } from '@/next/navigation' import { createApp } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' +import { trackCreateApp } from '@/utils/create-app-tracking' import CreateAppModal from '../index' const ahooksMocks = vi.hoisted(() => ({ @@ -31,8 +31,8 @@ vi.mock('ahooks', () => ({ vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(), })) -vi.mock('@/app/components/base/amplitude', () => ({ - trackEvent: vi.fn(), +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: vi.fn(), })) vi.mock('@/service/apps', () => ({ createApp: vi.fn(), @@ -87,7 +87,7 @@ vi.mock('@/hooks/use-theme', () => ({ const mockUseRouter = vi.mocked(useRouter) const mockPush = vi.fn() const mockCreateApp = vi.mocked(createApp) -const mockTrackEvent = vi.mocked(trackEvent) +const mockTrackCreateApp = vi.mocked(trackCreateApp) const mockGetRedirection = vi.mocked(getRedirection) const mockUseProviderContext = vi.mocked(useProviderContext) const mockUseAppContext = vi.mocked(useAppContext) @@ -178,10 +178,7 @@ describe('CreateAppModal', () => { mode: AppModeEnum.ADVANCED_CHAT, })) - expect(mockTrackEvent).toHaveBeenCalledWith('create_app', { - app_mode: AppModeEnum.ADVANCED_CHAT, - description: '', - }) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.ADVANCED_CHAT }) expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated') expect(onSuccess).toHaveBeenCalled() expect(onClose).toHaveBeenCalled() diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index f2ced9b6c0..96c3045c59 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -6,7 +6,6 @@ import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon import { useDebounceFn, useKeyPress } from 'ahooks' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { trackEvent } from '@/app/components/base/amplitude' import AppIcon from '@/app/components/base/app-icon' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' @@ -25,6 +24,7 @@ import { createApp } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import { basePath } from '@/utils/var' import AppIconPicker from '../../base/app-icon-picker' import ShortcutsName from '../../workflow/shortcuts-name' @@ -80,11 +80,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: mode: appMode, }) - // Track app creation success - trackEvent('create_app', { - app_mode: appMode, - description, - }) + trackCreateApp({ appMode: app.mode }) toast.success(t('newApp.appCreated', { ns: 'app' })) onSuccess() diff --git a/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx b/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx index c1ffbc22e8..e106cc7eb3 100644 --- a/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx +++ b/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx @@ -2,12 +2,13 @@ import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { DSLImportMode, DSLImportStatus } from '@/models/app' +import { AppModeEnum } from '@/types/app' import CreateFromDSLModal, { CreateFromDSLModalTab } from '../index' const mockPush = vi.fn() const mockImportDSL = vi.fn() const mockImportDSLConfirm = vi.fn() -const mockTrackEvent = vi.fn() +const mockTrackCreateApp = vi.fn() const mockHandleCheckPluginDependencies = vi.fn() const mockGetRedirection = vi.fn() const toastMocks = vi.hoisted(() => ({ @@ -43,8 +44,8 @@ vi.mock('@/next/navigation', () => ({ }), })) -vi.mock('@/app/components/base/amplitude', () => ({ - trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), })) vi.mock('@/service/apps', () => ({ @@ -172,7 +173,7 @@ describe('CreateFromDSLModal', () => { id: 'import-1', status: DSLImportStatus.COMPLETED, app_id: 'app-1', - app_mode: 'chat', + app_mode: AppModeEnum.CHAT, }) render( @@ -196,10 +197,7 @@ describe('CreateFromDSLModal', () => { mode: DSLImportMode.YAML_URL, yaml_url: 'https://example.com/app.yml', }) - expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_dsl', expect.objectContaining({ - creation_method: 'dsl_url', - has_warnings: false, - })) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.CHAT }) expect(handleSuccess).toHaveBeenCalledTimes(1) expect(handleClose).toHaveBeenCalledTimes(1) expect(localStorage.getItem(NEED_REFRESH_APP_LIST_KEY)).toBe('1') @@ -212,7 +210,7 @@ describe('CreateFromDSLModal', () => { id: 'import-2', status: DSLImportStatus.COMPLETED_WITH_WARNINGS, app_id: 'app-2', - app_mode: 'chat', + app_mode: AppModeEnum.CHAT, }) render( @@ -275,7 +273,7 @@ describe('CreateFromDSLModal', () => { mockImportDSLConfirm.mockResolvedValue({ status: DSLImportStatus.COMPLETED, app_id: 'app-3', - app_mode: 'workflow', + app_mode: AppModeEnum.WORKFLOW, }) render( @@ -305,6 +303,7 @@ describe('CreateFromDSLModal', () => { expect(mockImportDSLConfirm).toHaveBeenCalledWith({ import_id: 'import-3', }) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.WORKFLOW }) }) it('should ignore empty import responses and prevent duplicate submissions while a request is in flight', async () => { @@ -332,7 +331,7 @@ describe('CreateFromDSLModal', () => { id: 'import-in-flight', status: DSLImportStatus.COMPLETED, app_id: 'app-1', - app_mode: 'chat', + app_mode: AppModeEnum.CHAT, }) }) diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index dd17655e3c..77000dbf0a 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -6,7 +6,6 @@ import { useDebounceFn, useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' import { useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { trackEvent } from '@/app/components/base/amplitude' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' @@ -27,6 +26,7 @@ import { } from '@/service/apps' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import ShortcutsName from '../../workflow/shortcuts-name' import Uploader from './uploader' @@ -112,12 +112,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS return const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { - // Track app creation from DSL import - trackEvent('create_app_with_dsl', { - app_mode, - creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url', - has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS, - }) + trackCreateApp({ appMode: app_mode }) if (onSuccess) onSuccess() @@ -179,6 +174,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS const { status, app_id, app_mode } = response if (status === DSLImportStatus.COMPLETED) { + trackCreateApp({ appMode: app_mode }) if (onSuccess) onSuccess() if (onClose) @@ -228,7 +224,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS isShow={show} onClose={noop} > -