Merge remote-tracking branch 'origin/deploy/dev' into feat/evaluation

# Conflicts:
#	api/tests/unit_tests/services/test_async_workflow_service.py
#	api/tests/unit_tests/services/test_webhook_service.py

Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2026-04-14 07:28:30 +00:00 committed by GitHub
commit e2e5ad0c33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 2392 additions and 317 deletions

View File

@ -1,4 +1,6 @@
import base64
import json
from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
@ -10,6 +12,7 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -77,3 +80,39 @@ class PartnerTenants(Resource):
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
_DEBUG_KEY = "billing:debug"
_DEBUG_TTL = timedelta(days=7)
class DebugDataPayload(BaseModel):
type: str = Field(..., min_length=1, description="Data type key")
data: str = Field(..., min_length=1, description="Data value to append")
@console_ns.route("/billing/debug/data")
class DebugData(Resource):
def post(self):
body = DebugDataPayload.model_validate(request.get_json(force=True))
item = json.dumps({
"type": body.type,
"data": body.data,
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
})
redis_client.lpush(_DEBUG_KEY, item)
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
return {"result": "ok"}, 201
def get(self):
recent = request.args.get("recent", 10, type=int)
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
return {
"data": [
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
]
}
def delete(self):
redis_client.delete(_DEBUG_KEY)
return {"result": "ok"}

View File

@ -1,56 +1,17 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@ -106,7 +107,7 @@ class AppGenerateService:
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@ -116,6 +117,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
quota_charge.commit()
effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
)

View File

@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -131,9 +132,10 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Check and consume quota
# 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
try:
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@ -153,13 +155,18 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
try:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED

View File

@ -32,6 +32,102 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _BillingQuota(TypedDict):
size: int
limit: int
class _VectorSpaceQuota(TypedDict):
size: float
limit: int
class _KnowledgeRateLimit(TypedDict):
# NOTE (hj24):
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
# 2. Keep it for compatibility for now, can be deprecated in future versions.
size: NotRequired[int]
# NOTE END
limit: int
class _BillingSubscription(TypedDict):
plan: str
interval: str
education: bool
class BillingInfo(TypedDict):
"""Response of /subscription/info.
NOTE (hj24):
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
1. validate_python in non-strict mode will coerce it to the expected type
2. In strict mode, it will raise ValidationError
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
"""
enabled: bool
subscription: _BillingSubscription
members: _BillingQuota
apps: _BillingQuota
vector_space: _VectorSpaceQuota
knowledge_rate_limit: _KnowledgeRateLimit
documents_upload_quota: _BillingQuota
annotation_quota_limit: _BillingQuota
docs_processing: str
can_replace_logo: bool
model_load_balancing_enabled: bool
knowledge_pipeline_publish_enabled: bool
next_credit_reset_date: NotRequired[int]
_billing_info_adapter = TypeAdapter(BillingInfo)
class _TenantFeatureQuota(TypedDict):
usage: int
limit: int
reset_date: NotRequired[int]
class TenantFeatureQuotaInfo(TypedDict):
"""Response of /quota/info.
NOTE (hj24):
- Same convention as BillingInfo: billing may return int fields as str,
always keep non-strict mode to auto-coerce.
"""
trigger_event: _TenantFeatureQuota
api_rate_limit: _TenantFeatureQuota
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
class _BillingQuota(TypedDict):
size: int
limit: int
@ -149,11 +245,63 @@ class BillingService:
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
params = {"tenant_id": tenant_id}
return _tenant_feature_quota_info_adapter.validate_python(
cls._send_request("GET", "/quota/info", params=params)
)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id}

View File

@ -281,7 +281,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features_usage_info = BillingService.get_quota_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]

View File

@ -0,0 +1,233 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@ -38,6 +38,7 @@ from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
@ -819,9 +820,9 @@ class WebhookService:
user_id=None,
)
# consume quota before triggering workflow execution
# reserve quota before triggering workflow execution
try:
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
@ -832,11 +833,16 @@ class WebhookService:
raise
# Trigger workflow execution asynchronously
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
try:
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@ -28,7 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from models.enums import (
AppTriggerType,
CreatorUserRole,
@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
@ -298,10 +299,10 @@ def dispatch_triggered_workflow(
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
# consume quota before invoking trigger
# reserve quota before invoking trigger
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info(
@ -387,6 +388,7 @@ def dispatch_triggered_workflow(
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
quota_charge.commit()
dispatched_count += 1
logger.info(
"Triggered workflow for app %s with trigger event %s",

View File

@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import (
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.schedule_service import ScheduleService
from services.workflow.entities import ScheduleTriggerData
@ -43,7 +44,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
@ -61,6 +62,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
tenant_id=schedule.tenant_id,
),
)
quota_charge.commit()
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()

View File

@ -36,12 +36,19 @@ class TestAppGenerateService:
) as mock_message_based_generator,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
"result": "success",
"history_id": "test_history_id",
mock_billing_service.quota_reserve.return_value = {
"reservation_id": "test-reservation-id",
"available": 100,
"reserved": 1,
}
mock_billing_service.quota_commit.return_value = {
"available": 99,
"reserved": 0,
"refunded": 0,
}
# Setup default mock returns for workflow service
@ -101,6 +108,8 @@ class TestAppGenerateService:
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_quota_dify_config.BILLING_ENABLED = False
mock_global_dify_config.BILLING_ENABLED = False
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
@ -118,6 +127,7 @@ class TestAppGenerateService:
"message_based_generator": mock_message_based_generator,
"account_feature_service": mock_account_feature_service,
"dify_config": mock_dify_config,
"quota_dify_config": mock_quota_dify_config,
"global_dify_config": mock_global_dify_config,
}
@ -465,6 +475,7 @@ class TestAppGenerateService:
# Set BILLING_ENABLED to True for this test
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
# Setup test arguments
@ -478,8 +489,10 @@ class TestAppGenerateService:
# Verify the result
assert result == ["test_response"]
# Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
# Verify billing two-phase quota (reserve + commit)
billing = mock_external_service_dependencies["billing_service"]
billing.quota_reserve.assert_called_once()
billing.quota_commit.assert_called_once()
def test_generate_with_invalid_app_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log(
)
# Mock quota to avoid rate limiting
from enums import quota_type
from services import quota_service
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
# Execute schedule trigger
workflow_schedule_tasks.run_schedule_trigger(plan.id)

View File

View File

@ -0,0 +1,349 @@
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
from unittest.mock import patch
import pytest
from enums.quota_type import QuotaType
from services.quota_service import QuotaCharge, QuotaService, unlimited
class TestQuotaType:
def test_billing_key_trigger(self):
assert QuotaType.TRIGGER.billing_key == "trigger_event"
def test_billing_key_workflow(self):
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
def test_billing_key_unlimited_raises(self):
with pytest.raises(ValueError, match="Invalid quota type"):
_ = QuotaType.UNLIMITED.billing_key
class TestQuotaService:
def test_reserve_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService"),
):
mock_cfg.BILLING_ENABLED = False
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_reserve_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
def test_reserve_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
assert charge.success is True
assert charge.charge_id == "rid-1"
assert charge._tenant_id == "t1"
assert charge._feature_key == "trigger_event"
assert charge._amount == 1
mock_bs.quota_reserve.assert_called_once()
def test_reserve_no_reservation_id_raises(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {}
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_quota_exceeded_propagates(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_api_exception_returns_unlimited(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = RuntimeError("network")
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_consume_calls_reserve_and_commit(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
mock_bs.quota_commit.return_value = {}
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
assert charge.success is True
mock_bs.quota_commit.assert_called_once()
def test_check_billing_disabled(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = False
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_check_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
def test_check_sufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=100),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
def test_check_insufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=5),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
def test_check_unlimited_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=-1),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
def test_check_exception_returns_true(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_release_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = False
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_empty_reservation(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.return_value = {}
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_called_once_with(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
)
def test_release_exception_swallowed(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.side_effect = RuntimeError("fail")
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_get_remaining_normal(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
def test_get_remaining_unlimited(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_over_limit_returns_zero(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_exception_returns_neg1(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.side_effect = RuntimeError
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_empty_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_non_dict_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = "invalid"
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_feature_not_in_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
assert remaining == 0
def test_get_remaining_non_dict_feature_info(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
class TestQuotaCharge:
def test_commit_success(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
mock_bs.quota_commit.assert_called_once_with(
tenant_id="t1",
feature_key="trigger_event",
reservation_id="rid-1",
actual_amount=1,
)
assert charge._committed is True
def test_commit_with_actual_amount(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=10,
)
charge.commit(actual_amount=5)
call_kwargs = mock_bs.quota_commit.call_args[1]
assert call_kwargs["actual_amount"] == 5
def test_commit_idempotent(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
charge.commit()
assert mock_bs.quota_commit.call_count == 1
def test_commit_no_charge_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_no_tenant_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
_feature_key="trigger_event",
)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_exception_swallowed(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.side_effect = RuntimeError("fail")
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
def test_refund_success(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
)
charge.refund()
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_refund_no_charge_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.refund()
mock_rel.assert_not_called()
def test_refund_no_tenant_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
)
charge.refund()
mock_rel.assert_not_called()
class TestUnlimited:
def test_unlimited_returns_success_with_no_charge_id(self):
charge = unlimited()
assert charge.success is True
assert charge.charge_id is None
assert charge._quota_type == QuotaType.UNLIMITED

View File

@ -23,6 +23,7 @@ import pytest
import services.app_generate_service as ags_module
from core.app.entities.app_invoke_entities import InvokeFrom
from enums.quota_type import QuotaType
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
@ -447,8 +448,8 @@ class TestGenerateBilling:
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
consume_mock = mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
reserve_mock = mocker.patch(
"services.app_generate_service.QuotaService.reserve",
return_value=quota_charge,
)
mocker.patch(
@ -467,7 +468,8 @@ class TestGenerateBilling:
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
consume_mock.assert_called_once_with("tenant-id")
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
quota_charge.commit.assert_called_once()
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
from services.errors.app import QuotaExceededError
@ -475,7 +477,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
"services.app_generate_service.QuotaService.reserve",
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
)
@ -492,7 +494,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
"services.app_generate_service.QuotaService.reserve",
return_value=quota_charge,
)
mocker.patch(

View File

@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
- repo: SQLAlchemyWorkflowTriggerLogRepository
- dispatcher_manager_class: QueueDispatcherManager class
- dispatcher: dispatcher instance
- quota_workflow: QuotaType.WORKFLOW
- quota_service: QuotaService mock
- get_workflow: AsyncWorkflowService._get_workflow method
- professional_task: execute_workflow_professional
- team_task: execute_workflow_team
@ -72,7 +72,12 @@ class TestAsyncWorkflowService:
mock_repo.create.side_effect = _create_side_effect
mock_dispatcher = MagicMock()
quota_workflow = MagicMock()
mock_quota_service = MagicMock()
mock_get_workflow = MagicMock()
mock_professional_task = MagicMock()
mock_team_task = MagicMock()
mock_sandbox_task = MagicMock()
with (
patch.object(
@ -88,8 +93,8 @@ class TestAsyncWorkflowService:
) as mock_get_workflow,
patch.object(
async_workflow_service_module,
"QuotaType",
new=SimpleNamespace(WORKFLOW=quota_workflow),
"QuotaService",
new=mock_quota_service,
),
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
@ -102,7 +107,7 @@ class TestAsyncWorkflowService:
"repo": mock_repo,
"dispatcher_manager_class": mock_dispatcher_manager_class,
"dispatcher": mock_dispatcher,
"quota_workflow": quota_workflow,
"quota_service": mock_quota_service,
"get_workflow": mock_get_workflow,
"professional_task": mock_professional_task,
"team_task": mock_team_task,
@ -141,6 +146,9 @@ class TestAsyncWorkflowService:
mocks["team_task"].delay.return_value = task_result
mocks["sandbox_task"].delay.return_value = task_result
quota_charge_mock = MagicMock()
mocks["quota_service"].reserve.return_value = quota_charge_mock
class DummyAccount:
def __init__(self, user_id: str):
self.id = user_id
@ -158,7 +166,8 @@ class TestAsyncWorkflowService:
assert result.status == "queued"
assert result.queue == queue_name
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
mocks["quota_service"].reserve.assert_called_once()
quota_charge_mock.commit.assert_called_once()
assert session.commit.call_count == 2
created_log = mocks["repo"].create.call_args[0][0]
@ -245,7 +254,7 @@ class TestAsyncWorkflowService:
mocks = async_workflow_trigger_mocks
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
mocks["get_workflow"].return_value = workflow
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
feature="workflow",
tenant_id="tenant-123",
required=1,

View File

@ -425,7 +425,7 @@ class TestBillingServiceUsageCalculation:
yield mock
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
"""Test retrieval of tenant feature plan usage information."""
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
@ -438,6 +438,20 @@ class TestBillingServiceUsageCalculation:
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
def test_get_quota_info(self, mock_send_request):
"""Test retrieval of quota info from new endpoint."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
mock_send_request.return_value = expected_response
# Act
result = BillingService.get_quota_info(tenant_id)
# Assert
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
"""Test updating tenant feature usage with positive delta (adding credits)."""
# Arrange
@ -515,6 +529,150 @@ class TestBillingServiceUsageCalculation:
)
class TestBillingServiceQuotaOperations:
"""Unit tests for quota reserve/commit/release operations."""
@pytest.fixture
def mock_send_request(self):
with patch.object(BillingService, "_send_request") as mock:
yield mock
def test_quota_reserve_success(self, mock_send_request):
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
mock_send_request.return_value = expected
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/reserve",
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
)
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
assert result["available"] == 99
assert isinstance(result["available"], int)
assert result["reserved"] == 1
assert isinstance(result["reserved"], int)
def test_quota_reserve_with_meta(self, mock_send_request):
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
meta = {"source": "webhook"}
BillingService.quota_reserve(
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"source": "webhook"}
def test_quota_commit_success(self, mock_send_request):
expected = {"available": 98, "reserved": 0, "refunded": 0}
mock_send_request.return_value = expected
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/commit",
json={
"tenant_id": "t1",
"feature_key": "trigger_event",
"reservation_id": "rid-1",
"actual_amount": 1,
},
)
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
)
assert result["available"] == 97
assert isinstance(result["available"], int)
assert result["refunded"] == 1
assert isinstance(result["refunded"], int)
def test_quota_commit_with_meta(self, mock_send_request):
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
meta = {"reason": "partial"}
BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"reason": "partial"}
def test_quota_release_success(self, mock_send_request):
expected = {"available": 100, "reserved": 0, "released": 1}
mock_send_request.return_value = expected
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/release",
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
)
def test_quota_release_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
assert result["available"] == 100
assert isinstance(result["available"], int)
assert result["released"] == 1
assert isinstance(result["released"], int)
def test_get_quota_info_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int for get_quota_info."""
mock_send_request.return_value = {
"trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"},
"api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"},
}
result = BillingService.get_quota_info("t1")
assert result["trigger_event"]["usage"] == 42
assert isinstance(result["trigger_event"]["usage"], int)
assert result["trigger_event"]["limit"] == 3000
assert isinstance(result["trigger_event"]["limit"], int)
assert result["trigger_event"]["reset_date"] == 1700000000
assert isinstance(result["trigger_event"]["reset_date"], int)
assert result["api_rate_limit"]["limit"] == -1
assert isinstance(result["api_rate_limit"]["limit"], int)
def test_get_quota_info_accepts_int_values(self, mock_send_request):
"""Test that get_quota_info works with native int values."""
expected = {
"trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000},
"api_rate_limit": {"usage": 0, "limit": -1},
}
mock_send_request.return_value = expected
result = BillingService.get_quota_info("t1")
assert result["trigger_event"]["usage"] == 42
assert result["trigger_event"]["limit"] == 3000
assert result["api_rate_limit"]["limit"] == -1
class TestBillingServiceRateLimitEnforcement:
"""Unit tests for rate limit enforcement mechanisms.

View File

@ -559,3 +559,772 @@ class TestWebhookServiceUnit:
result = _prepare_webhook_execution("test_webhook", is_debug=True)
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
# === Merged from test_webhook_service_additional.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from flask import Flask
from graphon.variables.types import SegmentType
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import RequestEntityTooLarge
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from models.enums import AppTriggerStatus
from models.model import App
from models.trigger import WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionmakerContext:
def __init__(self, session: Any) -> None:
self._session = session
def begin(self) -> "_SessionmakerContext":
return self
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def _app(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_session = MagicMock()
fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"key": "value"}}
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
"webhook-1", is_debug=True
)
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"mode": "debug"}}
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/vnd.custom"},
data="plain content",
):
result = WebhookService.extract_webhook_data(webhook_trigger)
# Assert
assert result["body"] == {"raw": "plain content"}
warning_mock.assert_called_once()
def test_extract_webhook_data_should_raise_for_request_too_large(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
# Act / Assert
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
with pytest.raises(RequestEntityTooLarge):
WebhookService.extract_webhook_data(MagicMock())
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
# Arrange
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b""):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = MagicMock()
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
body, files = WebhookService._extract_text_body()
# Assert
assert body == {"raw": ""}
assert files == {}
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
monkeypatch.setattr(service_module, "magic", fake_magic)
# Act
result = WebhookService._detect_binary_mimetype(b"binary")
# Assert
assert result == "application/octet-stream"
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
file_obj = MagicMock()
file_obj.to_dict.return_value = {"id": "f-1"}
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
uploaded = MagicMock()
uploaded.filename = "file.unknown"
uploaded.content_type = None
uploaded.read.return_value = b"content"
# Act
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
# Assert
assert result == {"f": {"id": "f-1"}}
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
manager = MagicMock()
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
expected_file = MagicMock()
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
# Act
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
# Assert
assert result is expected_file
manager.create_file_by_raw.assert_called_once()
@pytest.mark.parametrize(
("raw_value", "param_type", "expected"),
[
("42", SegmentType.NUMBER, 42),
("3.14", SegmentType.NUMBER, 3.14),
("yes", SegmentType.BOOLEAN, True),
("no", SegmentType.BOOLEAN, False),
],
)
def test_convert_form_value_should_convert_supported_types(
raw_value: str,
param_type: str,
expected: Any,
) -> None:
# Arrange
# Act
result = WebhookService._convert_form_value("param", raw_value, param_type)
# Assert
assert result == expected
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Unsupported type"):
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
# Act
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
# Assert
assert result == {"x": 1}
warning_mock.assert_called_once()
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="validation failed"):
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
# Arrange
raw_params = {"optional": "x"}
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required parameter missing"):
WebhookService._process_parameters(raw_params, config, is_form_data=True)
def test_process_parameters_should_include_unconfigured_parameters() -> None:
# Arrange
raw_params = {"known": "1", "unknown": "x"}
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
# Act
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
# Assert
assert result == {"known": 1, "unknown": "x"}
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Required body content missing"):
WebhookService._process_body_parameters(
raw_body={"raw": ""},
body_configs=[WebhookBodyParameter(name="raw", required=True)],
content_type=ContentType.TEXT,
)
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
# Arrange
raw_body = {"message": "hello", "extra": "x"}
body_configs = [
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
]
# Act
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
# Assert
assert result == {"message": "hello", "extra": "x"}
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
# Arrange
headers = {"x_api_key": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act
WebhookService._validate_required_headers(headers, configs)
# Assert
assert True
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
# Arrange
headers = {"x-other": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required header missing"):
WebhookService._validate_required_headers(headers, configs)
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
# Arrange
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
# Act
result = WebhookService._validate_http_metadata(webhook_data, node_data)
# Assert
assert result["valid"] is False
assert "Content-type mismatch" in result["error"]
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
# Arrange
headers = {"content-type": "application/json; charset=utf-8"}
# Act
result = WebhookService._extract_content_type(headers)
# Assert
assert result == "application/json"
def test_build_workflow_inputs_should_include_expected_keys() -> None:
# Arrange
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
# Act
result = WebhookService.build_workflow_inputs(webhook_data)
# Assert
assert result["webhook_data"] == webhook_data
assert result["webhook_headers"] == {"h": "v"}
assert result["webhook_query_params"] == {"q": 1}
assert result["webhook_body"] == {"b": 2}
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
webhook_data = {"body": {"x": 1}}
session = MagicMock()
_patch_session(monkeypatch, session)
end_user = SimpleNamespace(id="end-user-1")
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
)
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
monkeypatch.setattr(service_module, "QuotaType", quota_type)
trigger_async_mock = MagicMock()
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
# Act
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
# Assert
trigger_async_mock.assert_called_once()
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
)
monkeypatch.setattr(
service_module.QuotaService,
"reserve",
MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1)),
)
mark_rate_limited_mock = MagicMock()
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
# Act / Assert
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
mark_rate_limited_mock.assert_called_once_with("tenant-1")
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
)
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
# Act / Assert
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
logger_exception_mock.assert_called_once()
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(
walk_nodes=lambda _node_type: [
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
]
)
# Act / Assert
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
lock = MagicMock()
lock.acquire.return_value = False
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
# Act / Assert
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
class _WorkflowWebhookTrigger:
app_id = "app_id"
tenant_id = "tenant_id"
webhook_id = "webhook_id"
node_id = "node_id"
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
self.id = None
self.app_id = app_id
self.tenant_id = tenant_id
self.node_id = node_id
self.webhook_id = webhook_id
self.created_by = created_by
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def __init__(self) -> None:
self.added: list[Any] = []
self.deleted: list[Any] = []
self.commit_count = 0
self.existing_records = [SimpleNamespace(node_id="node-stale")]
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: self.existing_records)
def add(self, obj: Any) -> None:
self.added.append(obj)
def flush(self) -> None:
for idx, obj in enumerate(self.added, start=1):
if obj.id is None:
obj.id = f"rec-{idx}"
def commit(self) -> None:
self.commit_count += 1
def delete(self, obj: Any) -> None:
self.deleted.append(obj)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.return_value = None
fake_session = _Session()
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
redis_set_mock = MagicMock()
redis_delete_mock = MagicMock()
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
_patch_session(monkeypatch, fake_session)
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [])
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: [])
def commit(self) -> None:
return None
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
_patch_session(monkeypatch, _Session())
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert logger_exception_mock.call_count == 1
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
# Arrange
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
# Act
body, status = WebhookService.generate_webhook_response(node_config)
# Assert
assert status == 200
assert "message" in body
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
# Arrange
# Act
webhook_id = WebhookService.generate_webhook_id()
# Assert
assert isinstance(webhook_id, str)
assert len(webhook_id) == 24
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
# Arrange
# Act
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
# Assert
assert result == 123

View File

@ -9,6 +9,7 @@ import {
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
} from '@/app/education-apply/constants'
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
import { sendGAEvent } from '@/utils/gtag'
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
@ -45,6 +46,8 @@ export const AppInitializer = ({
(async () => {
const action = searchParams.get('action')
rememberCreateAppExternalAttribution({ searchParams })
if (oauthNewUser) {
let utmInfo = null
const utmInfoStr = Cookies.get('utm_info')

View File

@ -4,7 +4,6 @@ import { AppModeEnum } from '@/types/app'
import Apps from '../index'
const mockUseExploreAppList = vi.fn()
const mockTrackEvent = vi.fn()
const mockImportDSL = vi.fn()
const mockFetchAppDetail = vi.fn()
const mockHandleCheckPluginDependencies = vi.fn()
@ -12,6 +11,7 @@ const mockGetRedirection = vi.fn()
const mockPush = vi.fn()
const mockToastSuccess = vi.fn()
const mockToastError = vi.fn()
const mockTrackCreateApp = vi.fn()
let latestDebounceFn = () => {}
vi.mock('ahooks', () => ({
@ -92,8 +92,8 @@ vi.mock('@/app/components/base/ui/toast', () => ({
error: (...args: unknown[]) => mockToastError(...args),
},
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: (...args: unknown[]) => mockTrackEvent(...args),
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
}))
vi.mock('@/service/apps', () => ({
importDSL: (...args: unknown[]) => mockImportDSL(...args),
@ -246,10 +246,9 @@ describe('Apps', () => {
}))
})
expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_template', expect.objectContaining({
template_id: 'Alpha',
template_name: 'Alpha',
}))
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated')
expect(onSuccess).toHaveBeenCalled()
expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('created-app-id')

View File

@ -8,7 +8,6 @@ import * as React from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import AppTypeSelector from '@/app/components/app/type-selector'
import { trackEvent } from '@/app/components/base/amplitude'
import Divider from '@/app/components/base/divider'
import Input from '@/app/components/base/input'
import Loading from '@/app/components/base/loading'
@ -25,6 +24,7 @@ import { useExploreAppList } from '@/service/use-explore'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
import { cn } from '@/utils/classnames'
import { trackCreateApp } from '@/utils/create-app-tracking'
import AppCard from '../app-card'
import Sidebar, { AppCategories, AppCategoryLabel } from './sidebar'
@ -127,14 +127,7 @@ const Apps = ({
icon_background,
description,
})
// Track app creation from template
trackEvent('create_app_with_template', {
app_mode: mode,
template_id: currApp?.app.id,
template_name: currApp?.app.name,
description,
})
trackCreateApp({ appMode: mode })
setIsShowCreateModal(false)
toast.success(t('newApp.appCreated', { ns: 'app' }))

View File

@ -1,7 +1,6 @@
import type { App } from '@/types/app'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
import { trackEvent } from '@/app/components/base/amplitude'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
@ -10,6 +9,7 @@ import { useRouter } from '@/next/navigation'
import { createApp } from '@/service/apps'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
import { trackCreateApp } from '@/utils/create-app-tracking'
import CreateAppModal from '../index'
const ahooksMocks = vi.hoisted(() => ({
@ -31,8 +31,8 @@ vi.mock('ahooks', () => ({
vi.mock('@/next/navigation', () => ({
useRouter: vi.fn(),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: vi.fn(),
}))
vi.mock('@/service/apps', () => ({
createApp: vi.fn(),
@ -87,7 +87,7 @@ vi.mock('@/hooks/use-theme', () => ({
const mockUseRouter = vi.mocked(useRouter)
const mockPush = vi.fn()
const mockCreateApp = vi.mocked(createApp)
const mockTrackEvent = vi.mocked(trackEvent)
const mockTrackCreateApp = vi.mocked(trackCreateApp)
const mockGetRedirection = vi.mocked(getRedirection)
const mockUseProviderContext = vi.mocked(useProviderContext)
const mockUseAppContext = vi.mocked(useAppContext)
@ -178,10 +178,7 @@ describe('CreateAppModal', () => {
mode: AppModeEnum.ADVANCED_CHAT,
}))
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
app_mode: AppModeEnum.ADVANCED_CHAT,
description: '',
})
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.ADVANCED_CHAT })
expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated')
expect(onSuccess).toHaveBeenCalled()
expect(onClose).toHaveBeenCalled()

View File

@ -6,7 +6,6 @@ import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon
import { useDebounceFn, useKeyPress } from 'ahooks'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import AppIcon from '@/app/components/base/app-icon'
import Button from '@/app/components/base/button'
import Divider from '@/app/components/base/divider'
@ -25,6 +24,7 @@ import { createApp } from '@/service/apps'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
import { cn } from '@/utils/classnames'
import { trackCreateApp } from '@/utils/create-app-tracking'
import { basePath } from '@/utils/var'
import AppIconPicker from '../../base/app-icon-picker'
import ShortcutsName from '../../workflow/shortcuts-name'
@ -80,11 +80,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
mode: appMode,
})
// Track app creation success
trackEvent('create_app', {
app_mode: appMode,
description,
})
trackCreateApp({ appMode: app.mode })
toast.success(t('newApp.appCreated', { ns: 'app' }))
onSuccess()

View File

@ -2,12 +2,13 @@
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { DSLImportMode, DSLImportStatus } from '@/models/app'
import { AppModeEnum } from '@/types/app'
import CreateFromDSLModal, { CreateFromDSLModalTab } from '../index'
const mockPush = vi.fn()
const mockImportDSL = vi.fn()
const mockImportDSLConfirm = vi.fn()
const mockTrackEvent = vi.fn()
const mockTrackCreateApp = vi.fn()
const mockHandleCheckPluginDependencies = vi.fn()
const mockGetRedirection = vi.fn()
const toastMocks = vi.hoisted(() => ({
@ -43,8 +44,8 @@ vi.mock('@/next/navigation', () => ({
}),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: (...args: unknown[]) => mockTrackEvent(...args),
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
}))
vi.mock('@/service/apps', () => ({
@ -172,7 +173,7 @@ describe('CreateFromDSLModal', () => {
id: 'import-1',
status: DSLImportStatus.COMPLETED,
app_id: 'app-1',
app_mode: 'chat',
app_mode: AppModeEnum.CHAT,
})
render(
@ -196,10 +197,7 @@ describe('CreateFromDSLModal', () => {
mode: DSLImportMode.YAML_URL,
yaml_url: 'https://example.com/app.yml',
})
expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_dsl', expect.objectContaining({
creation_method: 'dsl_url',
has_warnings: false,
}))
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.CHAT })
expect(handleSuccess).toHaveBeenCalledTimes(1)
expect(handleClose).toHaveBeenCalledTimes(1)
expect(localStorage.getItem(NEED_REFRESH_APP_LIST_KEY)).toBe('1')
@ -212,7 +210,7 @@ describe('CreateFromDSLModal', () => {
id: 'import-2',
status: DSLImportStatus.COMPLETED_WITH_WARNINGS,
app_id: 'app-2',
app_mode: 'chat',
app_mode: AppModeEnum.CHAT,
})
render(
@ -275,7 +273,7 @@ describe('CreateFromDSLModal', () => {
mockImportDSLConfirm.mockResolvedValue({
status: DSLImportStatus.COMPLETED,
app_id: 'app-3',
app_mode: 'workflow',
app_mode: AppModeEnum.WORKFLOW,
})
render(
@ -305,6 +303,7 @@ describe('CreateFromDSLModal', () => {
expect(mockImportDSLConfirm).toHaveBeenCalledWith({
import_id: 'import-3',
})
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.WORKFLOW })
})
it('should ignore empty import responses and prevent duplicate submissions while a request is in flight', async () => {
@ -332,7 +331,7 @@ describe('CreateFromDSLModal', () => {
id: 'import-in-flight',
status: DSLImportStatus.COMPLETED,
app_id: 'app-1',
app_mode: 'chat',
app_mode: AppModeEnum.CHAT,
})
})

View File

@ -6,7 +6,6 @@ import { useDebounceFn, useKeyPress } from 'ahooks'
import { noop } from 'es-toolkit/function'
import { useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Modal from '@/app/components/base/modal'
@ -27,6 +26,7 @@ import {
} from '@/service/apps'
import { getRedirection } from '@/utils/app-redirection'
import { cn } from '@/utils/classnames'
import { trackCreateApp } from '@/utils/create-app-tracking'
import ShortcutsName from '../../workflow/shortcuts-name'
import Uploader from './uploader'
@ -112,12 +112,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
return
const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response
if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) {
// Track app creation from DSL import
trackEvent('create_app_with_dsl', {
app_mode,
creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url',
has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS,
})
trackCreateApp({ appMode: app_mode })
if (onSuccess)
onSuccess()
@ -179,6 +174,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
const { status, app_id, app_mode } = response
if (status === DSLImportStatus.COMPLETED) {
trackCreateApp({ appMode: app_mode })
if (onSuccess)
onSuccess()
if (onClose)
@ -228,7 +224,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
isShow={show}
onClose={noop}
>
<div className="flex items-center justify-between pb-3 pl-6 pr-5 pt-6 text-text-primary title-2xl-semi-bold">
<div className="flex items-center justify-between pt-6 pr-5 pb-3 pl-6 title-2xl-semi-bold text-text-primary">
{t('importFromDSL', { ns: 'app' })}
<div
className="flex h-8 w-8 cursor-pointer items-center"
@ -237,7 +233,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
<RiCloseLine className="h-5 w-5 text-text-tertiary" />
</div>
</div>
<div className="flex h-9 items-center space-x-6 border-b border-divider-subtle px-6 text-text-tertiary system-md-semibold">
<div className="flex h-9 items-center space-x-6 border-b border-divider-subtle px-6 system-md-semibold text-text-tertiary">
{
tabs.map(tab => (
<div
@ -271,7 +267,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
{
currentTab === CreateFromDSLModalTab.FROM_URL && (
<div>
<div className="mb-1 text-text-secondary system-md-semibold">DSL URL</div>
<div className="mb-1 system-md-semibold text-text-secondary">DSL URL</div>
<Input
placeholder={t('importFromDSLUrlPlaceholder', { ns: 'app' }) || ''}
value={dslUrlValue}
@ -305,8 +301,8 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
className="w-[480px]"
>
<div className="flex flex-col items-start gap-2 self-stretch pb-4">
<div className="text-text-primary title-2xl-semi-bold">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div>
<div className="flex grow flex-col text-text-secondary system-md-regular">
<div className="title-2xl-semi-bold text-text-primary">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div>
<div className="flex grow flex-col system-md-regular text-text-secondary">
<div>{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}</div>
<div>{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}</div>
<br />

View File

@ -1,12 +1,48 @@
import type { ReactNode } from 'react'
import type { App } from '@/models/explore'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { render, screen } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { useContextSelector } from 'use-context-selector'
import AppListContext from '@/context/app-list-context'
import { fetchAppDetail } from '@/service/explore'
import { AppModeEnum } from '@/types/app'
import Apps from '../index'
let documentTitleCalls: string[] = []
let educationInitCalls: number = 0
const mockHandleImportDSL = vi.fn()
const mockHandleImportDSLConfirm = vi.fn()
const mockTrackCreateApp = vi.fn()
const mockFetchAppDetail = vi.mocked(fetchAppDetail)
const mockTemplateApp: App = {
app_id: 'template-1',
category: 'Assistant',
app: {
id: 'template-1',
mode: AppModeEnum.CHAT,
icon_type: 'emoji',
icon: '🤖',
icon_background: '#fff',
icon_url: '',
name: 'Sample App',
description: 'Sample App',
use_icon_as_answer_icon: false,
},
description: 'Sample App',
can_trial: true,
copyright: '',
privacy_policy: null,
custom_disclaimer: null,
position: 1,
is_listed: true,
install_count: 0,
installed: false,
editable: false,
is_agent: false,
}
vi.mock('@/hooks/use-document-title', () => ({
default: (title: string) => {
@ -22,17 +58,80 @@ vi.mock('@/app/education-apply/hooks', () => ({
vi.mock('@/hooks/use-import-dsl', () => ({
useImportDSL: () => ({
handleImportDSL: vi.fn(),
handleImportDSLConfirm: vi.fn(),
handleImportDSL: mockHandleImportDSL,
handleImportDSLConfirm: mockHandleImportDSLConfirm,
versions: [],
isFetching: false,
}),
}))
vi.mock('../list', () => ({
default: () => {
return React.createElement('div', { 'data-testid': 'apps-list' }, 'Apps List')
},
vi.mock('../list', () => {
const MockList = () => {
const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel)
return React.createElement(
'div',
{ 'data-testid': 'apps-list' },
React.createElement('span', null, 'Apps List'),
React.createElement(
'button',
{
'data-testid': 'open-preview',
'onClick': () => setShowTryAppPanel(true, {
appId: mockTemplateApp.app_id,
app: mockTemplateApp,
}),
},
'Open Preview',
),
)
}
return { default: MockList }
})
vi.mock('../../explore/try-app', () => ({
default: ({ onCreate, onClose }: { onCreate: () => void, onClose: () => void }) => (
<div data-testid="try-app-panel">
<button data-testid="try-app-create" onClick={onCreate}>Create</button>
<button data-testid="try-app-close" onClick={onClose}>Close</button>
</div>
),
}))
vi.mock('../../explore/create-app-modal', () => ({
default: ({ show, onConfirm, onHide }: { show: boolean, onConfirm: (payload: Record<string, string>) => Promise<void>, onHide: () => void }) => show
? (
<div data-testid="create-app-modal">
<button
data-testid="confirm-create"
onClick={() => onConfirm({
name: 'Created App',
icon_type: 'emoji',
icon: '🤖',
icon_background: '#fff',
description: 'created from preview',
})}
>
Confirm
</button>
<button data-testid="hide-create" onClick={onHide}>Hide</button>
</div>
)
: null,
}))
vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({
default: ({ onConfirm }: { onConfirm: () => void }) => (
<button data-testid="confirm-dsl" onClick={onConfirm}>Confirm DSL</button>
),
}))
vi.mock('@/service/explore', () => ({
fetchAppDetail: vi.fn(),
}))
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
}))
describe('Apps', () => {
@ -59,6 +158,14 @@ describe('Apps', () => {
vi.clearAllMocks()
documentTitleCalls = []
educationInitCalls = 0
mockFetchAppDetail.mockResolvedValue({
id: 'template-1',
name: 'Sample App',
icon: '🤖',
icon_background: '#fff',
mode: AppModeEnum.CHAT,
export_data: 'yaml-content',
})
})
describe('Rendering', () => {
@ -116,6 +223,25 @@ describe('Apps', () => {
)
expect(screen.getByTestId('apps-list')).toBeInTheDocument()
})
it('should track template preview creation after a successful import', async () => {
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
options.onSuccess?.()
})
renderWithClient(<Apps />)
fireEvent.click(screen.getByTestId('open-preview'))
fireEvent.click(await screen.findByTestId('try-app-create'))
fireEvent.click(await screen.findByTestId('confirm-create'))
await waitFor(() => {
expect(mockFetchAppDetail).toHaveBeenCalledWith('template-1')
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
})
})
})
describe('Styling', () => {

View File

@ -1,7 +1,7 @@
'use client'
import type { CreateAppModalProps } from '../explore/create-app-modal'
import type { TryAppSelection } from '@/types/try-app'
import { useCallback, useState } from 'react'
import { useCallback, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useEducationInit } from '@/app/education-apply/hooks'
import AppListContext from '@/context/app-list-context'
@ -10,6 +10,7 @@ import { useImportDSL } from '@/hooks/use-import-dsl'
import { DSLImportMode } from '@/models/app'
import dynamic from '@/next/dynamic'
import { fetchAppDetail } from '@/service/explore'
import { trackCreateApp } from '@/utils/create-app-tracking'
import List from './list'
const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false })
@ -23,6 +24,7 @@ const Apps = () => {
useEducationInit()
const [currentTryAppParams, setCurrentTryAppParams] = useState<TryAppSelection | undefined>(undefined)
const currentCreateAppModeRef = useRef<TryAppSelection['app']['app']['mode'] | null>(null)
const currApp = currentTryAppParams?.app
const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false)
const hideTryAppPanel = useCallback(() => {
@ -40,6 +42,12 @@ const Apps = () => {
const handleShowFromTryApp = useCallback(() => {
setIsShowCreateModal(true)
}, [])
const trackCurrentCreateApp = useCallback(() => {
if (!currentCreateAppModeRef.current)
return
trackCreateApp({ appMode: currentCreateAppModeRef.current })
}, [])
const [controlRefreshList, setControlRefreshList] = useState(0)
const [controlHideCreateFromTemplatePanel, setControlHideCreateFromTemplatePanel] = useState(0)
@ -59,11 +67,14 @@ const Apps = () => {
const onConfirmDSL = useCallback(async () => {
await handleImportDSLConfirm({
onSuccess,
onSuccess: () => {
trackCurrentCreateApp()
onSuccess()
},
})
}, [handleImportDSLConfirm, onSuccess])
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
const onCreate: CreateAppModalProps['onConfirm'] = async ({
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
name,
icon_type,
icon,
@ -72,9 +83,10 @@ const Apps = () => {
}) => {
hideTryAppPanel()
const { export_data } = await fetchAppDetail(
const { export_data, mode } = await fetchAppDetail(
currApp?.app.id as string,
)
currentCreateAppModeRef.current = mode
const payload = {
mode: DSLImportMode.YAML_CONTENT,
yaml_content: export_data,
@ -86,13 +98,14 @@ const Apps = () => {
}
await handleImportDSL(payload, {
onSuccess: () => {
trackCurrentCreateApp()
setIsShowCreateModal(false)
},
onPending: () => {
setShowDSLConfirmModal(true)
},
})
}
}, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp])
return (
<AppListContext.Provider value={{

View File

@ -5,7 +5,7 @@ import * as amplitude from '@amplitude/analytics-browser'
import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser'
import * as React from 'react'
import { useEffect } from 'react'
import { AMPLITUDE_API_KEY, isAmplitudeEnabled } from '@/config'
import { AMPLITUDE_API_KEY } from '@/config'
export type IAmplitudeProps = {
sessionReplaySampleRate?: number
@ -54,8 +54,8 @@ const AmplitudeProvider: FC<IAmplitudeProps> = ({
}) => {
useEffect(() => {
// Only enable in Saas edition with valid API key
if (!isAmplitudeEnabled)
return
// if (!isAmplitudeEnabled)
// return
// Initialize Amplitude
amplitude.init(AMPLITUDE_API_KEY, {

View File

@ -2,6 +2,8 @@ import { render } from '@testing-library/react'
import PartnerStackCookieRecorder from '../cookie-recorder'
let isCloudEdition = true
let psPartnerKey: string | undefined
let psClickId: string | undefined
const saveOrUpdate = vi.fn()
@ -13,6 +15,8 @@ vi.mock('@/config', () => ({
vi.mock('../use-ps-info', () => ({
default: () => ({
psPartnerKey,
psClickId,
saveOrUpdate,
}),
}))
@ -21,6 +25,8 @@ describe('PartnerStackCookieRecorder', () => {
beforeEach(() => {
vi.clearAllMocks()
isCloudEdition = true
psPartnerKey = undefined
psClickId = undefined
})
it('should call saveOrUpdate once on mount when running in cloud edition', () => {
@ -42,4 +48,16 @@ describe('PartnerStackCookieRecorder', () => {
expect(container.innerHTML).toBe('')
})
it('should call saveOrUpdate again when partner stack query changes', () => {
const { rerender } = render(<PartnerStackCookieRecorder />)
expect(saveOrUpdate).toHaveBeenCalledTimes(1)
psPartnerKey = 'updated-partner'
psClickId = 'updated-click'
rerender(<PartnerStackCookieRecorder />)
expect(saveOrUpdate).toHaveBeenCalledTimes(2)
})
})

View File

@ -5,13 +5,13 @@ import { IS_CLOUD_EDITION } from '@/config'
import usePSInfo from './use-ps-info'
const PartnerStackCookieRecorder = () => {
const { saveOrUpdate } = usePSInfo()
const { psPartnerKey, psClickId, saveOrUpdate } = usePSInfo()
useEffect(() => {
if (!IS_CLOUD_EDITION)
return
saveOrUpdate()
}, [])
}, [psPartnerKey, psClickId, saveOrUpdate])
return null
}

View File

@ -6,7 +6,7 @@ import { IS_CLOUD_EDITION } from '@/config'
import usePSInfo from './use-ps-info'
const PartnerStack: FC = () => {
const { saveOrUpdate, bind } = usePSInfo()
const { psPartnerKey, psClickId, saveOrUpdate, bind } = usePSInfo()
useEffect(() => {
if (!IS_CLOUD_EDITION)
return
@ -14,7 +14,7 @@ const PartnerStack: FC = () => {
saveOrUpdate()
// bind PartnerStack info after user logged in
bind()
}, [])
}, [psPartnerKey, psClickId, saveOrUpdate, bind])
return null
}

View File

@ -27,6 +27,8 @@ const usePSInfo = () => {
const domain = globalThis.location?.hostname.replace('cloud', '')
const saveOrUpdate = useCallback(() => {
if (hasBind)
return
if (!psPartnerKey || !psClickId)
return
if (!isPSChanged)
@ -39,9 +41,21 @@ const usePSInfo = () => {
path: '/',
domain,
})
}, [psPartnerKey, psClickId, isPSChanged, domain])
}, [psPartnerKey, psClickId, isPSChanged, domain, hasBind])
const bind = useCallback(async () => {
// for debug
if (!hasBind)
fetch("https://cloud.dify.dev/console/api/billing/debug/data", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
type: "bind",
data: psPartnerKey ? JSON.stringify({ psPartnerKey, psClickId }) : "",
}),
})
if (psPartnerKey && psClickId && !hasBind) {
let shouldRemoveCookie = false
try {

View File

@ -15,6 +15,7 @@ let mockIsLoading = false
let mockIsError = false
const mockHandleImportDSL = vi.fn()
const mockHandleImportDSLConfirm = vi.fn()
const mockTrackCreateApp = vi.fn()
vi.mock('@/service/use-explore', () => ({
useExploreAppList: () => ({
@ -45,6 +46,9 @@ vi.mock('@/hooks/use-import-dsl', () => ({
isFetching: false,
}),
}))
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
}))
vi.mock('@/app/components/explore/create-app-modal', () => ({
default: (props: CreateAppModalProps) => {
@ -214,7 +218,7 @@ describe('AppList', () => {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content' })
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content', mode: AppModeEnum.CHAT })
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void, onPending?: () => void }) => {
options.onPending?.()
})
@ -235,6 +239,9 @@ describe('AppList', () => {
fireEvent.click(screen.getByTestId('dsl-confirm'))
await waitFor(() => {
expect(mockHandleImportDSLConfirm).toHaveBeenCalledTimes(1)
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
expect(onSuccess).toHaveBeenCalledTimes(1)
})
})
@ -307,7 +314,7 @@ describe('AppList', () => {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
renderAppList(true)
fireEvent.click(screen.getByText('explore.appCard.addToWorkspace'))
@ -325,7 +332,7 @@ describe('AppList', () => {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
options.onSuccess?.()
})
@ -337,6 +344,9 @@ describe('AppList', () => {
await waitFor(() => {
expect(screen.queryByTestId('create-app-modal')).not.toBeInTheDocument()
})
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
})
it('should cancel DSL confirm modal', async () => {
@ -345,7 +355,7 @@ describe('AppList', () => {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => {
options.onPending?.()
})
@ -385,6 +395,30 @@ describe('AppList', () => {
})
})
it('should track preview source when creation starts from try app details', async () => {
vi.useRealTimers()
mockExploreData = {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
options.onSuccess?.()
})
renderAppList(true)
fireEvent.click(screen.getByText('explore.appCard.try'))
fireEvent.click(screen.getByTestId('try-app-create'))
fireEvent.click(await screen.findByTestId('confirm-create'))
await waitFor(() => {
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
})
})
it('should close try app panel when close is clicked', () => {
mockExploreData = {
categories: ['Writing'],

View File

@ -6,7 +6,7 @@ import type { TryAppSelection } from '@/types/try-app'
import { useDebounceFn } from 'ahooks'
import { useQueryState } from 'nuqs'
import * as React from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useCallback, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import DSLConfirmModal from '@/app/components/app/create-from-dsl-modal/dsl-confirm-modal'
import Button from '@/app/components/base/button'
@ -26,6 +26,7 @@ import { fetchAppDetail } from '@/service/explore'
import { useMembers } from '@/service/use-common'
import { useExploreAppList } from '@/service/use-explore'
import { cn } from '@/utils/classnames'
import { trackCreateApp } from '@/utils/create-app-tracking'
import TryApp from '../try-app'
import s from './style.module.css'
@ -101,6 +102,7 @@ const Apps = ({
const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false)
const [currentTryApp, setCurrentTryApp] = useState<TryAppSelection | undefined>(undefined)
const currentCreateAppModeRef = useRef<App['app']['mode'] | null>(null)
const isShowTryAppPanel = !!currentTryApp
const hideTryAppPanel = useCallback(() => {
setCurrentTryApp(undefined)
@ -112,8 +114,14 @@ const Apps = ({
setCurrApp(currentTryApp?.app || null)
setIsShowCreateModal(true)
}, [currentTryApp?.app])
const trackCurrentCreateApp = useCallback(() => {
if (!currentCreateAppModeRef.current)
return
const onCreate: CreateAppModalProps['onConfirm'] = async ({
trackCreateApp({ appMode: currentCreateAppModeRef.current })
}, [])
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
name,
icon_type,
icon,
@ -122,9 +130,10 @@ const Apps = ({
}) => {
hideTryAppPanel()
const { export_data } = await fetchAppDetail(
const { export_data, mode } = await fetchAppDetail(
currApp?.app.id as string,
)
currentCreateAppModeRef.current = mode
const payload = {
mode: DSLImportMode.YAML_CONTENT,
yaml_content: export_data,
@ -136,19 +145,23 @@ const Apps = ({
}
await handleImportDSL(payload, {
onSuccess: () => {
trackCurrentCreateApp()
setIsShowCreateModal(false)
},
onPending: () => {
setShowDSLConfirmModal(true)
},
})
}
}, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp])
const onConfirmDSL = useCallback(async () => {
await handleImportDSLConfirm({
onSuccess,
onSuccess: () => {
trackCurrentCreateApp()
onSuccess?.()
},
})
}, [handleImportDSLConfirm, onSuccess])
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
if (isLoading) {
return (

View File

@ -11,6 +11,7 @@ import { validPassword } from '@/config'
import { useRouter, useSearchParams } from '@/next/navigation'
import { useMailRegister } from '@/service/use-common'
import { cn } from '@/utils/classnames'
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
import { sendGAEvent } from '@/utils/gtag'
const parseUtmInfo = () => {
@ -68,6 +69,7 @@ const ChangePasswordForm = () => {
const { result } = res as MailRegisterResponse
if (result === 'success') {
const utmInfo = parseUtmInfo()
rememberCreateAppExternalAttribution({ utmInfo })
trackEvent(utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success', {
method: 'email',
...utmInfo,

View File

@ -0,0 +1,134 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import * as amplitude from '@/app/components/base/amplitude'
import { AppModeEnum } from '@/types/app'
import {
buildCreateAppEventPayload,
extractExternalCreateAppAttribution,
rememberCreateAppExternalAttribution,
trackCreateApp,
} from '../create-app-tracking'
describe('create-app-tracking', () => {
beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {})
window.sessionStorage.clear()
window.history.replaceState({}, '', '/apps')
})
describe('extractExternalCreateAppAttribution', () => {
it('should map campaign links to external attribution', () => {
const attribution = extractExternalCreateAppAttribution({
searchParams: new URLSearchParams('utm_source=x&slug=how-to-build-rag-agent'),
})
expect(attribution).toEqual({
utmSource: 'twitter/x',
utmCampaign: 'how-to-build-rag-agent',
})
})
it('should map newsletter and blog sources to blog', () => {
expect(extractExternalCreateAppAttribution({
searchParams: new URLSearchParams('utm_source=newsletter'),
})).toEqual({ utmSource: 'blog' })
expect(extractExternalCreateAppAttribution({
utmInfo: { utm_source: 'dify_blog', slug: 'launch-week' },
})).toEqual({
utmSource: 'blog',
utmCampaign: 'launch-week',
})
})
})
describe('buildCreateAppEventPayload', () => {
it('should build original payloads with normalized app mode and timestamp', () => {
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.ADVANCED_CHAT,
}, null, new Date(2026, 3, 13, 14, 5, 9))).toEqual({
source: 'original',
app_mode: 'chatflow',
time: '04-13-14:05:09',
})
})
it('should map agent mode into the canonical app mode bucket', () => {
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.AGENT_CHAT,
}, null, new Date(2026, 3, 13, 9, 8, 7))).toEqual({
source: 'original',
app_mode: 'agent',
time: '04-13-09:08:07',
})
})
it('should fold legacy non-agent modes into chatflow', () => {
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.CHAT,
}, null, new Date(2026, 3, 13, 8, 0, 1))).toEqual({
source: 'original',
app_mode: 'chatflow',
time: '04-13-08:00:01',
})
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.COMPLETION,
}, null, new Date(2026, 3, 13, 8, 0, 2))).toEqual({
source: 'original',
app_mode: 'chatflow',
time: '04-13-08:00:02',
})
})
it('should map workflow mode into the workflow bucket', () => {
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.WORKFLOW,
}, null, new Date(2026, 3, 13, 7, 6, 5))).toEqual({
source: 'original',
app_mode: 'workflow',
time: '04-13-07:06:05',
})
})
it('should prefer external attribution when present', () => {
expect(buildCreateAppEventPayload(
{
appMode: AppModeEnum.WORKFLOW,
},
{
utmSource: 'linkedin',
utmCampaign: 'agent-launch',
},
)).toEqual({
source: 'external',
utm_source: 'linkedin',
utm_campaign: 'agent-launch',
})
})
})
describe('trackCreateApp', () => {
it('should track remembered external attribution once before falling back to internal source', () => {
rememberCreateAppExternalAttribution({
searchParams: new URLSearchParams('utm_source=newsletter&slug=how-to-build-rag-agent'),
})
trackCreateApp({ appMode: AppModeEnum.WORKFLOW })
expect(amplitude.trackEvent).toHaveBeenNthCalledWith(1, 'create_app', {
source: 'external',
utm_source: 'blog',
utm_campaign: 'how-to-build-rag-agent',
})
trackCreateApp({ appMode: AppModeEnum.WORKFLOW })
expect(amplitude.trackEvent).toHaveBeenNthCalledWith(2, 'create_app', {
source: 'original',
app_mode: 'workflow',
time: expect.stringMatching(/^\d{2}-\d{2}-\d{2}:\d{2}:\d{2}$/),
})
})
})
})

View File

@ -0,0 +1,187 @@
import Cookies from 'js-cookie'
import { trackEvent } from '@/app/components/base/amplitude'
import { AppModeEnum } from '@/types/app'
const CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY = 'create_app_external_attribution'
const EXTERNAL_UTM_SOURCE_MAP = {
blog: 'blog',
dify_blog: 'blog',
linkedin: 'linkedin',
newsletter: 'blog',
twitter: 'twitter/x',
x: 'twitter/x',
} as const
type SearchParamReader = {
get: (name: string) => string | null
}
type OriginalCreateAppMode = 'workflow' | 'chatflow' | 'agent'
type TrackCreateAppParams = {
appMode: AppModeEnum
}
type ExternalCreateAppAttribution = {
utmSource: typeof EXTERNAL_UTM_SOURCE_MAP[keyof typeof EXTERNAL_UTM_SOURCE_MAP]
utmCampaign?: string
}
const normalizeString = (value?: string | null) => {
const trimmed = value?.trim()
return trimmed || undefined
}
const getObjectStringValue = (value: unknown) => {
return typeof value === 'string' ? normalizeString(value) : undefined
}
const getSearchParamValue = (searchParams?: SearchParamReader | null, key?: string) => {
if (!searchParams || !key)
return undefined
return normalizeString(searchParams.get(key))
}
const parseJSONRecord = (value?: string | null): Record<string, unknown> | null => {
if (!value)
return null
try {
const parsed = JSON.parse(value)
return parsed && typeof parsed === 'object' ? parsed as Record<string, unknown> : null
}
catch {
return null
}
}
const getCookieUtmInfo = () => {
return parseJSONRecord(Cookies.get('utm_info'))
}
const mapExternalUtmSource = (value?: string) => {
if (!value)
return undefined
const normalized = value.toLowerCase()
return EXTERNAL_UTM_SOURCE_MAP[normalized as keyof typeof EXTERNAL_UTM_SOURCE_MAP]
}
const padTimeValue = (value: number) => String(value).padStart(2, '0')
const formatCreateAppTime = (date: Date) => {
return `${padTimeValue(date.getMonth() + 1)}-${padTimeValue(date.getDate())}-${padTimeValue(date.getHours())}:${padTimeValue(date.getMinutes())}:${padTimeValue(date.getSeconds())}`
}
const mapOriginalCreateAppMode = (appMode: AppModeEnum): OriginalCreateAppMode => {
if (appMode === AppModeEnum.WORKFLOW)
return 'workflow'
if (appMode === AppModeEnum.AGENT_CHAT)
return 'agent'
return 'chatflow'
}
export const extractExternalCreateAppAttribution = ({
searchParams,
utmInfo,
}: {
searchParams?: SearchParamReader | null
utmInfo?: Record<string, unknown> | null
}) => {
const rawSource = getSearchParamValue(searchParams, 'utm_source') ?? getObjectStringValue(utmInfo?.utm_source)
const mappedSource = mapExternalUtmSource(rawSource)
if (!mappedSource)
return null
const utmCampaign = getSearchParamValue(searchParams, 'slug')
?? getSearchParamValue(searchParams, 'utm_campaign')
?? getObjectStringValue(utmInfo?.slug)
?? getObjectStringValue(utmInfo?.utm_campaign)
return {
utmSource: mappedSource,
...(utmCampaign ? { utmCampaign } : {}),
} satisfies ExternalCreateAppAttribution
}
const readRememberedExternalCreateAppAttribution = (): ExternalCreateAppAttribution | null => {
if (typeof window === 'undefined')
return null
return parseJSONRecord(window.sessionStorage.getItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY)) as ExternalCreateAppAttribution | null
}
const writeRememberedExternalCreateAppAttribution = (attribution: ExternalCreateAppAttribution) => {
if (typeof window === 'undefined')
return
window.sessionStorage.setItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY, JSON.stringify(attribution))
}
const clearRememberedExternalCreateAppAttribution = () => {
if (typeof window === 'undefined')
return
window.sessionStorage.removeItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY)
}
export const rememberCreateAppExternalAttribution = ({
searchParams,
utmInfo,
}: {
searchParams?: SearchParamReader | null
utmInfo?: Record<string, unknown> | null
} = {}) => {
const attribution = extractExternalCreateAppAttribution({
searchParams,
utmInfo: utmInfo ?? getCookieUtmInfo(),
})
if (attribution)
writeRememberedExternalCreateAppAttribution(attribution)
return attribution
}
const resolveCurrentExternalCreateAppAttribution = () => {
if (typeof window === 'undefined')
return null
return rememberCreateAppExternalAttribution({
searchParams: new URLSearchParams(window.location.search),
}) ?? readRememberedExternalCreateAppAttribution()
}
export const buildCreateAppEventPayload = (
params: TrackCreateAppParams,
externalAttribution?: ExternalCreateAppAttribution | null,
currentTime = new Date(),
) => {
if (externalAttribution) {
return {
source: 'external',
utm_source: externalAttribution.utmSource,
...(externalAttribution.utmCampaign ? { utm_campaign: externalAttribution.utmCampaign } : {}),
} satisfies Record<string, string>
}
return {
source: 'original',
app_mode: mapOriginalCreateAppMode(params.appMode),
time: formatCreateAppTime(currentTime),
} satisfies Record<string, string>
}
export const trackCreateApp = (params: TrackCreateAppParams) => {
const externalAttribution = resolveCurrentExternalCreateAppAttribution()
const payload = buildCreateAppEventPayload(params, externalAttribution)
if (externalAttribution)
clearRememberedExternalCreateAppAttribution()
trackEvent('create_app', payload)
}