mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
Merge remote-tracking branch 'origin/deploy/dev' into feat/evaluation
# Conflicts: # api/tests/unit_tests/services/test_async_workflow_service.py # api/tests/unit_tests/services/test_webhook_service.py Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
This commit is contained in:
commit
e2e5ad0c33
@ -1,4 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
@ -10,6 +12,7 @@ from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@ -77,3 +80,39 @@ class PartnerTenants(Resource):
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
||||
|
||||
_DEBUG_KEY = "billing:debug"
|
||||
_DEBUG_TTL = timedelta(days=7)
|
||||
|
||||
|
||||
class DebugDataPayload(BaseModel):
|
||||
type: str = Field(..., min_length=1, description="Data type key")
|
||||
data: str = Field(..., min_length=1, description="Data value to append")
|
||||
|
||||
|
||||
@console_ns.route("/billing/debug/data")
|
||||
class DebugData(Resource):
|
||||
def post(self):
|
||||
body = DebugDataPayload.model_validate(request.get_json(force=True))
|
||||
item = json.dumps({
|
||||
"type": body.type,
|
||||
"data": body.data,
|
||||
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
})
|
||||
redis_client.lpush(_DEBUG_KEY, item)
|
||||
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
|
||||
return {"result": "ok"}, 201
|
||||
|
||||
def get(self):
|
||||
recent = request.args.get("recent", 10, type=int)
|
||||
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
|
||||
return {
|
||||
"data": [
|
||||
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
|
||||
]
|
||||
}
|
||||
|
||||
def delete(self):
|
||||
redis_client.delete(_DEBUG_KEY)
|
||||
return {"result": "ok"}
|
||||
|
||||
@ -1,56 +1,17 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum, auto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota consumption operation.
|
||||
|
||||
Attributes:
|
||||
success: Whether the quota charge succeeded
|
||||
charge_id: UUID for refund, or None if failed/disabled
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None
|
||||
_quota_type: "QuotaType"
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Refund this quota charge.
|
||||
|
||||
Safe to call even if charge failed or was disabled.
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if self.charge_id:
|
||||
self._quota_type.refund(self.charge_id)
|
||||
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
|
||||
|
||||
|
||||
class QuotaType(StrEnum):
|
||||
"""
|
||||
Supported quota types for tenant feature usage.
|
||||
|
||||
Add additional types here whenever new billable features become available.
|
||||
"""
|
||||
|
||||
# Trigger execution quota
|
||||
TRIGGER = auto()
|
||||
|
||||
# Workflow execution quota
|
||||
WORKFLOW = auto()
|
||||
|
||||
UNLIMITED = auto()
|
||||
|
||||
@property
|
||||
def billing_key(self) -> str:
|
||||
"""
|
||||
Get the billing key for the feature.
|
||||
"""
|
||||
match self:
|
||||
case QuotaType.TRIGGER:
|
||||
return "trigger_event"
|
||||
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
|
||||
return "api_rate_limit"
|
||||
case _:
|
||||
raise ValueError(f"Invalid quota type: {self}")
|
||||
|
||||
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Consume quota for the feature.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to consume (default: 1)
|
||||
|
||||
Returns:
|
||||
QuotaCharge with success status and charge_id for refund
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
|
||||
|
||||
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to consume must be greater than 0")
|
||||
|
||||
try:
|
||||
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
|
||||
|
||||
if response.get("result") != "success":
|
||||
logger.warning(
|
||||
"Failed to consume quota for %s, feature %s details: %s",
|
||||
tenant_id,
|
||||
self.value,
|
||||
response.get("detail"),
|
||||
)
|
||||
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
charge_id = response.get("history_id")
|
||||
logger.debug(
|
||||
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
|
||||
amount,
|
||||
self.value,
|
||||
tenant_id,
|
||||
charge_id,
|
||||
)
|
||||
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
# fail-safe: allow request on billing errors
|
||||
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
|
||||
return unlimited()
|
||||
|
||||
def check(self, tenant_id: str, amount: int = 1) -> bool:
|
||||
"""
|
||||
Check if tenant has sufficient quota without consuming.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to check (default: 1)
|
||||
|
||||
Returns:
|
||||
True if quota is sufficient, False otherwise
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = self.get_remaining(tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
|
||||
# fail-safe: allow request on billing errors
|
||||
return True
|
||||
|
||||
def refund(self, charge_id: str) -> None:
|
||||
"""
|
||||
Refund quota using charge_id from consume().
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
All errors are logged but silently handled.
|
||||
|
||||
Args:
|
||||
charge_id: The UUID returned from consume()
|
||||
"""
|
||||
try:
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not charge_id:
|
||||
logger.warning("Cannot refund: charge_id is empty")
|
||||
return
|
||||
|
||||
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
|
||||
|
||||
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
|
||||
if response.get("result") == "success":
|
||||
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
|
||||
else:
|
||||
logger.warning("Refund failed for charge_id: %s", charge_id)
|
||||
|
||||
except Exception:
|
||||
# Catch ALL exceptions - refund must never fail
|
||||
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
|
||||
# Don't raise - refund is best-effort and must be silent
|
||||
|
||||
def get_remaining(self, tenant_id: str) -> int:
|
||||
"""
|
||||
Get remaining quota for the tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Remaining quota amount
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
|
||||
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
|
||||
if isinstance(usage_info, dict):
|
||||
return usage_info.get("remaining", 0)
|
||||
# If it returns a simple number, treat it as remaining
|
||||
return int(usage_info) if usage_info else 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
|
||||
return -1
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
"""
|
||||
Return a quota charge for unlimited quota.
|
||||
|
||||
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
|
||||
"""
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
|
||||
from core.app.features.rate_limiting.rate_limit import rate_limit_context
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.db import session_factory
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from enums.quota_type import QuotaType
|
||||
from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
|
||||
|
||||
@ -106,7 +107,7 @@ class AppGenerateService:
|
||||
quota_charge = unlimited()
|
||||
if dify_config.BILLING_ENABLED:
|
||||
try:
|
||||
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
|
||||
except QuotaExceededError:
|
||||
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
|
||||
|
||||
@ -116,6 +117,7 @@ class AppGenerateService:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
quota_charge.commit()
|
||||
effective_mode = (
|
||||
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
|
||||
)
|
||||
|
||||
@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -131,9 +132,10 @@ class AsyncWorkflowService:
|
||||
trigger_log = trigger_log_repo.create(trigger_log)
|
||||
session.commit()
|
||||
|
||||
# 7. Check and consume quota
|
||||
# 7. Reserve quota (commit after successful dispatch)
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
|
||||
except QuotaExceededError as e:
|
||||
# Update trigger log status
|
||||
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
||||
@ -153,13 +155,18 @@ class AsyncWorkflowService:
|
||||
# 9. Dispatch to appropriate queue
|
||||
task_data_dict = task_data.model_dump(mode="json")
|
||||
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
try:
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
|
||||
@ -32,6 +32,102 @@ class SubscriptionPlan(TypedDict):
|
||||
expiration_date: int
|
||||
|
||||
|
||||
class QuotaReserveResult(TypedDict):
|
||||
reservation_id: str
|
||||
available: int
|
||||
reserved: int
|
||||
|
||||
|
||||
class QuotaCommitResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
refunded: int
|
||||
|
||||
|
||||
class QuotaReleaseResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
released: int
|
||||
|
||||
|
||||
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
|
||||
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
|
||||
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
|
||||
class _BillingQuota(TypedDict):
|
||||
size: int
|
||||
limit: int
|
||||
|
||||
|
||||
class _VectorSpaceQuota(TypedDict):
|
||||
size: float
|
||||
limit: int
|
||||
|
||||
|
||||
class _KnowledgeRateLimit(TypedDict):
|
||||
# NOTE (hj24):
|
||||
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
|
||||
# 2. Keep it for compatibility for now, can be deprecated in future versions.
|
||||
size: NotRequired[int]
|
||||
# NOTE END
|
||||
limit: int
|
||||
|
||||
|
||||
class _BillingSubscription(TypedDict):
|
||||
plan: str
|
||||
interval: str
|
||||
education: bool
|
||||
|
||||
|
||||
class BillingInfo(TypedDict):
|
||||
"""Response of /subscription/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
|
||||
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
|
||||
1. validate_python in non-strict mode will coerce it to the expected type
|
||||
2. In strict mode, it will raise ValidationError
|
||||
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
subscription: _BillingSubscription
|
||||
members: _BillingQuota
|
||||
apps: _BillingQuota
|
||||
vector_space: _VectorSpaceQuota
|
||||
knowledge_rate_limit: _KnowledgeRateLimit
|
||||
documents_upload_quota: _BillingQuota
|
||||
annotation_quota_limit: _BillingQuota
|
||||
docs_processing: str
|
||||
can_replace_logo: bool
|
||||
model_load_balancing_enabled: bool
|
||||
knowledge_pipeline_publish_enabled: bool
|
||||
next_credit_reset_date: NotRequired[int]
|
||||
|
||||
|
||||
_billing_info_adapter = TypeAdapter(BillingInfo)
|
||||
|
||||
|
||||
class _TenantFeatureQuota(TypedDict):
|
||||
usage: int
|
||||
limit: int
|
||||
reset_date: NotRequired[int]
|
||||
|
||||
|
||||
class TenantFeatureQuotaInfo(TypedDict):
|
||||
"""Response of /quota/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Same convention as BillingInfo: billing may return int fields as str,
|
||||
always keep non-strict mode to auto-coerce.
|
||||
"""
|
||||
|
||||
trigger_event: _TenantFeatureQuota
|
||||
api_rate_limit: _TenantFeatureQuota
|
||||
|
||||
|
||||
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
|
||||
|
||||
|
||||
class _BillingQuota(TypedDict):
|
||||
size: int
|
||||
limit: int
|
||||
@ -149,11 +245,63 @@ class BillingService:
|
||||
|
||||
@classmethod
|
||||
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
||||
"""Deprecated: Use get_quota_info instead."""
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
|
||||
return usage_info
|
||||
|
||||
@classmethod
|
||||
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
|
||||
params = {"tenant_id": tenant_id}
|
||||
return _tenant_feature_quota_info_adapter.validate_python(
|
||||
cls._send_request("GET", "/quota/info", params=params)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def quota_reserve(
|
||||
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
|
||||
) -> QuotaReserveResult:
|
||||
"""Reserve quota before task execution."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"request_id": request_id,
|
||||
"amount": amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_commit(
|
||||
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
|
||||
) -> QuotaCommitResult:
|
||||
"""Commit a reservation with actual consumption."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
"actual_amount": actual_amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
|
||||
"""Release a reservation (cancel, return frozen quota)."""
|
||||
return _quota_release_adapter.validate_python(
|
||||
cls._send_request(
|
||||
"POST",
|
||||
"/quota/release",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
@ -281,7 +281,7 @@ class FeatureService:
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
|
||||
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
|
||||
features_usage_info = BillingService.get_quota_info(tenant_id)
|
||||
|
||||
features.billing.enabled = billing_info["enabled"]
|
||||
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
||||
|
||||
233
api/services/quota_service.py
Normal file
233
api/services/quota_service.py
Normal file
@ -0,0 +1,233 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota reservation (Reserve phase).
|
||||
|
||||
Lifecycle:
|
||||
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
|
||||
try:
|
||||
do_work()
|
||||
charge.commit() # Confirm consumption
|
||||
except:
|
||||
charge.refund() # Release frozen quota
|
||||
|
||||
If neither commit() nor refund() is called, the billing system's
|
||||
cleanup CronJob will auto-release the reservation within ~75 seconds.
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None # reservation_id
|
||||
_quota_type: QuotaType
|
||||
_tenant_id: str | None = None
|
||||
_feature_key: str | None = None
|
||||
_amount: int = 0
|
||||
_committed: bool = field(default=False, repr=False)
|
||||
|
||||
def commit(self, actual_amount: int | None = None) -> None:
|
||||
"""
|
||||
Confirm the consumption with actual amount.
|
||||
|
||||
Args:
|
||||
actual_amount: Actual amount consumed. Defaults to the reserved amount.
|
||||
If less than reserved, the difference is refunded automatically.
|
||||
"""
|
||||
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
amount = actual_amount if actual_amount is not None else self._amount
|
||||
BillingService.quota_commit(
|
||||
tenant_id=self._tenant_id,
|
||||
feature_key=self._feature_key,
|
||||
reservation_id=self.charge_id,
|
||||
actual_amount=amount,
|
||||
)
|
||||
self._committed = True
|
||||
logger.debug(
|
||||
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
|
||||
self._quota_type,
|
||||
self._tenant_id,
|
||||
self.charge_id,
|
||||
amount,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Release the reserved quota (cancel the charge).
|
||||
|
||||
Safe to call even if:
|
||||
- charge failed or was disabled (charge_id is None)
|
||||
- already committed (Release after Commit is a no-op)
|
||||
- already refunded (idempotent)
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
|
||||
class QuotaService:
|
||||
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
|
||||
|
||||
@staticmethod
|
||||
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve + immediate Commit (one-shot mode).
|
||||
|
||||
The returned QuotaCharge supports .refund() which calls Release.
|
||||
For two-phase usage (e.g. streaming), use reserve() directly.
|
||||
"""
|
||||
charge = QuotaService.reserve(quota_type, tenant_id, amount)
|
||||
if charge.success and charge.charge_id:
|
||||
charge.commit()
|
||||
return charge
|
||||
|
||||
@staticmethod
|
||||
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve quota before task execution (Reserve phase only).
|
||||
|
||||
The caller MUST call charge.commit() after the task succeeds,
|
||||
or charge.refund() if the task fails.
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
|
||||
|
||||
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to reserve must be greater than 0")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
feature_key = quota_type.billing_key
|
||||
|
||||
try:
|
||||
reserve_resp = BillingService.quota_reserve(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
request_id=request_id,
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
reservation_id = reserve_resp.get("reservation_id")
|
||||
if not reservation_id:
|
||||
logger.warning(
|
||||
"Reserve returned no reservation_id for %s, feature %s, response: %s",
|
||||
tenant_id,
|
||||
quota_type.value,
|
||||
reserve_resp,
|
||||
)
|
||||
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
logger.debug(
|
||||
"Reserved %d %s quota for tenant %s, reservation_id: %s",
|
||||
amount,
|
||||
quota_type.value,
|
||||
tenant_id,
|
||||
reservation_id,
|
||||
)
|
||||
return QuotaCharge(
|
||||
success=True,
|
||||
charge_id=reservation_id,
|
||||
_quota_type=quota_type,
|
||||
_tenant_id=tenant_id,
|
||||
_feature_key=feature_key,
|
||||
_amount=amount,
|
||||
)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return unlimited()
|
||||
|
||||
@staticmethod
|
||||
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = QuotaService.get_remaining(quota_type, tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
|
||||
"""Release a reservation. Guarantees no exceptions."""
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not reservation_id:
|
||||
return
|
||||
|
||||
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
|
||||
BillingService.quota_release(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
reservation_id=reservation_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
|
||||
|
||||
@staticmethod
|
||||
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_quota_info(tenant_id)
|
||||
if isinstance(usage_info, dict):
|
||||
feature_info = usage_info.get(quota_type.billing_key, {})
|
||||
if isinstance(feature_info, dict):
|
||||
limit = feature_info.get("limit", 0)
|
||||
usage = feature_info.get("usage", 0)
|
||||
if limit == -1:
|
||||
return -1
|
||||
return max(0, limit - usage)
|
||||
return 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return -1
|
||||
@ -38,6 +38,7 @@ from models.workflow import Workflow
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
|
||||
@ -819,9 +820,9 @@ class WebhookService:
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# consume quota before triggering workflow execution
|
||||
# reserve quota before triggering workflow execution
|
||||
try:
|
||||
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
@ -832,11 +833,16 @@ class WebhookService:
|
||||
raise
|
||||
|
||||
# Trigger workflow execution asynchronously
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
try:
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||
|
||||
@ -28,7 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from enums.quota_type import QuotaType
|
||||
from models.enums import (
|
||||
AppTriggerType,
|
||||
CreatorUserRole,
|
||||
@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
|
||||
@ -298,10 +299,10 @@ def dispatch_triggered_workflow(
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
)
|
||||
|
||||
# consume quota before invoking trigger
|
||||
# reserve quota before invoking trigger
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info(
|
||||
@ -387,6 +388,7 @@ def dispatch_triggered_workflow(
|
||||
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
|
||||
|
||||
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
|
||||
quota_charge.commit()
|
||||
dispatched_count += 1
|
||||
logger.info(
|
||||
"Triggered workflow for app %s with trigger event %s",
|
||||
|
||||
@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import (
|
||||
ScheduleNotFoundError,
|
||||
TenantOwnerNotFoundError,
|
||||
)
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from enums.quota_type import QuotaType
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
from services.workflow.entities import ScheduleTriggerData
|
||||
@ -43,7 +44,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
@ -61,6 +62,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
tenant_id=schedule.tenant_id,
|
||||
),
|
||||
)
|
||||
quota_charge.commit()
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
|
||||
@ -36,12 +36,19 @@ class TestAppGenerateService:
|
||||
) as mock_message_based_generator,
|
||||
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
|
||||
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
|
||||
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
|
||||
):
|
||||
# Setup default mock returns for billing service
|
||||
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
|
||||
"result": "success",
|
||||
"history_id": "test_history_id",
|
||||
mock_billing_service.quota_reserve.return_value = {
|
||||
"reservation_id": "test-reservation-id",
|
||||
"available": 100,
|
||||
"reserved": 1,
|
||||
}
|
||||
mock_billing_service.quota_commit.return_value = {
|
||||
"available": 99,
|
||||
"reserved": 0,
|
||||
"refunded": 0,
|
||||
}
|
||||
|
||||
# Setup default mock returns for workflow service
|
||||
@ -101,6 +108,8 @@ class TestAppGenerateService:
|
||||
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
|
||||
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
|
||||
mock_quota_dify_config.BILLING_ENABLED = False
|
||||
|
||||
mock_global_dify_config.BILLING_ENABLED = False
|
||||
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
@ -118,6 +127,7 @@ class TestAppGenerateService:
|
||||
"message_based_generator": mock_message_based_generator,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"dify_config": mock_dify_config,
|
||||
"quota_dify_config": mock_quota_dify_config,
|
||||
"global_dify_config": mock_global_dify_config,
|
||||
}
|
||||
|
||||
@ -465,6 +475,7 @@ class TestAppGenerateService:
|
||||
|
||||
# Set BILLING_ENABLED to True for this test
|
||||
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
|
||||
|
||||
# Setup test arguments
|
||||
@ -478,8 +489,10 @@ class TestAppGenerateService:
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify billing service was called to consume quota
|
||||
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
|
||||
# Verify billing two-phase quota (reserve + commit)
|
||||
billing = mock_external_service_dependencies["billing_service"]
|
||||
billing.quota_reserve.assert_called_once()
|
||||
billing.quota_commit.assert_called_once()
|
||||
|
||||
def test_generate_with_invalid_app_mode(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
|
||||
@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log(
|
||||
)
|
||||
|
||||
# Mock quota to avoid rate limiting
|
||||
from enums import quota_type
|
||||
from services import quota_service
|
||||
|
||||
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
|
||||
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
|
||||
|
||||
# Execute schedule trigger
|
||||
workflow_schedule_tasks.run_schedule_trigger(plan.id)
|
||||
|
||||
0
api/tests/unit_tests/enums/__init__.py
Normal file
0
api/tests/unit_tests/enums/__init__.py
Normal file
349
api/tests/unit_tests/enums/test_quota_type.py
Normal file
349
api/tests/unit_tests/enums/test_quota_type.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enums.quota_type import QuotaType
|
||||
from services.quota_service import QuotaCharge, QuotaService, unlimited
|
||||
|
||||
|
||||
class TestQuotaType:
|
||||
def test_billing_key_trigger(self):
|
||||
assert QuotaType.TRIGGER.billing_key == "trigger_event"
|
||||
|
||||
def test_billing_key_workflow(self):
|
||||
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
|
||||
|
||||
def test_billing_key_unlimited_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid quota type"):
|
||||
_ = QuotaType.UNLIMITED.billing_key
|
||||
|
||||
|
||||
class TestQuotaService:
|
||||
def test_reserve_billing_disabled(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService"),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
|
||||
def test_reserve_zero_amount_raises(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
|
||||
|
||||
def test_reserve_success(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
|
||||
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
|
||||
|
||||
assert charge.success is True
|
||||
assert charge.charge_id == "rid-1"
|
||||
assert charge._tenant_id == "t1"
|
||||
assert charge._feature_key == "trigger_event"
|
||||
assert charge._amount == 1
|
||||
mock_bs.quota_reserve.assert_called_once()
|
||||
|
||||
def test_reserve_no_reservation_id_raises(self):
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {}
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
|
||||
def test_reserve_quota_exceeded_propagates(self):
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
|
||||
def test_reserve_api_exception_returns_unlimited(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.side_effect = RuntimeError("network")
|
||||
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
|
||||
def test_consume_calls_reserve_and_commit(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
|
||||
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
mock_bs.quota_commit.assert_called_once()
|
||||
|
||||
def test_check_billing_disabled(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
|
||||
|
||||
def test_check_zero_amount_raises(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
|
||||
|
||||
def test_check_sufficient_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=100),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
|
||||
|
||||
def test_check_insufficient_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=5),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
|
||||
|
||||
def test_check_unlimited_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=-1),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
|
||||
|
||||
def test_check_exception_returns_true(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
|
||||
|
||||
def test_release_billing_disabled(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_not_called()
|
||||
|
||||
def test_release_empty_reservation(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_not_called()
|
||||
|
||||
def test_release_success(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_release.return_value = {}
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_called_once_with(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
|
||||
)
|
||||
|
||||
def test_release_exception_swallowed(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_release.side_effect = RuntimeError("fail")
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
|
||||
def test_get_remaining_normal(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
|
||||
|
||||
def test_get_remaining_unlimited(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
|
||||
|
||||
def test_get_remaining_over_limit_returns_zero(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_exception_returns_neg1(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.side_effect = RuntimeError
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
|
||||
|
||||
def test_get_remaining_empty_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_non_dict_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = "invalid"
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_feature_not_in_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
|
||||
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
|
||||
assert remaining == 0
|
||||
|
||||
def test_get_remaining_non_dict_feature_info(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
|
||||
class TestQuotaCharge:
|
||||
def test_commit_success(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_called_once_with(
|
||||
tenant_id="t1",
|
||||
feature_key="trigger_event",
|
||||
reservation_id="rid-1",
|
||||
actual_amount=1,
|
||||
)
|
||||
assert charge._committed is True
|
||||
|
||||
def test_commit_with_actual_amount(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=10,
|
||||
)
|
||||
charge.commit(actual_amount=5)
|
||||
call_kwargs = mock_bs.quota_commit.call_args[1]
|
||||
assert call_kwargs["actual_amount"] == 5
|
||||
|
||||
def test_commit_idempotent(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
charge.commit()
|
||||
assert mock_bs.quota_commit.call_count == 1
|
||||
|
||||
def test_commit_no_charge_id_noop(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_not_called()
|
||||
|
||||
def test_commit_no_tenant_id_noop(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id=None,
|
||||
_feature_key="trigger_event",
|
||||
)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_not_called()
|
||||
|
||||
def test_commit_exception_swallowed(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.side_effect = RuntimeError("fail")
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
|
||||
def test_refund_success(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
)
|
||||
charge.refund()
|
||||
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
|
||||
def test_refund_no_charge_id_noop(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
|
||||
charge.refund()
|
||||
mock_rel.assert_not_called()
|
||||
|
||||
def test_refund_no_tenant_id_noop(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id=None,
|
||||
)
|
||||
charge.refund()
|
||||
mock_rel.assert_not_called()
|
||||
|
||||
|
||||
class TestUnlimited:
|
||||
def test_unlimited_returns_success_with_no_charge_id(self):
|
||||
charge = unlimited()
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
assert charge._quota_type == QuotaType.UNLIMITED
|
||||
@ -23,6 +23,7 @@ import pytest
|
||||
|
||||
import services.app_generate_service as ags_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from enums.quota_type import QuotaType
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
@ -447,8 +448,8 @@ class TestGenerateBilling:
|
||||
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
consume_mock = mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
reserve_mock = mocker.patch(
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
@ -467,7 +468,8 @@ class TestGenerateBilling:
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
consume_mock.assert_called_once_with("tenant-id")
|
||||
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
|
||||
quota_charge.commit.assert_called_once()
|
||||
|
||||
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
|
||||
from services.errors.app import QuotaExceededError
|
||||
@ -475,7 +477,7 @@ class TestGenerateBilling:
|
||||
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
|
||||
)
|
||||
|
||||
@ -492,7 +494,7 @@ class TestGenerateBilling:
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
|
||||
@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
|
||||
- repo: SQLAlchemyWorkflowTriggerLogRepository
|
||||
- dispatcher_manager_class: QueueDispatcherManager class
|
||||
- dispatcher: dispatcher instance
|
||||
- quota_workflow: QuotaType.WORKFLOW
|
||||
- quota_service: QuotaService mock
|
||||
- get_workflow: AsyncWorkflowService._get_workflow method
|
||||
- professional_task: execute_workflow_professional
|
||||
- team_task: execute_workflow_team
|
||||
@ -72,7 +72,12 @@ class TestAsyncWorkflowService:
|
||||
mock_repo.create.side_effect = _create_side_effect
|
||||
|
||||
mock_dispatcher = MagicMock()
|
||||
quota_workflow = MagicMock()
|
||||
mock_quota_service = MagicMock()
|
||||
mock_get_workflow = MagicMock()
|
||||
|
||||
mock_professional_task = MagicMock()
|
||||
mock_team_task = MagicMock()
|
||||
mock_sandbox_task = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
@ -88,8 +93,8 @@ class TestAsyncWorkflowService:
|
||||
) as mock_get_workflow,
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"QuotaType",
|
||||
new=SimpleNamespace(WORKFLOW=quota_workflow),
|
||||
"QuotaService",
|
||||
new=mock_quota_service,
|
||||
),
|
||||
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
|
||||
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
|
||||
@ -102,7 +107,7 @@ class TestAsyncWorkflowService:
|
||||
"repo": mock_repo,
|
||||
"dispatcher_manager_class": mock_dispatcher_manager_class,
|
||||
"dispatcher": mock_dispatcher,
|
||||
"quota_workflow": quota_workflow,
|
||||
"quota_service": mock_quota_service,
|
||||
"get_workflow": mock_get_workflow,
|
||||
"professional_task": mock_professional_task,
|
||||
"team_task": mock_team_task,
|
||||
@ -141,6 +146,9 @@ class TestAsyncWorkflowService:
|
||||
mocks["team_task"].delay.return_value = task_result
|
||||
mocks["sandbox_task"].delay.return_value = task_result
|
||||
|
||||
quota_charge_mock = MagicMock()
|
||||
mocks["quota_service"].reserve.return_value = quota_charge_mock
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id: str):
|
||||
self.id = user_id
|
||||
@ -158,7 +166,8 @@ class TestAsyncWorkflowService:
|
||||
assert result.status == "queued"
|
||||
assert result.queue == queue_name
|
||||
|
||||
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
|
||||
mocks["quota_service"].reserve.assert_called_once()
|
||||
quota_charge_mock.commit.assert_called_once()
|
||||
assert session.commit.call_count == 2
|
||||
|
||||
created_log = mocks["repo"].create.call_args[0][0]
|
||||
@ -245,7 +254,7 @@ class TestAsyncWorkflowService:
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
|
||||
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
|
||||
feature="workflow",
|
||||
tenant_id="tenant-123",
|
||||
required=1,
|
||||
|
||||
@ -425,7 +425,7 @@ class TestBillingServiceUsageCalculation:
|
||||
yield mock
|
||||
|
||||
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
|
||||
"""Test retrieval of tenant feature plan usage information."""
|
||||
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
|
||||
@ -438,6 +438,20 @@ class TestBillingServiceUsageCalculation:
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
|
||||
|
||||
def test_get_quota_info(self, mock_send_request):
|
||||
"""Test retrieval of quota info from new endpoint."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
|
||||
mock_send_request.return_value = expected_response
|
||||
|
||||
# Act
|
||||
result = BillingService.get_quota_info(tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
|
||||
|
||||
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
|
||||
"""Test updating tenant feature usage with positive delta (adding credits)."""
|
||||
# Arrange
|
||||
@ -515,6 +529,150 @@ class TestBillingServiceUsageCalculation:
|
||||
)
|
||||
|
||||
|
||||
class TestBillingServiceQuotaOperations:
|
||||
"""Unit tests for quota reserve/commit/release operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_request(self):
|
||||
with patch.object(BillingService, "_send_request") as mock:
|
||||
yield mock
|
||||
|
||||
def test_quota_reserve_success(self, mock_send_request):
|
||||
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/reserve",
|
||||
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
|
||||
)
|
||||
|
||||
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
|
||||
|
||||
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
|
||||
|
||||
assert result["available"] == 99
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["reserved"] == 1
|
||||
assert isinstance(result["reserved"], int)
|
||||
|
||||
def test_quota_reserve_with_meta(self, mock_send_request):
|
||||
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
|
||||
meta = {"source": "webhook"}
|
||||
|
||||
BillingService.quota_reserve(
|
||||
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
|
||||
)
|
||||
|
||||
call_json = mock_send_request.call_args[1]["json"]
|
||||
assert call_json["meta"] == {"source": "webhook"}
|
||||
|
||||
def test_quota_commit_success(self, mock_send_request):
|
||||
expected = {"available": 98, "reserved": 0, "refunded": 0}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/commit",
|
||||
json={
|
||||
"tenant_id": "t1",
|
||||
"feature_key": "trigger_event",
|
||||
"reservation_id": "rid-1",
|
||||
"actual_amount": 1,
|
||||
},
|
||||
)
|
||||
|
||||
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
|
||||
|
||||
result = BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
|
||||
)
|
||||
|
||||
assert result["available"] == 97
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["refunded"] == 1
|
||||
assert isinstance(result["refunded"], int)
|
||||
|
||||
def test_quota_commit_with_meta(self, mock_send_request):
|
||||
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
|
||||
meta = {"reason": "partial"}
|
||||
|
||||
BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
|
||||
)
|
||||
|
||||
call_json = mock_send_request.call_args[1]["json"]
|
||||
assert call_json["meta"] == {"reason": "partial"}
|
||||
|
||||
def test_quota_release_success(self, mock_send_request):
|
||||
expected = {"available": 100, "reserved": 0, "released": 1}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/release",
|
||||
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
|
||||
)
|
||||
|
||||
def test_quota_release_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
|
||||
|
||||
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
|
||||
|
||||
assert result["available"] == 100
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["released"] == 1
|
||||
assert isinstance(result["released"], int)
|
||||
|
||||
def test_get_quota_info_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int for get_quota_info."""
|
||||
mock_send_request.return_value = {
|
||||
"trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"},
|
||||
"api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"},
|
||||
}
|
||||
|
||||
result = BillingService.get_quota_info("t1")
|
||||
|
||||
assert result["trigger_event"]["usage"] == 42
|
||||
assert isinstance(result["trigger_event"]["usage"], int)
|
||||
assert result["trigger_event"]["limit"] == 3000
|
||||
assert isinstance(result["trigger_event"]["limit"], int)
|
||||
assert result["trigger_event"]["reset_date"] == 1700000000
|
||||
assert isinstance(result["trigger_event"]["reset_date"], int)
|
||||
assert result["api_rate_limit"]["limit"] == -1
|
||||
assert isinstance(result["api_rate_limit"]["limit"], int)
|
||||
|
||||
def test_get_quota_info_accepts_int_values(self, mock_send_request):
|
||||
"""Test that get_quota_info works with native int values."""
|
||||
expected = {
|
||||
"trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000},
|
||||
"api_rate_limit": {"usage": 0, "limit": -1},
|
||||
}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.get_quota_info("t1")
|
||||
|
||||
assert result["trigger_event"]["usage"] == 42
|
||||
assert result["trigger_event"]["limit"] == 3000
|
||||
assert result["api_rate_limit"]["limit"] == -1
|
||||
|
||||
|
||||
class TestBillingServiceRateLimitEnforcement:
|
||||
"""Unit tests for rate limit enforcement mechanisms.
|
||||
|
||||
|
||||
@ -559,3 +559,772 @@ class TestWebhookServiceUnit:
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=True)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
|
||||
|
||||
|
||||
# === Merged from test_webhook_service_additional.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from graphon.variables.types import SegmentType
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionmakerContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def begin(self) -> "_SessionmakerContext":
|
||||
return self
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.return_value = None
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"key": "value"}}
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
"webhook-1", is_debug=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"mode": "debug"}}
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/vnd.custom"},
|
||||
data="plain content",
|
||||
):
|
||||
result = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result["body"] == {"raw": "plain content"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_raise_for_request_too_large(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
|
||||
|
||||
# Act / Assert
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
|
||||
with pytest.raises(RequestEntityTooLarge):
|
||||
WebhookService.extract_webhook_data(MagicMock())
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b""):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
|
||||
body, files = WebhookService._extract_text_body()
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": ""}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
|
||||
monkeypatch.setattr(service_module, "magic", fake_magic)
|
||||
|
||||
# Act
|
||||
result = WebhookService._detect_binary_mimetype(b"binary")
|
||||
|
||||
# Assert
|
||||
assert result == "application/octet-stream"
|
||||
|
||||
|
||||
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
file_obj = MagicMock()
|
||||
file_obj.to_dict.return_value = {"id": "f-1"}
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
|
||||
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
|
||||
|
||||
uploaded = MagicMock()
|
||||
uploaded.filename = "file.unknown"
|
||||
uploaded.content_type = None
|
||||
uploaded.read.return_value = b"content"
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result == {"f": {"id": "f-1"}}
|
||||
|
||||
|
||||
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
manager = MagicMock()
|
||||
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
|
||||
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
|
||||
expected_file = MagicMock()
|
||||
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
|
||||
|
||||
# Act
|
||||
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result is expected_file
|
||||
manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw_value", "param_type", "expected"),
|
||||
[
|
||||
("42", SegmentType.NUMBER, 42),
|
||||
("3.14", SegmentType.NUMBER, 3.14),
|
||||
("yes", SegmentType.BOOLEAN, True),
|
||||
("no", SegmentType.BOOLEAN, False),
|
||||
],
|
||||
)
|
||||
def test_convert_form_value_should_convert_supported_types(
|
||||
raw_value: str,
|
||||
param_type: str,
|
||||
expected: Any,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._convert_form_value("param", raw_value, param_type)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Unsupported type"):
|
||||
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
|
||||
|
||||
|
||||
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
|
||||
|
||||
# Assert
|
||||
assert result == {"x": 1}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
|
||||
# Arrange
|
||||
raw_params = {"optional": "x"}
|
||||
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required parameter missing"):
|
||||
WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_include_unconfigured_parameters() -> None:
|
||||
# Arrange
|
||||
raw_params = {"known": "1", "unknown": "x"}
|
||||
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
# Assert
|
||||
assert result == {"known": 1, "unknown": "x"}
|
||||
|
||||
|
||||
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required body content missing"):
|
||||
WebhookService._process_body_parameters(
|
||||
raw_body={"raw": ""},
|
||||
body_configs=[WebhookBodyParameter(name="raw", required=True)],
|
||||
content_type=ContentType.TEXT,
|
||||
)
|
||||
|
||||
|
||||
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
|
||||
# Arrange
|
||||
raw_body = {"message": "hello", "extra": "x"}
|
||||
body_configs = [
|
||||
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
|
||||
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
|
||||
|
||||
# Assert
|
||||
assert result == {"message": "hello", "extra": "x"}
|
||||
|
||||
|
||||
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
|
||||
# Arrange
|
||||
headers = {"x_api_key": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
# Assert
|
||||
assert True
|
||||
|
||||
|
||||
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
|
||||
# Arrange
|
||||
headers = {"x-other": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required header missing"):
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
|
||||
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
|
||||
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_http_metadata(webhook_data, node_data)
|
||||
|
||||
# Assert
|
||||
assert result["valid"] is False
|
||||
assert "Content-type mismatch" in result["error"]
|
||||
|
||||
|
||||
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
|
||||
# Arrange
|
||||
headers = {"content-type": "application/json; charset=utf-8"}
|
||||
|
||||
# Act
|
||||
result = WebhookService._extract_content_type(headers)
|
||||
|
||||
# Assert
|
||||
assert result == "application/json"
|
||||
|
||||
|
||||
def test_build_workflow_inputs_should_include_expected_keys() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
|
||||
|
||||
# Act
|
||||
result = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Assert
|
||||
assert result["webhook_data"] == webhook_data
|
||||
assert result["webhook_headers"] == {"h": "v"}
|
||||
assert result["webhook_query_params"] == {"q": 1}
|
||||
assert result["webhook_body"] == {"b": 2}
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
webhook_data = {"body": {"x": 1}}
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
|
||||
)
|
||||
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
trigger_async_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
|
||||
|
||||
# Act
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Assert
|
||||
trigger_async_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.QuotaService,
|
||||
"reserve",
|
||||
MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1)),
|
||||
)
|
||||
mark_rate_limited_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
mark_rate_limited_mock.assert_called_once_with("tenant-1")
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
|
||||
)
|
||||
logger_exception_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
logger_exception_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(
|
||||
walk_nodes=lambda _node_type: [
|
||||
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
|
||||
]
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
|
||||
|
||||
class _WorkflowWebhookTrigger:
|
||||
app_id = "app_id"
|
||||
tenant_id = "tenant_id"
|
||||
webhook_id = "webhook_id"
|
||||
node_id = "node_id"
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
|
||||
self.id = None
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.node_id = node_id
|
||||
self.webhook_id = webhook_id
|
||||
self.created_by = created_by
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.commit_count = 0
|
||||
self.existing_records = [SimpleNamespace(node_id="node-stale")]
|
||||
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: self.existing_records)
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def flush(self) -> None:
|
||||
for idx, obj in enumerate(self.added, start=1):
|
||||
if obj.id is None:
|
||||
obj.id = f"rec-{idx}"
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
|
||||
def delete(self, obj: Any) -> None:
|
||||
self.deleted.append(obj)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.return_value = None
|
||||
|
||||
fake_session = _Session()
|
||||
|
||||
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
redis_set_mock = MagicMock()
|
||||
redis_delete_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
|
||||
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
|
||||
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [])
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: [])
|
||||
|
||||
def commit(self) -> None:
|
||||
return None
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
logger_exception_mock = MagicMock()
|
||||
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
_patch_session(monkeypatch, _Session())
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert logger_exception_mock.call_count == 1
|
||||
|
||||
|
||||
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
|
||||
# Arrange
|
||||
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
|
||||
|
||||
# Act
|
||||
body, status = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
# Assert
|
||||
assert status == 200
|
||||
assert "message" in body
|
||||
|
||||
|
||||
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
webhook_id = WebhookService.generate_webhook_id()
|
||||
|
||||
# Assert
|
||||
assert isinstance(webhook_id, str)
|
||||
assert len(webhook_id) == 24
|
||||
|
||||
|
||||
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
|
||||
|
||||
# Assert
|
||||
assert result == 123
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import {
|
||||
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
|
||||
} from '@/app/education-apply/constants'
|
||||
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
|
||||
import { sendGAEvent } from '@/utils/gtag'
|
||||
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
|
||||
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
|
||||
@ -45,6 +46,8 @@ export const AppInitializer = ({
|
||||
(async () => {
|
||||
const action = searchParams.get('action')
|
||||
|
||||
rememberCreateAppExternalAttribution({ searchParams })
|
||||
|
||||
if (oauthNewUser) {
|
||||
let utmInfo = null
|
||||
const utmInfoStr = Cookies.get('utm_info')
|
||||
|
||||
@ -4,7 +4,6 @@ import { AppModeEnum } from '@/types/app'
|
||||
import Apps from '../index'
|
||||
|
||||
const mockUseExploreAppList = vi.fn()
|
||||
const mockTrackEvent = vi.fn()
|
||||
const mockImportDSL = vi.fn()
|
||||
const mockFetchAppDetail = vi.fn()
|
||||
const mockHandleCheckPluginDependencies = vi.fn()
|
||||
@ -12,6 +11,7 @@ const mockGetRedirection = vi.fn()
|
||||
const mockPush = vi.fn()
|
||||
const mockToastSuccess = vi.fn()
|
||||
const mockToastError = vi.fn()
|
||||
const mockTrackCreateApp = vi.fn()
|
||||
let latestDebounceFn = () => {}
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
@ -92,8 +92,8 @@ vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
error: (...args: unknown[]) => mockToastError(...args),
|
||||
},
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: (...args: unknown[]) => mockTrackEvent(...args),
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
|
||||
}))
|
||||
vi.mock('@/service/apps', () => ({
|
||||
importDSL: (...args: unknown[]) => mockImportDSL(...args),
|
||||
@ -246,10 +246,9 @@ describe('Apps', () => {
|
||||
}))
|
||||
})
|
||||
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_template', expect.objectContaining({
|
||||
template_id: 'Alpha',
|
||||
template_name: 'Alpha',
|
||||
}))
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated')
|
||||
expect(onSuccess).toHaveBeenCalled()
|
||||
expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('created-app-id')
|
||||
|
||||
@ -8,7 +8,6 @@ import * as React from 'react'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppTypeSelector from '@/app/components/app/type-selector'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
@ -25,6 +24,7 @@ import { useExploreAppList } from '@/service/use-explore'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import AppCard from '../app-card'
|
||||
import Sidebar, { AppCategories, AppCategoryLabel } from './sidebar'
|
||||
|
||||
@ -127,14 +127,7 @@ const Apps = ({
|
||||
icon_background,
|
||||
description,
|
||||
})
|
||||
|
||||
// Track app creation from template
|
||||
trackEvent('create_app_with_template', {
|
||||
app_mode: mode,
|
||||
template_id: currApp?.app.id,
|
||||
template_name: currApp?.app.name,
|
||||
description,
|
||||
})
|
||||
trackCreateApp({ appMode: mode })
|
||||
|
||||
setIsShowCreateModal(false)
|
||||
toast.success(t('newApp.appCreated', { ns: 'app' }))
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import type { App } from '@/types/app'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
@ -10,6 +9,7 @@ import { useRouter } from '@/next/navigation'
|
||||
import { createApp } from '@/service/apps'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import CreateAppModal from '../index'
|
||||
|
||||
const ahooksMocks = vi.hoisted(() => ({
|
||||
@ -31,8 +31,8 @@ vi.mock('ahooks', () => ({
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/service/apps', () => ({
|
||||
createApp: vi.fn(),
|
||||
@ -87,7 +87,7 @@ vi.mock('@/hooks/use-theme', () => ({
|
||||
const mockUseRouter = vi.mocked(useRouter)
|
||||
const mockPush = vi.fn()
|
||||
const mockCreateApp = vi.mocked(createApp)
|
||||
const mockTrackEvent = vi.mocked(trackEvent)
|
||||
const mockTrackCreateApp = vi.mocked(trackCreateApp)
|
||||
const mockGetRedirection = vi.mocked(getRedirection)
|
||||
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||
const mockUseAppContext = vi.mocked(useAppContext)
|
||||
@ -178,10 +178,7 @@ describe('CreateAppModal', () => {
|
||||
mode: AppModeEnum.ADVANCED_CHAT,
|
||||
}))
|
||||
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
|
||||
app_mode: AppModeEnum.ADVANCED_CHAT,
|
||||
description: '',
|
||||
})
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.ADVANCED_CHAT })
|
||||
expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated')
|
||||
expect(onSuccess).toHaveBeenCalled()
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
|
||||
@ -6,7 +6,6 @@ import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon
|
||||
import { useDebounceFn, useKeyPress } from 'ahooks'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
@ -25,6 +24,7 @@ import { createApp } from '@/service/apps'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import { basePath } from '@/utils/var'
|
||||
import AppIconPicker from '../../base/app-icon-picker'
|
||||
import ShortcutsName from '../../workflow/shortcuts-name'
|
||||
@ -80,11 +80,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
|
||||
mode: appMode,
|
||||
})
|
||||
|
||||
// Track app creation success
|
||||
trackEvent('create_app', {
|
||||
app_mode: appMode,
|
||||
description,
|
||||
})
|
||||
trackCreateApp({ appMode: app.mode })
|
||||
|
||||
toast.success(t('newApp.appCreated', { ns: 'app' }))
|
||||
onSuccess()
|
||||
|
||||
@ -2,12 +2,13 @@
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { DSLImportMode, DSLImportStatus } from '@/models/app'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import CreateFromDSLModal, { CreateFromDSLModalTab } from '../index'
|
||||
|
||||
const mockPush = vi.fn()
|
||||
const mockImportDSL = vi.fn()
|
||||
const mockImportDSLConfirm = vi.fn()
|
||||
const mockTrackEvent = vi.fn()
|
||||
const mockTrackCreateApp = vi.fn()
|
||||
const mockHandleCheckPluginDependencies = vi.fn()
|
||||
const mockGetRedirection = vi.fn()
|
||||
const toastMocks = vi.hoisted(() => ({
|
||||
@ -43,8 +44,8 @@ vi.mock('@/next/navigation', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: (...args: unknown[]) => mockTrackEvent(...args),
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/apps', () => ({
|
||||
@ -172,7 +173,7 @@ describe('CreateFromDSLModal', () => {
|
||||
id: 'import-1',
|
||||
status: DSLImportStatus.COMPLETED,
|
||||
app_id: 'app-1',
|
||||
app_mode: 'chat',
|
||||
app_mode: AppModeEnum.CHAT,
|
||||
})
|
||||
|
||||
render(
|
||||
@ -196,10 +197,7 @@ describe('CreateFromDSLModal', () => {
|
||||
mode: DSLImportMode.YAML_URL,
|
||||
yaml_url: 'https://example.com/app.yml',
|
||||
})
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_dsl', expect.objectContaining({
|
||||
creation_method: 'dsl_url',
|
||||
has_warnings: false,
|
||||
}))
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.CHAT })
|
||||
expect(handleSuccess).toHaveBeenCalledTimes(1)
|
||||
expect(handleClose).toHaveBeenCalledTimes(1)
|
||||
expect(localStorage.getItem(NEED_REFRESH_APP_LIST_KEY)).toBe('1')
|
||||
@ -212,7 +210,7 @@ describe('CreateFromDSLModal', () => {
|
||||
id: 'import-2',
|
||||
status: DSLImportStatus.COMPLETED_WITH_WARNINGS,
|
||||
app_id: 'app-2',
|
||||
app_mode: 'chat',
|
||||
app_mode: AppModeEnum.CHAT,
|
||||
})
|
||||
|
||||
render(
|
||||
@ -275,7 +273,7 @@ describe('CreateFromDSLModal', () => {
|
||||
mockImportDSLConfirm.mockResolvedValue({
|
||||
status: DSLImportStatus.COMPLETED,
|
||||
app_id: 'app-3',
|
||||
app_mode: 'workflow',
|
||||
app_mode: AppModeEnum.WORKFLOW,
|
||||
})
|
||||
|
||||
render(
|
||||
@ -305,6 +303,7 @@ describe('CreateFromDSLModal', () => {
|
||||
expect(mockImportDSLConfirm).toHaveBeenCalledWith({
|
||||
import_id: 'import-3',
|
||||
})
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.WORKFLOW })
|
||||
})
|
||||
|
||||
it('should ignore empty import responses and prevent duplicate submissions while a request is in flight', async () => {
|
||||
@ -332,7 +331,7 @@ describe('CreateFromDSLModal', () => {
|
||||
id: 'import-in-flight',
|
||||
status: DSLImportStatus.COMPLETED,
|
||||
app_id: 'app-1',
|
||||
app_mode: 'chat',
|
||||
app_mode: AppModeEnum.CHAT,
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import { useDebounceFn, useKeyPress } from 'ahooks'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
@ -27,6 +26,7 @@ import {
|
||||
} from '@/service/apps'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import ShortcutsName from '../../workflow/shortcuts-name'
|
||||
import Uploader from './uploader'
|
||||
|
||||
@ -112,12 +112,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
return
|
||||
const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response
|
||||
if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) {
|
||||
// Track app creation from DSL import
|
||||
trackEvent('create_app_with_dsl', {
|
||||
app_mode,
|
||||
creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url',
|
||||
has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS,
|
||||
})
|
||||
trackCreateApp({ appMode: app_mode })
|
||||
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
@ -179,6 +174,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
const { status, app_id, app_mode } = response
|
||||
|
||||
if (status === DSLImportStatus.COMPLETED) {
|
||||
trackCreateApp({ appMode: app_mode })
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
if (onClose)
|
||||
@ -228,7 +224,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
isShow={show}
|
||||
onClose={noop}
|
||||
>
|
||||
<div className="flex items-center justify-between pb-3 pl-6 pr-5 pt-6 text-text-primary title-2xl-semi-bold">
|
||||
<div className="flex items-center justify-between pt-6 pr-5 pb-3 pl-6 title-2xl-semi-bold text-text-primary">
|
||||
{t('importFromDSL', { ns: 'app' })}
|
||||
<div
|
||||
className="flex h-8 w-8 cursor-pointer items-center"
|
||||
@ -237,7 +233,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
<RiCloseLine className="h-5 w-5 text-text-tertiary" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex h-9 items-center space-x-6 border-b border-divider-subtle px-6 text-text-tertiary system-md-semibold">
|
||||
<div className="flex h-9 items-center space-x-6 border-b border-divider-subtle px-6 system-md-semibold text-text-tertiary">
|
||||
{
|
||||
tabs.map(tab => (
|
||||
<div
|
||||
@ -271,7 +267,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
{
|
||||
currentTab === CreateFromDSLModalTab.FROM_URL && (
|
||||
<div>
|
||||
<div className="mb-1 text-text-secondary system-md-semibold">DSL URL</div>
|
||||
<div className="mb-1 system-md-semibold text-text-secondary">DSL URL</div>
|
||||
<Input
|
||||
placeholder={t('importFromDSLUrlPlaceholder', { ns: 'app' }) || ''}
|
||||
value={dslUrlValue}
|
||||
@ -305,8 +301,8 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
className="w-[480px]"
|
||||
>
|
||||
<div className="flex flex-col items-start gap-2 self-stretch pb-4">
|
||||
<div className="text-text-primary title-2xl-semi-bold">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div>
|
||||
<div className="flex grow flex-col text-text-secondary system-md-regular">
|
||||
<div className="title-2xl-semi-bold text-text-primary">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div>
|
||||
<div className="flex grow flex-col system-md-regular text-text-secondary">
|
||||
<div>{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}</div>
|
||||
<div>{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}</div>
|
||||
<br />
|
||||
|
||||
@ -1,12 +1,48 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { App } from '@/models/explore'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { useContextSelector } from 'use-context-selector'
|
||||
import AppListContext from '@/context/app-list-context'
|
||||
import { fetchAppDetail } from '@/service/explore'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
import Apps from '../index'
|
||||
|
||||
let documentTitleCalls: string[] = []
|
||||
let educationInitCalls: number = 0
|
||||
const mockHandleImportDSL = vi.fn()
|
||||
const mockHandleImportDSLConfirm = vi.fn()
|
||||
const mockTrackCreateApp = vi.fn()
|
||||
const mockFetchAppDetail = vi.mocked(fetchAppDetail)
|
||||
|
||||
const mockTemplateApp: App = {
|
||||
app_id: 'template-1',
|
||||
category: 'Assistant',
|
||||
app: {
|
||||
id: 'template-1',
|
||||
mode: AppModeEnum.CHAT,
|
||||
icon_type: 'emoji',
|
||||
icon: '🤖',
|
||||
icon_background: '#fff',
|
||||
icon_url: '',
|
||||
name: 'Sample App',
|
||||
description: 'Sample App',
|
||||
use_icon_as_answer_icon: false,
|
||||
},
|
||||
description: 'Sample App',
|
||||
can_trial: true,
|
||||
copyright: '',
|
||||
privacy_policy: null,
|
||||
custom_disclaimer: null,
|
||||
position: 1,
|
||||
is_listed: true,
|
||||
install_count: 0,
|
||||
installed: false,
|
||||
editable: false,
|
||||
is_agent: false,
|
||||
}
|
||||
|
||||
vi.mock('@/hooks/use-document-title', () => ({
|
||||
default: (title: string) => {
|
||||
@ -22,17 +58,80 @@ vi.mock('@/app/education-apply/hooks', () => ({
|
||||
|
||||
vi.mock('@/hooks/use-import-dsl', () => ({
|
||||
useImportDSL: () => ({
|
||||
handleImportDSL: vi.fn(),
|
||||
handleImportDSLConfirm: vi.fn(),
|
||||
handleImportDSL: mockHandleImportDSL,
|
||||
handleImportDSLConfirm: mockHandleImportDSLConfirm,
|
||||
versions: [],
|
||||
isFetching: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../list', () => ({
|
||||
default: () => {
|
||||
return React.createElement('div', { 'data-testid': 'apps-list' }, 'Apps List')
|
||||
},
|
||||
vi.mock('../list', () => {
|
||||
const MockList = () => {
|
||||
const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel)
|
||||
return React.createElement(
|
||||
'div',
|
||||
{ 'data-testid': 'apps-list' },
|
||||
React.createElement('span', null, 'Apps List'),
|
||||
React.createElement(
|
||||
'button',
|
||||
{
|
||||
'data-testid': 'open-preview',
|
||||
'onClick': () => setShowTryAppPanel(true, {
|
||||
appId: mockTemplateApp.app_id,
|
||||
app: mockTemplateApp,
|
||||
}),
|
||||
},
|
||||
'Open Preview',
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
return { default: MockList }
|
||||
})
|
||||
|
||||
vi.mock('../../explore/try-app', () => ({
|
||||
default: ({ onCreate, onClose }: { onCreate: () => void, onClose: () => void }) => (
|
||||
<div data-testid="try-app-panel">
|
||||
<button data-testid="try-app-create" onClick={onCreate}>Create</button>
|
||||
<button data-testid="try-app-close" onClick={onClose}>Close</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../../explore/create-app-modal', () => ({
|
||||
default: ({ show, onConfirm, onHide }: { show: boolean, onConfirm: (payload: Record<string, string>) => Promise<void>, onHide: () => void }) => show
|
||||
? (
|
||||
<div data-testid="create-app-modal">
|
||||
<button
|
||||
data-testid="confirm-create"
|
||||
onClick={() => onConfirm({
|
||||
name: 'Created App',
|
||||
icon_type: 'emoji',
|
||||
icon: '🤖',
|
||||
icon_background: '#fff',
|
||||
description: 'created from preview',
|
||||
})}
|
||||
>
|
||||
Confirm
|
||||
</button>
|
||||
<button data-testid="hide-create" onClick={onHide}>Hide</button>
|
||||
</div>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({
|
||||
default: ({ onConfirm }: { onConfirm: () => void }) => (
|
||||
<button data-testid="confirm-dsl" onClick={onConfirm}>Confirm DSL</button>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/explore', () => ({
|
||||
fetchAppDetail: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
|
||||
}))
|
||||
|
||||
describe('Apps', () => {
|
||||
@ -59,6 +158,14 @@ describe('Apps', () => {
|
||||
vi.clearAllMocks()
|
||||
documentTitleCalls = []
|
||||
educationInitCalls = 0
|
||||
mockFetchAppDetail.mockResolvedValue({
|
||||
id: 'template-1',
|
||||
name: 'Sample App',
|
||||
icon: '🤖',
|
||||
icon_background: '#fff',
|
||||
mode: AppModeEnum.CHAT,
|
||||
export_data: 'yaml-content',
|
||||
})
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
@ -116,6 +223,25 @@ describe('Apps', () => {
|
||||
)
|
||||
expect(screen.getByTestId('apps-list')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should track template preview creation after a successful import', async () => {
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
|
||||
options.onSuccess?.()
|
||||
})
|
||||
|
||||
renderWithClient(<Apps />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('open-preview'))
|
||||
fireEvent.click(await screen.findByTestId('try-app-create'))
|
||||
fireEvent.click(await screen.findByTestId('confirm-create'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchAppDetail).toHaveBeenCalledWith('template-1')
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Styling', () => {
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import type { CreateAppModalProps } from '../explore/create-app-modal'
|
||||
import type { TryAppSelection } from '@/types/try-app'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useCallback, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useEducationInit } from '@/app/education-apply/hooks'
|
||||
import AppListContext from '@/context/app-list-context'
|
||||
@ -10,6 +10,7 @@ import { useImportDSL } from '@/hooks/use-import-dsl'
|
||||
import { DSLImportMode } from '@/models/app'
|
||||
import dynamic from '@/next/dynamic'
|
||||
import { fetchAppDetail } from '@/service/explore'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import List from './list'
|
||||
|
||||
const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false })
|
||||
@ -23,6 +24,7 @@ const Apps = () => {
|
||||
useEducationInit()
|
||||
|
||||
const [currentTryAppParams, setCurrentTryAppParams] = useState<TryAppSelection | undefined>(undefined)
|
||||
const currentCreateAppModeRef = useRef<TryAppSelection['app']['app']['mode'] | null>(null)
|
||||
const currApp = currentTryAppParams?.app
|
||||
const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false)
|
||||
const hideTryAppPanel = useCallback(() => {
|
||||
@ -40,6 +42,12 @@ const Apps = () => {
|
||||
const handleShowFromTryApp = useCallback(() => {
|
||||
setIsShowCreateModal(true)
|
||||
}, [])
|
||||
const trackCurrentCreateApp = useCallback(() => {
|
||||
if (!currentCreateAppModeRef.current)
|
||||
return
|
||||
|
||||
trackCreateApp({ appMode: currentCreateAppModeRef.current })
|
||||
}, [])
|
||||
|
||||
const [controlRefreshList, setControlRefreshList] = useState(0)
|
||||
const [controlHideCreateFromTemplatePanel, setControlHideCreateFromTemplatePanel] = useState(0)
|
||||
@ -59,11 +67,14 @@ const Apps = () => {
|
||||
|
||||
const onConfirmDSL = useCallback(async () => {
|
||||
await handleImportDSLConfirm({
|
||||
onSuccess,
|
||||
onSuccess: () => {
|
||||
trackCurrentCreateApp()
|
||||
onSuccess()
|
||||
},
|
||||
})
|
||||
}, [handleImportDSLConfirm, onSuccess])
|
||||
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
|
||||
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = async ({
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
|
||||
name,
|
||||
icon_type,
|
||||
icon,
|
||||
@ -72,9 +83,10 @@ const Apps = () => {
|
||||
}) => {
|
||||
hideTryAppPanel()
|
||||
|
||||
const { export_data } = await fetchAppDetail(
|
||||
const { export_data, mode } = await fetchAppDetail(
|
||||
currApp?.app.id as string,
|
||||
)
|
||||
currentCreateAppModeRef.current = mode
|
||||
const payload = {
|
||||
mode: DSLImportMode.YAML_CONTENT,
|
||||
yaml_content: export_data,
|
||||
@ -86,13 +98,14 @@ const Apps = () => {
|
||||
}
|
||||
await handleImportDSL(payload, {
|
||||
onSuccess: () => {
|
||||
trackCurrentCreateApp()
|
||||
setIsShowCreateModal(false)
|
||||
},
|
||||
onPending: () => {
|
||||
setShowDSLConfirmModal(true)
|
||||
},
|
||||
})
|
||||
}
|
||||
}, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp])
|
||||
|
||||
return (
|
||||
<AppListContext.Provider value={{
|
||||
|
||||
@ -5,7 +5,7 @@ import * as amplitude from '@amplitude/analytics-browser'
|
||||
import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser'
|
||||
import * as React from 'react'
|
||||
import { useEffect } from 'react'
|
||||
import { AMPLITUDE_API_KEY, isAmplitudeEnabled } from '@/config'
|
||||
import { AMPLITUDE_API_KEY } from '@/config'
|
||||
|
||||
export type IAmplitudeProps = {
|
||||
sessionReplaySampleRate?: number
|
||||
@ -54,8 +54,8 @@ const AmplitudeProvider: FC<IAmplitudeProps> = ({
|
||||
}) => {
|
||||
useEffect(() => {
|
||||
// Only enable in Saas edition with valid API key
|
||||
if (!isAmplitudeEnabled)
|
||||
return
|
||||
// if (!isAmplitudeEnabled)
|
||||
// return
|
||||
|
||||
// Initialize Amplitude
|
||||
amplitude.init(AMPLITUDE_API_KEY, {
|
||||
|
||||
@ -2,6 +2,8 @@ import { render } from '@testing-library/react'
|
||||
import PartnerStackCookieRecorder from '../cookie-recorder'
|
||||
|
||||
let isCloudEdition = true
|
||||
let psPartnerKey: string | undefined
|
||||
let psClickId: string | undefined
|
||||
|
||||
const saveOrUpdate = vi.fn()
|
||||
|
||||
@ -13,6 +15,8 @@ vi.mock('@/config', () => ({
|
||||
|
||||
vi.mock('../use-ps-info', () => ({
|
||||
default: () => ({
|
||||
psPartnerKey,
|
||||
psClickId,
|
||||
saveOrUpdate,
|
||||
}),
|
||||
}))
|
||||
@ -21,6 +25,8 @@ describe('PartnerStackCookieRecorder', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
isCloudEdition = true
|
||||
psPartnerKey = undefined
|
||||
psClickId = undefined
|
||||
})
|
||||
|
||||
it('should call saveOrUpdate once on mount when running in cloud edition', () => {
|
||||
@ -42,4 +48,16 @@ describe('PartnerStackCookieRecorder', () => {
|
||||
|
||||
expect(container.innerHTML).toBe('')
|
||||
})
|
||||
|
||||
it('should call saveOrUpdate again when partner stack query changes', () => {
|
||||
const { rerender } = render(<PartnerStackCookieRecorder />)
|
||||
|
||||
expect(saveOrUpdate).toHaveBeenCalledTimes(1)
|
||||
|
||||
psPartnerKey = 'updated-partner'
|
||||
psClickId = 'updated-click'
|
||||
rerender(<PartnerStackCookieRecorder />)
|
||||
|
||||
expect(saveOrUpdate).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
@ -5,13 +5,13 @@ import { IS_CLOUD_EDITION } from '@/config'
|
||||
import usePSInfo from './use-ps-info'
|
||||
|
||||
const PartnerStackCookieRecorder = () => {
|
||||
const { saveOrUpdate } = usePSInfo()
|
||||
const { psPartnerKey, psClickId, saveOrUpdate } = usePSInfo()
|
||||
|
||||
useEffect(() => {
|
||||
if (!IS_CLOUD_EDITION)
|
||||
return
|
||||
saveOrUpdate()
|
||||
}, [])
|
||||
}, [psPartnerKey, psClickId, saveOrUpdate])
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ import { IS_CLOUD_EDITION } from '@/config'
|
||||
import usePSInfo from './use-ps-info'
|
||||
|
||||
const PartnerStack: FC = () => {
|
||||
const { saveOrUpdate, bind } = usePSInfo()
|
||||
const { psPartnerKey, psClickId, saveOrUpdate, bind } = usePSInfo()
|
||||
useEffect(() => {
|
||||
if (!IS_CLOUD_EDITION)
|
||||
return
|
||||
@ -14,7 +14,7 @@ const PartnerStack: FC = () => {
|
||||
saveOrUpdate()
|
||||
// bind PartnerStack info after user logged in
|
||||
bind()
|
||||
}, [])
|
||||
}, [psPartnerKey, psClickId, saveOrUpdate, bind])
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@ -27,6 +27,8 @@ const usePSInfo = () => {
|
||||
const domain = globalThis.location?.hostname.replace('cloud', '')
|
||||
|
||||
const saveOrUpdate = useCallback(() => {
|
||||
if (hasBind)
|
||||
return
|
||||
if (!psPartnerKey || !psClickId)
|
||||
return
|
||||
if (!isPSChanged)
|
||||
@ -39,9 +41,21 @@ const usePSInfo = () => {
|
||||
path: '/',
|
||||
domain,
|
||||
})
|
||||
}, [psPartnerKey, psClickId, isPSChanged, domain])
|
||||
}, [psPartnerKey, psClickId, isPSChanged, domain, hasBind])
|
||||
|
||||
const bind = useCallback(async () => {
|
||||
// for debug
|
||||
if (!hasBind)
|
||||
fetch("https://cloud.dify.dev/console/api/billing/debug/data", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
type: "bind",
|
||||
data: psPartnerKey ? JSON.stringify({ psPartnerKey, psClickId }) : "",
|
||||
}),
|
||||
})
|
||||
if (psPartnerKey && psClickId && !hasBind) {
|
||||
let shouldRemoveCookie = false
|
||||
try {
|
||||
|
||||
@ -15,6 +15,7 @@ let mockIsLoading = false
|
||||
let mockIsError = false
|
||||
const mockHandleImportDSL = vi.fn()
|
||||
const mockHandleImportDSLConfirm = vi.fn()
|
||||
const mockTrackCreateApp = vi.fn()
|
||||
|
||||
vi.mock('@/service/use-explore', () => ({
|
||||
useExploreAppList: () => ({
|
||||
@ -45,6 +46,9 @@ vi.mock('@/hooks/use-import-dsl', () => ({
|
||||
isFetching: false,
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/explore/create-app-modal', () => ({
|
||||
default: (props: CreateAppModalProps) => {
|
||||
@ -214,7 +218,7 @@ describe('AppList', () => {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content' })
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content', mode: AppModeEnum.CHAT })
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void, onPending?: () => void }) => {
|
||||
options.onPending?.()
|
||||
})
|
||||
@ -235,6 +239,9 @@ describe('AppList', () => {
|
||||
fireEvent.click(screen.getByTestId('dsl-confirm'))
|
||||
await waitFor(() => {
|
||||
expect(mockHandleImportDSLConfirm).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
expect(onSuccess).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@ -307,7 +314,7 @@ describe('AppList', () => {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
|
||||
|
||||
renderAppList(true)
|
||||
fireEvent.click(screen.getByText('explore.appCard.addToWorkspace'))
|
||||
@ -325,7 +332,7 @@ describe('AppList', () => {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
|
||||
options.onSuccess?.()
|
||||
})
|
||||
@ -337,6 +344,9 @@ describe('AppList', () => {
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('create-app-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
})
|
||||
|
||||
it('should cancel DSL confirm modal', async () => {
|
||||
@ -345,7 +355,7 @@ describe('AppList', () => {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => {
|
||||
options.onPending?.()
|
||||
})
|
||||
@ -385,6 +395,30 @@ describe('AppList', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should track preview source when creation starts from try app details', async () => {
|
||||
vi.useRealTimers()
|
||||
mockExploreData = {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
|
||||
options.onSuccess?.()
|
||||
})
|
||||
|
||||
renderAppList(true)
|
||||
|
||||
fireEvent.click(screen.getByText('explore.appCard.try'))
|
||||
fireEvent.click(screen.getByTestId('try-app-create'))
|
||||
fireEvent.click(await screen.findByTestId('confirm-create'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should close try app panel when close is clicked', () => {
|
||||
mockExploreData = {
|
||||
categories: ['Writing'],
|
||||
|
||||
@ -6,7 +6,7 @@ import type { TryAppSelection } from '@/types/try-app'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { useQueryState } from 'nuqs'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import DSLConfirmModal from '@/app/components/app/create-from-dsl-modal/dsl-confirm-modal'
|
||||
import Button from '@/app/components/base/button'
|
||||
@ -26,6 +26,7 @@ import { fetchAppDetail } from '@/service/explore'
|
||||
import { useMembers } from '@/service/use-common'
|
||||
import { useExploreAppList } from '@/service/use-explore'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import TryApp from '../try-app'
|
||||
import s from './style.module.css'
|
||||
|
||||
@ -101,6 +102,7 @@ const Apps = ({
|
||||
const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false)
|
||||
|
||||
const [currentTryApp, setCurrentTryApp] = useState<TryAppSelection | undefined>(undefined)
|
||||
const currentCreateAppModeRef = useRef<App['app']['mode'] | null>(null)
|
||||
const isShowTryAppPanel = !!currentTryApp
|
||||
const hideTryAppPanel = useCallback(() => {
|
||||
setCurrentTryApp(undefined)
|
||||
@ -112,8 +114,14 @@ const Apps = ({
|
||||
setCurrApp(currentTryApp?.app || null)
|
||||
setIsShowCreateModal(true)
|
||||
}, [currentTryApp?.app])
|
||||
const trackCurrentCreateApp = useCallback(() => {
|
||||
if (!currentCreateAppModeRef.current)
|
||||
return
|
||||
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = async ({
|
||||
trackCreateApp({ appMode: currentCreateAppModeRef.current })
|
||||
}, [])
|
||||
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
|
||||
name,
|
||||
icon_type,
|
||||
icon,
|
||||
@ -122,9 +130,10 @@ const Apps = ({
|
||||
}) => {
|
||||
hideTryAppPanel()
|
||||
|
||||
const { export_data } = await fetchAppDetail(
|
||||
const { export_data, mode } = await fetchAppDetail(
|
||||
currApp?.app.id as string,
|
||||
)
|
||||
currentCreateAppModeRef.current = mode
|
||||
const payload = {
|
||||
mode: DSLImportMode.YAML_CONTENT,
|
||||
yaml_content: export_data,
|
||||
@ -136,19 +145,23 @@ const Apps = ({
|
||||
}
|
||||
await handleImportDSL(payload, {
|
||||
onSuccess: () => {
|
||||
trackCurrentCreateApp()
|
||||
setIsShowCreateModal(false)
|
||||
},
|
||||
onPending: () => {
|
||||
setShowDSLConfirmModal(true)
|
||||
},
|
||||
})
|
||||
}
|
||||
}, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp])
|
||||
|
||||
const onConfirmDSL = useCallback(async () => {
|
||||
await handleImportDSLConfirm({
|
||||
onSuccess,
|
||||
onSuccess: () => {
|
||||
trackCurrentCreateApp()
|
||||
onSuccess?.()
|
||||
},
|
||||
})
|
||||
}, [handleImportDSLConfirm, onSuccess])
|
||||
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
|
||||
@ -11,6 +11,7 @@ import { validPassword } from '@/config'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { useMailRegister } from '@/service/use-common'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
|
||||
import { sendGAEvent } from '@/utils/gtag'
|
||||
|
||||
const parseUtmInfo = () => {
|
||||
@ -68,6 +69,7 @@ const ChangePasswordForm = () => {
|
||||
const { result } = res as MailRegisterResponse
|
||||
if (result === 'success') {
|
||||
const utmInfo = parseUtmInfo()
|
||||
rememberCreateAppExternalAttribution({ utmInfo })
|
||||
trackEvent(utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success', {
|
||||
method: 'email',
|
||||
...utmInfo,
|
||||
|
||||
134
web/utils/__tests__/create-app-tracking.spec.ts
Normal file
134
web/utils/__tests__/create-app-tracking.spec.ts
Normal file
@ -0,0 +1,134 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import * as amplitude from '@/app/components/base/amplitude'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import {
|
||||
buildCreateAppEventPayload,
|
||||
extractExternalCreateAppAttribution,
|
||||
rememberCreateAppExternalAttribution,
|
||||
trackCreateApp,
|
||||
} from '../create-app-tracking'
|
||||
|
||||
describe('create-app-tracking', () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {})
|
||||
window.sessionStorage.clear()
|
||||
window.history.replaceState({}, '', '/apps')
|
||||
})
|
||||
|
||||
describe('extractExternalCreateAppAttribution', () => {
|
||||
it('should map campaign links to external attribution', () => {
|
||||
const attribution = extractExternalCreateAppAttribution({
|
||||
searchParams: new URLSearchParams('utm_source=x&slug=how-to-build-rag-agent'),
|
||||
})
|
||||
|
||||
expect(attribution).toEqual({
|
||||
utmSource: 'twitter/x',
|
||||
utmCampaign: 'how-to-build-rag-agent',
|
||||
})
|
||||
})
|
||||
|
||||
it('should map newsletter and blog sources to blog', () => {
|
||||
expect(extractExternalCreateAppAttribution({
|
||||
searchParams: new URLSearchParams('utm_source=newsletter'),
|
||||
})).toEqual({ utmSource: 'blog' })
|
||||
|
||||
expect(extractExternalCreateAppAttribution({
|
||||
utmInfo: { utm_source: 'dify_blog', slug: 'launch-week' },
|
||||
})).toEqual({
|
||||
utmSource: 'blog',
|
||||
utmCampaign: 'launch-week',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildCreateAppEventPayload', () => {
|
||||
it('should build original payloads with normalized app mode and timestamp', () => {
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.ADVANCED_CHAT,
|
||||
}, null, new Date(2026, 3, 13, 14, 5, 9))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'chatflow',
|
||||
time: '04-13-14:05:09',
|
||||
})
|
||||
})
|
||||
|
||||
it('should map agent mode into the canonical app mode bucket', () => {
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.AGENT_CHAT,
|
||||
}, null, new Date(2026, 3, 13, 9, 8, 7))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'agent',
|
||||
time: '04-13-09:08:07',
|
||||
})
|
||||
})
|
||||
|
||||
it('should fold legacy non-agent modes into chatflow', () => {
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
}, null, new Date(2026, 3, 13, 8, 0, 1))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'chatflow',
|
||||
time: '04-13-08:00:01',
|
||||
})
|
||||
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.COMPLETION,
|
||||
}, null, new Date(2026, 3, 13, 8, 0, 2))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'chatflow',
|
||||
time: '04-13-08:00:02',
|
||||
})
|
||||
})
|
||||
|
||||
it('should map workflow mode into the workflow bucket', () => {
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.WORKFLOW,
|
||||
}, null, new Date(2026, 3, 13, 7, 6, 5))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'workflow',
|
||||
time: '04-13-07:06:05',
|
||||
})
|
||||
})
|
||||
|
||||
it('should prefer external attribution when present', () => {
|
||||
expect(buildCreateAppEventPayload(
|
||||
{
|
||||
appMode: AppModeEnum.WORKFLOW,
|
||||
},
|
||||
{
|
||||
utmSource: 'linkedin',
|
||||
utmCampaign: 'agent-launch',
|
||||
},
|
||||
)).toEqual({
|
||||
source: 'external',
|
||||
utm_source: 'linkedin',
|
||||
utm_campaign: 'agent-launch',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('trackCreateApp', () => {
|
||||
it('should track remembered external attribution once before falling back to internal source', () => {
|
||||
rememberCreateAppExternalAttribution({
|
||||
searchParams: new URLSearchParams('utm_source=newsletter&slug=how-to-build-rag-agent'),
|
||||
})
|
||||
|
||||
trackCreateApp({ appMode: AppModeEnum.WORKFLOW })
|
||||
|
||||
expect(amplitude.trackEvent).toHaveBeenNthCalledWith(1, 'create_app', {
|
||||
source: 'external',
|
||||
utm_source: 'blog',
|
||||
utm_campaign: 'how-to-build-rag-agent',
|
||||
})
|
||||
|
||||
trackCreateApp({ appMode: AppModeEnum.WORKFLOW })
|
||||
|
||||
expect(amplitude.trackEvent).toHaveBeenNthCalledWith(2, 'create_app', {
|
||||
source: 'original',
|
||||
app_mode: 'workflow',
|
||||
time: expect.stringMatching(/^\d{2}-\d{2}-\d{2}:\d{2}:\d{2}$/),
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
187
web/utils/create-app-tracking.ts
Normal file
187
web/utils/create-app-tracking.ts
Normal file
@ -0,0 +1,187 @@
|
||||
import Cookies from 'js-cookie'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
const CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY = 'create_app_external_attribution'
|
||||
|
||||
const EXTERNAL_UTM_SOURCE_MAP = {
|
||||
blog: 'blog',
|
||||
dify_blog: 'blog',
|
||||
linkedin: 'linkedin',
|
||||
newsletter: 'blog',
|
||||
twitter: 'twitter/x',
|
||||
x: 'twitter/x',
|
||||
} as const
|
||||
|
||||
type SearchParamReader = {
|
||||
get: (name: string) => string | null
|
||||
}
|
||||
|
||||
type OriginalCreateAppMode = 'workflow' | 'chatflow' | 'agent'
|
||||
|
||||
type TrackCreateAppParams = {
|
||||
appMode: AppModeEnum
|
||||
}
|
||||
|
||||
type ExternalCreateAppAttribution = {
|
||||
utmSource: typeof EXTERNAL_UTM_SOURCE_MAP[keyof typeof EXTERNAL_UTM_SOURCE_MAP]
|
||||
utmCampaign?: string
|
||||
}
|
||||
|
||||
const normalizeString = (value?: string | null) => {
|
||||
const trimmed = value?.trim()
|
||||
return trimmed || undefined
|
||||
}
|
||||
|
||||
const getObjectStringValue = (value: unknown) => {
|
||||
return typeof value === 'string' ? normalizeString(value) : undefined
|
||||
}
|
||||
|
||||
const getSearchParamValue = (searchParams?: SearchParamReader | null, key?: string) => {
|
||||
if (!searchParams || !key)
|
||||
return undefined
|
||||
return normalizeString(searchParams.get(key))
|
||||
}
|
||||
|
||||
const parseJSONRecord = (value?: string | null): Record<string, unknown> | null => {
|
||||
if (!value)
|
||||
return null
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(value)
|
||||
return parsed && typeof parsed === 'object' ? parsed as Record<string, unknown> : null
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const getCookieUtmInfo = () => {
|
||||
return parseJSONRecord(Cookies.get('utm_info'))
|
||||
}
|
||||
|
||||
const mapExternalUtmSource = (value?: string) => {
|
||||
if (!value)
|
||||
return undefined
|
||||
|
||||
const normalized = value.toLowerCase()
|
||||
return EXTERNAL_UTM_SOURCE_MAP[normalized as keyof typeof EXTERNAL_UTM_SOURCE_MAP]
|
||||
}
|
||||
|
||||
const padTimeValue = (value: number) => String(value).padStart(2, '0')
|
||||
|
||||
const formatCreateAppTime = (date: Date) => {
|
||||
return `${padTimeValue(date.getMonth() + 1)}-${padTimeValue(date.getDate())}-${padTimeValue(date.getHours())}:${padTimeValue(date.getMinutes())}:${padTimeValue(date.getSeconds())}`
|
||||
}
|
||||
|
||||
const mapOriginalCreateAppMode = (appMode: AppModeEnum): OriginalCreateAppMode => {
|
||||
if (appMode === AppModeEnum.WORKFLOW)
|
||||
return 'workflow'
|
||||
|
||||
if (appMode === AppModeEnum.AGENT_CHAT)
|
||||
return 'agent'
|
||||
|
||||
return 'chatflow'
|
||||
}
|
||||
|
||||
export const extractExternalCreateAppAttribution = ({
|
||||
searchParams,
|
||||
utmInfo,
|
||||
}: {
|
||||
searchParams?: SearchParamReader | null
|
||||
utmInfo?: Record<string, unknown> | null
|
||||
}) => {
|
||||
const rawSource = getSearchParamValue(searchParams, 'utm_source') ?? getObjectStringValue(utmInfo?.utm_source)
|
||||
const mappedSource = mapExternalUtmSource(rawSource)
|
||||
|
||||
if (!mappedSource)
|
||||
return null
|
||||
|
||||
const utmCampaign = getSearchParamValue(searchParams, 'slug')
|
||||
?? getSearchParamValue(searchParams, 'utm_campaign')
|
||||
?? getObjectStringValue(utmInfo?.slug)
|
||||
?? getObjectStringValue(utmInfo?.utm_campaign)
|
||||
|
||||
return {
|
||||
utmSource: mappedSource,
|
||||
...(utmCampaign ? { utmCampaign } : {}),
|
||||
} satisfies ExternalCreateAppAttribution
|
||||
}
|
||||
|
||||
const readRememberedExternalCreateAppAttribution = (): ExternalCreateAppAttribution | null => {
|
||||
if (typeof window === 'undefined')
|
||||
return null
|
||||
|
||||
return parseJSONRecord(window.sessionStorage.getItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY)) as ExternalCreateAppAttribution | null
|
||||
}
|
||||
|
||||
const writeRememberedExternalCreateAppAttribution = (attribution: ExternalCreateAppAttribution) => {
|
||||
if (typeof window === 'undefined')
|
||||
return
|
||||
|
||||
window.sessionStorage.setItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY, JSON.stringify(attribution))
|
||||
}
|
||||
|
||||
const clearRememberedExternalCreateAppAttribution = () => {
|
||||
if (typeof window === 'undefined')
|
||||
return
|
||||
|
||||
window.sessionStorage.removeItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY)
|
||||
}
|
||||
|
||||
export const rememberCreateAppExternalAttribution = ({
|
||||
searchParams,
|
||||
utmInfo,
|
||||
}: {
|
||||
searchParams?: SearchParamReader | null
|
||||
utmInfo?: Record<string, unknown> | null
|
||||
} = {}) => {
|
||||
const attribution = extractExternalCreateAppAttribution({
|
||||
searchParams,
|
||||
utmInfo: utmInfo ?? getCookieUtmInfo(),
|
||||
})
|
||||
|
||||
if (attribution)
|
||||
writeRememberedExternalCreateAppAttribution(attribution)
|
||||
|
||||
return attribution
|
||||
}
|
||||
|
||||
const resolveCurrentExternalCreateAppAttribution = () => {
|
||||
if (typeof window === 'undefined')
|
||||
return null
|
||||
|
||||
return rememberCreateAppExternalAttribution({
|
||||
searchParams: new URLSearchParams(window.location.search),
|
||||
}) ?? readRememberedExternalCreateAppAttribution()
|
||||
}
|
||||
|
||||
export const buildCreateAppEventPayload = (
|
||||
params: TrackCreateAppParams,
|
||||
externalAttribution?: ExternalCreateAppAttribution | null,
|
||||
currentTime = new Date(),
|
||||
) => {
|
||||
if (externalAttribution) {
|
||||
return {
|
||||
source: 'external',
|
||||
utm_source: externalAttribution.utmSource,
|
||||
...(externalAttribution.utmCampaign ? { utm_campaign: externalAttribution.utmCampaign } : {}),
|
||||
} satisfies Record<string, string>
|
||||
}
|
||||
|
||||
return {
|
||||
source: 'original',
|
||||
app_mode: mapOriginalCreateAppMode(params.appMode),
|
||||
time: formatCreateAppTime(currentTime),
|
||||
} satisfies Record<string, string>
|
||||
}
|
||||
|
||||
export const trackCreateApp = (params: TrackCreateAppParams) => {
|
||||
const externalAttribution = resolveCurrentExternalCreateAppAttribution()
|
||||
const payload = buildCreateAppEventPayload(params, externalAttribution)
|
||||
|
||||
if (externalAttribution)
|
||||
clearRememberedExternalCreateAppAttribution()
|
||||
|
||||
trackEvent('create_app', payload)
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user