refactor: quota v3 integration (#35436)

Co-authored-by: Yansong Zhang <916125788@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
hj24 2026-04-27 09:49:40 +08:00 committed by hj24
parent ef50e117da
commit a03d5b8ed3
20 changed files with 1961 additions and 529 deletions

View File

@ -1,56 +1,17 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -4,7 +4,7 @@ import logging
import threading
import uuid
from collections.abc import Callable, Generator, Mapping
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@ -88,7 +89,7 @@ class AppGenerateService:
def generate(
cls,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
@ -106,7 +107,7 @@ class AppGenerateService:
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@ -116,139 +117,150 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id,
)
elif app_model.mode == AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
generator = AdvancedChatAppGenerator()
quota_charge.commit()
effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
)
match effective_mode:
case AppMode.COMPLETION:
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
advanced_generator = AdvancedChatAppGenerator()
case AppMode.AGENT_CHAT:
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
advanced_generator.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id,
)
case AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
case AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
)
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
root_node_id=root_node_id,
workflow_run_id=str(uuid.uuid4()),
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
generator = AdvancedChatAppGenerator()
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id=request_id,
)
payload_json = payload.model_dump_json()
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
advanced_generator.generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
pause_state_config=pause_config,
)
),
request_id=request_id,
)
case AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
root_node_id=root_node_id,
workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
payload.workflow_run_id,
on_subscribe=on_subscribe,
WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=False,
root_node_id=root_node_id,
call_depth=0,
pause_state_config=pause_config,
),
),
request_id,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=False,
root_node_id=root_node_id,
call_depth=0,
pause_state_config=pause_config,
),
),
request_id,
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
case _:
raise ValueError(f"Invalid app mode {app_model.mode}")
except Exception:
quota_charge.refund()
rate_limit.exit(request_id)
@ -280,53 +292,83 @@ class AppGenerateService:
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
match app_model.mode:
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
)
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
case AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
raise ValueError(f"Invalid app mode {app_model.mode}")
case _:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_single_loop(
cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True
):
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
match app_model.mode:
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
)
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
case AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
raise ValueError(f"Invalid app mode {app_model.mode}")
case _:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_more_like_this(
cls,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator]:
) -> Mapping | Generator:
"""
Generate more like this
:param app_model: app model

View File

@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -88,7 +89,10 @@ class AsyncWorkflowService:
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
# 2. Get workflow
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
# commit read only session before starting the billig rpc call
session.commit()
# 3. Get dispatcher based on tenant subscription
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
@ -131,9 +135,10 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Check and consume quota
# 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
try:
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@ -153,13 +158,18 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
try:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED
@ -295,13 +305,21 @@ class AsyncWorkflowService:
return [log.to_dict() for log in logs]
@staticmethod
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
def _get_workflow(
workflow_service: WorkflowService,
app_model: App,
workflow_id: str | None = None,
session: Session | None = None,
) -> Workflow:
"""
Get workflow for the app
Args:
app_model: App model instance
workflow_id: Optional specific workflow ID
session: Reuse this SQLAlchemy session for the lookup when provided,
so the caller's explicit session bears the connection cost
instead of Flask's request-scoped ``db.session``.
Returns:
Workflow instance
@ -311,12 +329,12 @@ class AsyncWorkflowService:
"""
if workflow_id:
# Get specific published workflow
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
if not workflow:
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
else:
# Get default published workflow
workflow = workflow_service.get_published_workflow(app_model)
workflow = workflow_service.get_published_workflow(app_model, session=session)
if not workflow:
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")

View File

@ -25,6 +25,50 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _TenantFeatureQuota(TypedDict):
usage: int
limit: int
reset_date: NotRequired[int]
class TenantFeatureQuotaInfo(TypedDict):
"""Response of /quota/info.
NOTE (hj24):
- Same convention as BillingInfo: billing may return int fields as str,
always keep non-strict mode to auto-coerce.
"""
trigger_event: _TenantFeatureQuota
api_rate_limit: _TenantFeatureQuota
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
class _BillingQuota(TypedDict):
size: int
limit: int
@ -142,13 +186,65 @@ class BillingService:
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
params = {"tenant_id": tenant_id}
return _tenant_feature_quota_info_adapter.validate_python(
cls._send_request("GET", "/quota/info", params=params)
)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)

View File

@ -281,7 +281,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features_usage_info = BillingService.get_quota_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]

View File

@ -0,0 +1,233 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@ -37,6 +37,7 @@ from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
@ -758,45 +759,47 @@ class WebhookService:
Exception: If workflow execution fails
"""
try:
with Session(db.engine) as session:
# Prepare inputs for the webhook node
# The webhook node expects webhook_data in the inputs
workflow_inputs = cls.build_workflow_inputs(webhook_data)
workflow_inputs = cls.build_workflow_inputs(webhook_data)
# Create trigger data
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id, # Start from the webhook node
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id,
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
)
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
)
raise
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
# consume quota before triggering workflow execution
try:
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
try:
# NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
# trigger_workflow_async need to handle multipe session commits internally
with Session(db.engine, expire_on_commit=False) as session:
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
raise
# Trigger workflow execution asynchronously
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@ -132,31 +132,38 @@ class WorkflowService:
if workflow_id:
return self.get_published_workflow_by_id(app_model, workflow_id)
# fetch draft workflow by app_model
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.first()
.limit(1)
)
# return draft workflow
return workflow
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
def get_published_workflow_by_id(
self, app_model: App, workflow_id: str, session: Session | None = None
) -> Workflow | None:
"""
fetch published workflow by workflow_id
When ``session`` is provided, reuse it so callers that already hold a
Session avoid checking out an extra request-scoped ``db.session``
connection. Falls back to ``db.session`` for backward compatibility.
"""
workflow = (
db.session.query(Workflow)
bind = session if session is not None else db.session
workflow = bind.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id,
)
.first()
.limit(1)
)
if not workflow:
return None
@ -167,23 +174,27 @@ class WorkflowService:
)
return workflow
def get_published_workflow(self, app_model: App) -> Workflow | None:
def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None:
"""
Get published workflow
When ``session`` is provided, reuse it so callers that already hold a
Session avoid checking out an extra request-scoped ``db.session``
connection. Falls back to ``db.session`` for backward compatibility.
"""
if not app_model.workflow_id:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
bind = session if session is not None else db.session
workflow = bind.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id,
)
.first()
.limit(1)
)
return workflow

View File

@ -12,6 +12,7 @@ from datetime import UTC, datetime
from typing import Any
from celery import shared_task
from graphon.enums import WorkflowExecutionStatus
from sqlalchemy import func, select
from sqlalchemy.orm import Session
@ -27,8 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from dify_graph.enums import WorkflowExecutionStatus
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from models.enums import (
AppTriggerType,
CreatorUserRole,
@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
@ -258,59 +259,58 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
# Ensure expire_on_commit is set to False to remain workflows available
with session_factory.create_session() as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
type=InvokeFrom.TRIGGER,
tenant_id=subscription.tenant_id,
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
user_id=user_id,
)
for plugin_trigger in subscribers:
# Get workflow from mapping
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
if not workflow:
logger.error(
"Workflow not found for app %s",
plugin_trigger.app_id,
)
continue
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
type=InvokeFrom.TRIGGER,
tenant_id=subscription.tenant_id,
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
user_id=user_id,
)
# Find the trigger node in the workflow
event_node = None
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
if node_id == plugin_trigger.node_id:
event_node = node_config
break
if not event_node:
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
# invoke trigger
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
for plugin_trigger in subscribers:
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
if not workflow:
logger.error(
"Workflow not found for app %s",
plugin_trigger.app_id,
)
continue
# consume quota before invoking trigger
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info(
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
)
return 0
event_node = None
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
if node_id == plugin_trigger.node_id:
event_node = node_config
break
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
if not event_node:
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id)
return dispatched_count
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
with session_factory.create_session() as session:
try:
invoke_response = TriggerManager.invoke_trigger_event(
tenant_id=subscription.tenant_id,
@ -387,6 +387,7 @@ def dispatch_triggered_workflow(
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
quota_charge.commit()
dispatched_count += 1
logger.info(
"Triggered workflow for app %s with trigger event %s",
@ -401,7 +402,7 @@ def dispatch_triggered_workflow(
plugin_trigger.app_id,
)
return dispatched_count
return dispatched_count
def dispatch_triggered_workflows(

View File

@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import (
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.schedule_service import ScheduleService
from services.workflow.entities import ScheduleTriggerData
@ -32,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
# Ensure expire_on_commit is set to False to remain schedule/tenant_owner available
with session_factory.create_session() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
@ -41,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None:
if not tenant_owner:
raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}")
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
return
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
return
try:
# Production dispatch: Trigger the workflow normally
try:
with session_factory.create_session() as session:
response = AsyncWorkflowService.trigger_workflow_async(
session=session,
user=tenant_owner,
@ -61,9 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None:
tenant_id=schedule.tenant_id,
),
)
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()
raise ScheduleExecutionError(
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
) from e
quota_charge.commit()
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()
raise ScheduleExecutionError(
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
) from e

View File

@ -36,12 +36,19 @@ class TestAppGenerateService:
) as mock_message_based_generator,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
"result": "success",
"history_id": "test_history_id",
mock_billing_service.quota_reserve.return_value = {
"reservation_id": "test-reservation-id",
"available": 100,
"reserved": 1,
}
mock_billing_service.quota_commit.return_value = {
"available": 99,
"reserved": 0,
"refunded": 0,
}
# Setup default mock returns for workflow service
@ -101,6 +108,8 @@ class TestAppGenerateService:
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_quota_dify_config.BILLING_ENABLED = False
mock_global_dify_config.BILLING_ENABLED = False
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
@ -118,6 +127,7 @@ class TestAppGenerateService:
"message_based_generator": mock_message_based_generator,
"account_feature_service": mock_account_feature_service,
"dify_config": mock_dify_config,
"quota_dify_config": mock_quota_dify_config,
"global_dify_config": mock_global_dify_config,
}
@ -465,6 +475,7 @@ class TestAppGenerateService:
# Set BILLING_ENABLED to True for this test
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
# Setup test arguments
@ -478,8 +489,10 @@ class TestAppGenerateService:
# Verify the result
assert result == ["test_response"]
# Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
# Verify billing two-phase quota (reserve + commit)
billing = mock_external_service_dependencies["billing_service"]
billing.quota_reserve.assert_called_once()
billing.quota_commit.assert_called_once()
def test_generate_with_invalid_app_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -0,0 +1,517 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from enums.quota_type import QuotaType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import AppTriggerStatus, AppTriggerType
from models.model import App
from models.trigger import AppTrigger, WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger.webhook_service import WebhookService
class WebhookServiceRelationshipFactory:
@staticmethod
def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]:
account = Account(
name=f"Account {uuid4()}",
email=f"webhook-{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal")
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App:
app = App(
tenant_id=tenant.id,
name=f"Webhook App {uuid4()}",
description="",
mode="workflow",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
@staticmethod
def create_workflow(
db_session_with_containers: Session,
*,
app: App,
account: Account,
node_ids: list[str],
version: str,
) -> Workflow:
graph = {
"nodes": [
{
"id": node_id,
"data": {
"type": TRIGGER_WEBHOOK_NODE_TYPE,
"title": f"Webhook {node_id}",
"method": "post",
"content_type": "application/json",
"headers": [],
"params": [],
"body": [],
"status_code": 200,
"response_body": '{"status": "ok"}',
"timeout": 30,
},
}
for node_id in node_ids
],
"edges": [],
}
workflow = Workflow(
tenant_id=app.tenant_id,
app_id=app.id,
type="workflow",
graph=json.dumps(graph),
features=json.dumps({}),
created_by=account.id,
updated_by=account.id,
environment_variables=[],
conversation_variables=[],
version=version,
)
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
return workflow
@staticmethod
def create_webhook_trigger(
db_session_with_containers: Session,
*,
app: App,
account: Account,
node_id: str,
webhook_id: str | None = None,
) -> WorkflowWebhookTrigger:
webhook_trigger = WorkflowWebhookTrigger(
app_id=app.id,
node_id=node_id,
tenant_id=app.tenant_id,
webhook_id=webhook_id or uuid4().hex[:24],
created_by=account.id,
)
db_session_with_containers.add(webhook_trigger)
db_session_with_containers.commit()
return webhook_trigger
@staticmethod
def create_app_trigger(
db_session_with_containers: Session,
*,
app: App,
node_id: str,
status: AppTriggerStatus,
) -> AppTrigger:
app_trigger = AppTrigger(
tenant_id=app.tenant_id,
app_id=app.id,
node_id=node_id,
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
provider_name="webhook",
title=f"Webhook {node_id}",
status=status,
)
db_session_with_containers.add(app_trigger)
db_session_with_containers.commit()
return app_trigger
class TestWebhookServiceLookupWithContainers:
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
factory.create_app_trigger(
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED
)
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
factory.create_app_trigger(
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED
)
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
factory.create_app_trigger(
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED
)
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["published-node"],
version="2026-04-14.001",
)
draft_workflow = factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["debug-node"],
version=Workflow.VERSION_DRAFT,
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="debug-node"
)
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
webhook_trigger.webhook_id,
is_debug=True,
)
assert got_trigger.id == webhook_trigger.id
assert got_workflow.id == draft_workflow.id
assert got_node_config["id"] == "debug-node"
class TestWebhookServiceTriggerExecutionWithContainers:
def test_trigger_workflow_execution_triggers_async_workflow_successfully(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
end_user = SimpleNamespace(id=str(uuid4()))
webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}
quota_charge = MagicMock()
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
return_value=end_user,
),
patch(
"services.trigger.webhook_service.QuotaService.reserve",
return_value=quota_charge,
) as mock_reserve,
patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger,
):
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
mock_reserve.assert_called_once()
reserve_args = mock_reserve.call_args.args
assert reserve_args[0] == QuotaType.TRIGGER
assert reserve_args[1] == webhook_trigger.tenant_id
quota_charge.commit.assert_called_once()
mock_trigger.assert_called_once()
trigger_args = mock_trigger.call_args.args
assert trigger_args[1] is end_user
assert trigger_args[2].workflow_id == workflow.id
assert trigger_args[2].root_node_id == webhook_trigger.node_id
def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
return_value=SimpleNamespace(id=str(uuid4())),
),
patch(
"services.trigger.webhook_service.QuotaService.reserve",
side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1),
),
patch(
"services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited"
) as mock_mark_rate_limited,
):
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
workflow,
)
mock_mark_rate_limited.assert_called_once_with(tenant.id)
def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
side_effect=RuntimeError("boom"),
),
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
):
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
workflow,
)
mock_logger_exception.assert_called_once()
class TestWebhookServiceRelationshipSyncWithContainers:
def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)]
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT
)
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_raises_when_lock_not_acquired(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT
)
lock = MagicMock()
lock.acquire.return_value = False
with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock):
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
stale_trigger = factory.create_webhook_trigger(
db_session_with_containers,
app=app,
account=account,
node_id="node-stale",
webhook_id="stale-webhook-id-000001",
)
stale_trigger_id = stale_trigger.id
workflow = factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["node-new"],
version=Workflow.VERSION_DRAFT,
)
with patch(
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001"
):
WebhookService.sync_webhook_relationships(app, workflow)
db_session_with_containers.expire_all()
records = db_session_with_containers.scalars(
select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id)
).all()
assert [record.node_id for record in records] == ["node-new"]
assert records[0].webhook_id == "new-webhook-id-000001"
assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None
def test_sync_webhook_relationships_sets_redis_cache_for_new_record(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["node-cache"],
version=Workflow.VERSION_DRAFT,
)
cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache"
with patch(
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001"
):
WebhookService.sync_webhook_relationships(app, workflow)
cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key)
assert cached_payload is not None
assert cached_payload["node_id"] == "node-cache"
assert cached_payload["webhook_id"] == "cache-webhook-id-00001"
def test_sync_webhook_relationships_logs_when_lock_release_fails(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT
)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
with (
patch("services.trigger.webhook_service.redis_client.lock", return_value=lock),
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
):
WebhookService.sync_webhook_relationships(app, workflow)
mock_logger_exception.assert_called_once()
def _read_cache(cache_key: str) -> dict[str, str] | None:
from extensions.ext_redis import redis_client
cached = redis_client.get(cache_key)
if not cached:
return None
if isinstance(cached, bytes):
cached = cached.decode("utf-8")
return json.loads(cached)
WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache)

View File

@ -602,9 +602,9 @@ def test_schedule_trigger_creates_trigger_log(
)
# Mock quota to avoid rate limiting
from enums import quota_type
from services import quota_service
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
# Execute schedule trigger
workflow_schedule_tasks.run_schedule_trigger(plan.id)

View File

View File

@ -0,0 +1,349 @@
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
from unittest.mock import patch
import pytest
from enums.quota_type import QuotaType
from services.quota_service import QuotaCharge, QuotaService, unlimited
class TestQuotaType:
def test_billing_key_trigger(self):
assert QuotaType.TRIGGER.billing_key == "trigger_event"
def test_billing_key_workflow(self):
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
def test_billing_key_unlimited_raises(self):
with pytest.raises(ValueError, match="Invalid quota type"):
_ = QuotaType.UNLIMITED.billing_key
class TestQuotaService:
def test_reserve_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService"),
):
mock_cfg.BILLING_ENABLED = False
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_reserve_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
def test_reserve_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
assert charge.success is True
assert charge.charge_id == "rid-1"
assert charge._tenant_id == "t1"
assert charge._feature_key == "trigger_event"
assert charge._amount == 1
mock_bs.quota_reserve.assert_called_once()
def test_reserve_no_reservation_id_raises(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {}
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_quota_exceeded_propagates(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_api_exception_returns_unlimited(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = RuntimeError("network")
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_consume_calls_reserve_and_commit(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
mock_bs.quota_commit.return_value = {}
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
assert charge.success is True
mock_bs.quota_commit.assert_called_once()
def test_check_billing_disabled(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = False
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_check_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
def test_check_sufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=100),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
def test_check_insufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=5),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
def test_check_unlimited_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=-1),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
def test_check_exception_returns_true(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_release_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = False
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_empty_reservation(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.return_value = {}
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_called_once_with(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
)
def test_release_exception_swallowed(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.side_effect = RuntimeError("fail")
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_get_remaining_normal(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
def test_get_remaining_unlimited(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_over_limit_returns_zero(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_exception_returns_neg1(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.side_effect = RuntimeError
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_empty_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_non_dict_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = "invalid"
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_feature_not_in_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
assert remaining == 0
def test_get_remaining_non_dict_feature_info(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
class TestQuotaCharge:
def test_commit_success(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
mock_bs.quota_commit.assert_called_once_with(
tenant_id="t1",
feature_key="trigger_event",
reservation_id="rid-1",
actual_amount=1,
)
assert charge._committed is True
def test_commit_with_actual_amount(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=10,
)
charge.commit(actual_amount=5)
call_kwargs = mock_bs.quota_commit.call_args[1]
assert call_kwargs["actual_amount"] == 5
def test_commit_idempotent(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
charge.commit()
assert mock_bs.quota_commit.call_count == 1
def test_commit_no_charge_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_no_tenant_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
_feature_key="trigger_event",
)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_exception_swallowed(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.side_effect = RuntimeError("fail")
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
def test_refund_success(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
)
charge.refund()
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_refund_no_charge_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.refund()
mock_rel.assert_not_called()
def test_refund_no_tenant_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
)
charge.refund()
mock_rel.assert_not_called()
class TestUnlimited:
def test_unlimited_returns_success_with_no_charge_id(self):
charge = unlimited()
assert charge.success is True
assert charge.charge_id is None
assert charge._quota_type == QuotaType.UNLIMITED

View File

@ -23,6 +23,7 @@ import pytest
import services.app_generate_service as ags_module
from core.app.entities.app_invoke_entities import InvokeFrom
from enums.quota_type import QuotaType
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
@ -447,8 +448,8 @@ class TestGenerateBilling:
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
consume_mock = mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
reserve_mock = mocker.patch(
"services.app_generate_service.QuotaService.reserve",
return_value=quota_charge,
)
mocker.patch(
@ -467,7 +468,8 @@ class TestGenerateBilling:
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
consume_mock.assert_called_once_with("tenant-id")
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
quota_charge.commit.assert_called_once()
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
from services.errors.app import QuotaExceededError
@ -475,7 +477,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
"services.app_generate_service.QuotaService.reserve",
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
)
@ -492,7 +494,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
"services.app_generate_service.QuotaService.reserve",
return_value=quota_charge,
)
mocker.patch(

View File

@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
- repo: SQLAlchemyWorkflowTriggerLogRepository
- dispatcher_manager_class: QueueDispatcherManager class
- dispatcher: dispatcher instance
- quota_workflow: QuotaType.WORKFLOW
- quota_service: QuotaService mock
- get_workflow: AsyncWorkflowService._get_workflow method
- professional_task: execute_workflow_professional
- team_task: execute_workflow_team
@ -72,12 +72,7 @@ class TestAsyncWorkflowService:
mock_repo.create.side_effect = _create_side_effect
mock_dispatcher = MagicMock()
quota_workflow = MagicMock()
mock_get_workflow = MagicMock()
mock_professional_task = MagicMock()
mock_team_task = MagicMock()
mock_sandbox_task = MagicMock()
mock_quota_service = MagicMock()
with (
patch.object(
@ -93,8 +88,8 @@ class TestAsyncWorkflowService:
) as mock_get_workflow,
patch.object(
async_workflow_service_module,
"QuotaType",
new=SimpleNamespace(WORKFLOW=quota_workflow),
"QuotaService",
new=mock_quota_service,
),
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
@ -107,7 +102,7 @@ class TestAsyncWorkflowService:
"repo": mock_repo,
"dispatcher_manager_class": mock_dispatcher_manager_class,
"dispatcher": mock_dispatcher,
"quota_workflow": quota_workflow,
"quota_service": mock_quota_service,
"get_workflow": mock_get_workflow,
"professional_task": mock_professional_task,
"team_task": mock_team_task,
@ -146,6 +141,9 @@ class TestAsyncWorkflowService:
mocks["team_task"].delay.return_value = task_result
mocks["sandbox_task"].delay.return_value = task_result
quota_charge_mock = MagicMock()
mocks["quota_service"].reserve.return_value = quota_charge_mock
class DummyAccount:
def __init__(self, user_id: str):
self.id = user_id
@ -163,8 +161,9 @@ class TestAsyncWorkflowService:
assert result.status == "queued"
assert result.queue == queue_name
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
assert session.commit.call_count == 2
mocks["quota_service"].reserve.assert_called_once()
quota_charge_mock.commit.assert_called_once()
assert session.commit.call_count == 3
created_log = mocks["repo"].create.call_args[0][0]
assert created_log.status == WorkflowTriggerStatus.QUEUED
@ -250,7 +249,7 @@ class TestAsyncWorkflowService:
mocks = async_workflow_trigger_mocks
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
mocks["get_workflow"].return_value = workflow
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
feature="workflow",
tenant_id="tenant-123",
required=1,
@ -267,7 +266,7 @@ class TestAsyncWorkflowService:
trigger_data=trigger_data,
)
assert session.commit.call_count == 2
assert session.commit.call_count == 3
updated_log = mocks["repo"].update.call_args[0][0]
assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED
assert "Quota limit reached" in updated_log.error
@ -463,7 +462,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
# Assert
assert result == workflow
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123")
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123", session=None)
workflow_service.get_published_workflow.assert_not_called()
def test_should_raise_when_specific_workflow_id_not_found(self):
@ -491,7 +490,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
# Assert
assert result == workflow
workflow_service.get_published_workflow.assert_called_once_with(app_model)
workflow_service.get_published_workflow.assert_called_once_with(app_model, session=None)
workflow_service.get_published_workflow_by_id.assert_not_called()
def test_should_raise_when_default_published_workflow_not_found(self):

View File

@ -425,7 +425,7 @@ class TestBillingServiceUsageCalculation:
yield mock
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
"""Test retrieval of tenant feature plan usage information."""
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
@ -438,6 +438,20 @@ class TestBillingServiceUsageCalculation:
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
def test_get_quota_info(self, mock_send_request):
"""Test retrieval of quota info from new endpoint."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
mock_send_request.return_value = expected_response
# Act
result = BillingService.get_quota_info(tenant_id)
# Assert
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
"""Test updating tenant feature usage with positive delta (adding credits)."""
# Arrange
@ -515,6 +529,150 @@ class TestBillingServiceUsageCalculation:
)
class TestBillingServiceQuotaOperations:
"""Unit tests for quota reserve/commit/release operations."""
@pytest.fixture
def mock_send_request(self):
with patch.object(BillingService, "_send_request") as mock:
yield mock
def test_quota_reserve_success(self, mock_send_request):
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
mock_send_request.return_value = expected
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/reserve",
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
)
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
assert result["available"] == 99
assert isinstance(result["available"], int)
assert result["reserved"] == 1
assert isinstance(result["reserved"], int)
def test_quota_reserve_with_meta(self, mock_send_request):
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
meta = {"source": "webhook"}
BillingService.quota_reserve(
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"source": "webhook"}
def test_quota_commit_success(self, mock_send_request):
expected = {"available": 98, "reserved": 0, "refunded": 0}
mock_send_request.return_value = expected
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/commit",
json={
"tenant_id": "t1",
"feature_key": "trigger_event",
"reservation_id": "rid-1",
"actual_amount": 1,
},
)
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
)
assert result["available"] == 97
assert isinstance(result["available"], int)
assert result["refunded"] == 1
assert isinstance(result["refunded"], int)
def test_quota_commit_with_meta(self, mock_send_request):
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
meta = {"reason": "partial"}
BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"reason": "partial"}
def test_quota_release_success(self, mock_send_request):
expected = {"available": 100, "reserved": 0, "released": 1}
mock_send_request.return_value = expected
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/release",
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
)
def test_quota_release_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
assert result["available"] == 100
assert isinstance(result["available"], int)
assert result["released"] == 1
assert isinstance(result["released"], int)
def test_get_quota_info_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int for get_quota_info."""
mock_send_request.return_value = {
"trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"},
"api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"},
}
result = BillingService.get_quota_info("t1")
assert result["trigger_event"]["usage"] == 42
assert isinstance(result["trigger_event"]["usage"], int)
assert result["trigger_event"]["limit"] == 3000
assert isinstance(result["trigger_event"]["limit"], int)
assert result["trigger_event"]["reset_date"] == 1700000000
assert isinstance(result["trigger_event"]["reset_date"], int)
assert result["api_rate_limit"]["limit"] == -1
assert isinstance(result["api_rate_limit"]["limit"], int)
def test_get_quota_info_accepts_int_values(self, mock_send_request):
"""Test that get_quota_info works with native int values."""
expected = {
"trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000},
"api_rate_limit": {"usage": 0, "limit": -1},
}
mock_send_request.return_value = expected
result = BillingService.get_quota_info("t1")
assert result["trigger_event"]["usage"] == 42
assert result["trigger_event"]["limit"] == 3000
assert result["api_rate_limit"]["limit"] == -1
class TestBillingServiceRateLimitEnforcement:
"""Unit tests for rate limit enforcement mechanisms.

View File

@ -337,10 +337,7 @@ class TestWorkflowService:
app = TestWorkflowAssociatedDataFactory.create_app_mock()
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock()
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
mock_db_session.session.scalar.return_value = mock_workflow
result = workflow_service.get_draft_workflow(app)
@ -350,10 +347,7 @@ class TestWorkflowService:
"""Test get_draft_workflow returns None when no draft exists."""
app = TestWorkflowAssociatedDataFactory.create_app_mock()
# Mock database query to return None
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = None
mock_db_session.session.scalar.return_value = None
result = workflow_service.get_draft_workflow(app)
@ -365,10 +359,7 @@ class TestWorkflowService:
workflow_id = "workflow-123"
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
mock_db_session.session.scalar.return_value = mock_workflow
result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id)
@ -383,10 +374,7 @@ class TestWorkflowService:
workflow_id = "workflow-123"
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
mock_db_session.session.scalar.return_value = mock_workflow
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
@ -405,10 +393,7 @@ class TestWorkflowService:
workflow_id=workflow_id, version=Workflow.VERSION_DRAFT
)
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
mock_db_session.session.scalar.return_value = mock_workflow
with pytest.raises(IsDraftWorkflowError):
workflow_service.get_published_workflow_by_id(app, workflow_id)
@ -418,10 +403,7 @@ class TestWorkflowService:
app = TestWorkflowAssociatedDataFactory.create_app_mock()
workflow_id = "nonexistent-workflow"
# Mock database query to return None
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = None
mock_db_session.session.scalar.return_value = None
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
@ -433,10 +415,7 @@ class TestWorkflowService:
app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id)
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
mock_db_session.session.scalar.return_value = mock_workflow
result = workflow_service.get_published_workflow(app)
@ -465,11 +444,7 @@ class TestWorkflowService:
graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
features = {"file_upload": {"enabled": False}}
# Mock get_draft_workflow to return None (no existing draft)
# This simulates the first time a workflow is created for an app
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = None
mock_db_session.session.scalar.return_value = None
with (
patch.object(workflow_service, "validate_features_structure"),
@ -506,9 +481,7 @@ class TestWorkflowService:
# Mock existing draft workflow
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash)
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
mock_db_session.session.scalar.return_value = mock_workflow
with (
patch.object(workflow_service, "validate_features_structure"),
@ -547,9 +520,7 @@ class TestWorkflowService:
# Mock existing draft workflow with different hash
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash")
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
mock_db_session.session.scalar.return_value = mock_workflow
with pytest.raises(WorkflowHashNotEqualError):
workflow_service.sync_draft_workflow(

View File

@ -0,0 +1,204 @@
from unittest.mock import MagicMock, patch
import pytest
import tasks.trigger_processing_tasks as trigger_processing_tasks_module
from services.errors.app import QuotaExceededError
from tasks.trigger_processing_tasks import dispatch_triggered_workflow
class TestDispatchTriggeredWorkflow:
"""Unit tests covering branch behaviours of ``dispatch_triggered_workflow``.
The covered branches are:
- workflow missing for ``plugin_trigger.app_id`` log + ``continue``
- ``QuotaService.reserve`` raising ``QuotaExceededError``
``mark_tenant_triggers_rate_limited`` + early ``return``
- ``trigger_workflow_async`` succeeds
``quota_charge.commit()`` + ``dispatched_count`` increments
"""
@pytest.fixture
def subscription(self):
sub = MagicMock()
sub.id = "subscription-123"
sub.tenant_id = "tenant-123"
sub.provider_id = "langgenius/test_plugin/test_plugin"
sub.endpoint_id = "endpoint-123"
sub.credentials = {}
sub.credential_type = "api_key"
return sub
@pytest.fixture
def plugin_trigger(self):
trigger = MagicMock()
trigger.id = "plugin-trigger-123"
trigger.app_id = "app-123"
trigger.node_id = "node-123"
return trigger
@pytest.fixture
def provider_controller(self):
controller = MagicMock()
controller.plugin_unique_identifier = "langgenius/test_plugin:0.0.1"
controller.entity.identity.name = "Test Plugin"
controller.entity.identity.icon = "icon.svg"
controller.entity.identity.icon_dark = "icon_dark.svg"
return controller
@pytest.fixture
def dispatch_mocks(self, subscription, plugin_trigger, provider_controller):
"""Patch all external dependencies reached by ``dispatch_triggered_workflow``.
Defaults are configured so the code flow can reach the final async
trigger block (line ~385); each test overrides specific handles
(``get_workflows``, ``reserve``, ``create_end_user_batch``, ...) to
drive the path it targets.
"""
session_cm = MagicMock()
session_cm.__enter__.return_value = MagicMock()
session_cm.__exit__.return_value = False
invoke_response = MagicMock()
invoke_response.cancelled = False
invoke_response.variables = {}
quota_charge = MagicMock()
with (
patch.object(
trigger_processing_tasks_module.TriggerHttpRequestCachingService,
"get_request",
return_value=MagicMock(),
),
patch.object(
trigger_processing_tasks_module.TriggerHttpRequestCachingService,
"get_payload",
return_value=MagicMock(),
),
patch.object(
trigger_processing_tasks_module.TriggerSubscriptionOperatorService,
"get_subscriber_triggers",
return_value=[plugin_trigger],
),
patch.object(
trigger_processing_tasks_module.TriggerManager,
"get_trigger_provider",
return_value=provider_controller,
),
patch.object(
trigger_processing_tasks_module.TriggerManager,
"invoke_trigger_event",
return_value=invoke_response,
) as invoke_trigger_event,
patch.object(
trigger_processing_tasks_module.TriggerEventNodeData,
"model_validate",
return_value=MagicMock(),
),
patch.object(
trigger_processing_tasks_module,
"_get_latest_workflows_by_app_ids",
) as get_workflows,
patch.object(
trigger_processing_tasks_module.EndUserService,
"create_end_user_batch",
return_value={},
) as create_end_user_batch,
patch.object(
trigger_processing_tasks_module.session_factory,
"create_session",
return_value=session_cm,
),
patch.object(
trigger_processing_tasks_module.QuotaService,
"reserve",
return_value=quota_charge,
) as reserve,
patch.object(
trigger_processing_tasks_module.AppTriggerService,
"mark_tenant_triggers_rate_limited",
) as mark_rate_limited,
patch.object(
trigger_processing_tasks_module.AsyncWorkflowService,
"trigger_workflow_async",
) as trigger_workflow_async,
):
yield {
"get_workflows": get_workflows,
"reserve": reserve,
"quota_charge": quota_charge,
"mark_rate_limited": mark_rate_limited,
"invoke_trigger_event": invoke_trigger_event,
"invoke_response": invoke_response,
"create_end_user_batch": create_end_user_batch,
"trigger_workflow_async": trigger_workflow_async,
}
def test_dispatch_skips_when_workflow_missing(self, subscription, dispatch_mocks):
"""Covers missing workflow → log + ``continue``."""
dispatch_mocks["get_workflows"].return_value = {}
dispatched = dispatch_triggered_workflow(
user_id="user-123",
subscription=subscription,
event_name="test_event",
request_id="request-123",
)
assert dispatched == 0
dispatch_mocks["reserve"].assert_not_called()
dispatch_mocks["invoke_trigger_event"].assert_not_called()
dispatch_mocks["mark_rate_limited"].assert_not_called()
def test_dispatch_marks_rate_limited_when_quota_exceeded(self, subscription, plugin_trigger, dispatch_mocks):
"""Covers QuotaExceededError → mark rate-limited + early return."""
workflow_mock = MagicMock()
workflow_mock.walk_nodes.return_value = iter(
[(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})]
)
dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock}
dispatch_mocks["reserve"].side_effect = QuotaExceededError(
feature="trigger", tenant_id=subscription.tenant_id, required=1
)
dispatched = dispatch_triggered_workflow(
user_id="user-123",
subscription=subscription,
event_name="test_event",
request_id="request-123",
)
assert dispatched == 0
dispatch_mocks["reserve"].assert_called_once()
dispatch_mocks["mark_rate_limited"].assert_called_once_with(subscription.tenant_id)
dispatch_mocks["invoke_trigger_event"].assert_not_called()
def test_dispatch_commits_quota_and_counts_when_workflow_triggered(
self, subscription, plugin_trigger, dispatch_mocks
):
"""Happy path: end user exists and async trigger succeeds."""
workflow_mock = MagicMock()
workflow_mock.id = "workflow-123"
workflow_mock.walk_nodes.return_value = iter(
[(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})]
)
dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock}
end_user_mock = MagicMock()
dispatch_mocks["create_end_user_batch"].return_value = {plugin_trigger.app_id: end_user_mock}
dispatched = dispatch_triggered_workflow(
user_id="user-123",
subscription=subscription,
event_name="test_event",
request_id="request-123",
)
assert dispatched == 1
dispatch_mocks["trigger_workflow_async"].assert_called_once()
_, kwargs = dispatch_mocks["trigger_workflow_async"].call_args
assert kwargs["user"] is end_user_mock
dispatch_mocks["quota_charge"].commit.assert_called_once()
dispatch_mocks["quota_charge"].refund.assert_not_called()
dispatch_mocks["mark_rate_limited"].assert_not_called()