import json import logging import os from collections.abc import Sequence from typing import Literal, NotRequired, TypedDict import httpx from pydantic import TypeAdapter from sqlalchemy import select from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from werkzeug.exceptions import InternalServerError from core.helper.http_client_pooling import get_pooled_http_client from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.helper import RateLimiter from models import Account, TenantAccountJoin, TenantAccountRole logger = logging.getLogger(__name__) _http_client: httpx.Client = get_pooled_http_client( "billing:default", lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), ) class SubscriptionPlan(TypedDict): """Tenant subscriptionplan information.""" plan: str expiration_date: int class _BillingQuota(TypedDict): size: int limit: int class _VectorSpaceQuota(TypedDict): size: float limit: int class _KnowledgeRateLimit(TypedDict): # NOTE (hj24): # 1. Return for sandbox users but is null for other plans, it's defined but never used. # 2. Keep it for compatibility for now, can be deprecated in future versions. size: NotRequired[int] # NOTE END limit: int class _BillingSubscription(TypedDict): plan: str interval: str education: bool class BillingInfo(TypedDict): """Response of /subscription/info. NOTE (hj24): - Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python() - To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter: 1. validate_python in non-strict mode will coerce it to the expected type 2. In strict mode, it will raise ValidationError 3. To preserve compatibility, always keep non-strict mode here and avoid strict mode """ enabled: bool subscription: _BillingSubscription members: _BillingQuota apps: _BillingQuota vector_space: _VectorSpaceQuota knowledge_rate_limit: _KnowledgeRateLimit documents_upload_quota: _BillingQuota annotation_quota_limit: _BillingQuota docs_processing: str can_replace_logo: bool model_load_balancing_enabled: bool knowledge_pipeline_publish_enabled: bool next_credit_reset_date: NotRequired[int] _billing_info_adapter = TypeAdapter(BillingInfo) class KnowledgeRateLimitDict(TypedDict): limit: int subscription_plan: str class TenantFeaturePlanUsageDict(TypedDict): result: str history_id: str class LangContentDict(TypedDict): lang: str title: str subtitle: str body: str title_pic_url: str class NotificationDict(TypedDict): notification_id: str contents: dict[str, LangContentDict] frequency: Literal["once", "every_page_load"] class AccountNotificationDict(TypedDict, total=False): should_show: bool notification: NotificationDict shouldShow: bool notifications: list[dict] class UpsertNotificationDict(TypedDict): notification_id: str class BatchAddNotificationAccountsDict(TypedDict): count: int class DismissNotificationDict(TypedDict): success: bool class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60) # Redis key prefix for tenant plan cache _PLAN_CACHE_KEY_PREFIX = "tenant_plan:" # Cache TTL: 10 minutes _PLAN_CACHE_TTL = 600 @classmethod def get_info(cls, tenant_id: str) -> BillingInfo: params = {"tenant_id": tenant_id} billing_info = cls._send_request("GET", "/subscription/info", params=params) return _billing_info_adapter.validate_python(billing_info) @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): params = {"tenant_id": tenant_id} usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params) return usage_info @classmethod def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict: params = {"tenant_id": tenant_id} knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params) return { "limit": knowledge_rate_limit.get("limit", 10), "subscription_plan": knowledge_rate_limit.get("subscription_plan", CloudPlan.SANDBOX), } @classmethod def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} return cls._send_request("GET", "/subscription/payment-link", params=params) @classmethod def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str): params = { "provider_name": provider_name, "tenant_id": tenant_id, "account_id": account_id, "prefilled_email": prefilled_email, } return cls._send_request("GET", "/model-provider/payment-link", params=params) @classmethod def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id} return cls._send_request("GET", "/invoices", params=params) @classmethod def update_tenant_feature_plan_usage( cls, tenant_id: str, feature_key: str, delta: int ) -> TenantFeaturePlanUsageDict: """ Update tenant feature plan usage. Args: tenant_id: Tenant identifier feature_key: Feature key (e.g., 'trigger', 'workflow') delta: Usage delta (positive to add, negative to consume) Returns: Response dict with 'result' and 'history_id' Example: {"result": "success", "history_id": "uuid"} """ return cls._send_request( "POST", "/tenant-feature-usage/usage", params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta}, ) @classmethod def refund_tenant_feature_plan_usage(cls, history_id: str) -> TenantFeaturePlanUsageDict: """ Refund a previous usage charge. Args: history_id: The history_id returned from update_tenant_feature_plan_usage Returns: Response dict with 'result' and 'history_id' """ return cls._send_request("POST", "/tenant-feature-usage/refund", params={"quota_usage_history_id": history_id}) @classmethod def get_tenant_feature_plan_usage(cls, tenant_id: str, feature_key: str): params = {"tenant_id": tenant_id, "feature_key": feature_key} return cls._send_request("GET", "/billing/tenant_feature_plan/usage", params=params) @classmethod @retry( wait=wait_fixed(2), stop=stop_before_delay(10), retry=retry_if_exception_type(httpx.RequestError), reraise=True, ) def _send_request(cls, method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, json=None, params=None): headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = _http_client.request(method, url, json=json, params=params, headers=headers, follow_redirects=True) if method == "GET" and response.status_code != httpx.codes.OK: raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") if method == "PUT": if response.status_code == httpx.codes.INTERNAL_SERVER_ERROR: raise InternalServerError( "Unable to process billing request. Please try again later or contact support." ) if response.status_code != httpx.codes.OK: raise ValueError("Invalid arguments.") if method == "POST" and response.status_code != httpx.codes.OK: raise ValueError(f"Unable to send request to {url}. Please try again later or contact support.") if method == "DELETE" and response.status_code != httpx.codes.OK: logger.error("billing_service: DELETE response: %s %s", response.status_code, response.text) raise ValueError(f"Unable to process delete request {url}. Please try again later or contact support.") return response.json() @staticmethod def is_tenant_owner_or_admin(current_user: Account): tenant_id = current_user.current_tenant_id join: TenantAccountJoin | None = db.session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .limit(1) ) if not join: raise ValueError("Tenant account join not found") if not TenantAccountRole.is_privileged_role(TenantAccountRole(join.role)): raise ValueError("Only team owner or team admin can perform this action") @classmethod def delete_account(cls, account_id: str): """Delete account.""" params = {"account_id": account_id} return cls._send_request("DELETE", "/account", params=params) @classmethod def is_email_in_freeze(cls, email: str) -> bool: params = {"email": email} try: response = cls._send_request("GET", "/account/in-freeze", params=params) return bool(response.get("data", False)) except Exception: return False @classmethod def update_account_deletion_feedback(cls, email: str, feedback: str): """Update account deletion feedback.""" json = {"email": email, "feedback": feedback} return cls._send_request("POST", "/account/delete-feedback", json=json) class EducationIdentity: verification_rate_limit = RateLimiter(prefix="edu_verification_rate_limit", max_attempts=10, time_window=60) activation_rate_limit = RateLimiter(prefix="edu_activation_rate_limit", max_attempts=10, time_window=60) @classmethod def verify(cls, account_id: str, account_email: str): if cls.verification_rate_limit.is_rate_limited(account_email): from controllers.console.error import EducationVerifyLimitError raise EducationVerifyLimitError() cls.verification_rate_limit.increment_rate_limit(account_email) params = {"account_id": account_id} return BillingService._send_request("GET", "/education/verify", params=params) @classmethod def status(cls, account_id: str): params = {"account_id": account_id} return BillingService._send_request("GET", "/education/status", params=params) @classmethod def activate(cls, account: Account, token: str, institution: str, role: str): if cls.activation_rate_limit.is_rate_limited(account.email): from controllers.console.error import EducationActivateLimitError raise EducationActivateLimitError() cls.activation_rate_limit.increment_rate_limit(account.email) params = {"account_id": account.id, "curr_tenant_id": account.current_tenant_id} json = { "institution": institution, "token": token, "role": role, } return BillingService._send_request("POST", "/education/", json=json, params=params) @classmethod def autocomplete(cls, keywords: str, page: int = 0, limit: int = 20): params = {"keywords": keywords, "page": page, "limit": limit} return BillingService._send_request("GET", "/education/autocomplete", params=params) @classmethod def get_compliance_download_link( cls, doc_name: str, account_id: str, tenant_id: str, ip: str, device_info: str, ): limiter_key = f"{account_id}:{tenant_id}" if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key): from controllers.console.error import ComplianceRateLimitError raise ComplianceRateLimitError() json = { "doc_name": doc_name, "account_id": account_id, "tenant_id": tenant_id, "ip_address": ip, "device_info": device_info, } res = cls._send_request("POST", "/compliance/download", json=json) cls.compliance_download_rate_limiter.increment_rate_limit(limiter_key) return res @classmethod def clean_billing_info_cache(cls, tenant_id: str): redis_client.delete(f"tenant:{tenant_id}:billing_info") @classmethod def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str): payload = {"account_id": account_id, "click_id": click_id} return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload) @classmethod def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: """ Bulk fetch billing subscription plan via billing API. Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request) Returns: Mapping of tenant_id -> {plan: str, expiration_date: int} """ results: dict[str, SubscriptionPlan] = {} subscription_adapter = TypeAdapter(SubscriptionPlan) chunk_size = 200 for i in range(0, len(tenant_ids), chunk_size): chunk = tenant_ids[i : i + chunk_size] try: resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk}) data = resp.get("data", {}) for tenant_id, plan in data.items(): try: subscription_plan = subscription_adapter.validate_python(plan) results[tenant_id] = subscription_plan except Exception: logger.exception( "get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id ) continue except Exception: logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk) continue return results @classmethod def _make_plan_cache_key(cls, tenant_id: str) -> str: return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}" @classmethod def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: """ Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios. NOTE: if you want to high data consistency, use get_plan_bulk instead. Returns: Mapping of tenant_id -> {plan: str, expiration_date: int} """ tenant_plans: dict[str, SubscriptionPlan] = {} if not tenant_ids: return tenant_plans subscription_adapter = TypeAdapter(SubscriptionPlan) # Step 1: Batch fetch from Redis cache using mget redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids] try: cached_values = redis_client.mget(redis_keys) if len(cached_values) != len(tenant_ids): raise Exception( "get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch" ) # Map cached values back to tenant_ids cache_misses: list[str] = [] for tenant_id, cached_value in zip(tenant_ids, cached_values): if cached_value: try: # Redis returns bytes, decode to string and parse JSON json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value plan_dict = json.loads(json_str) # NOTE (hj24): New billing versions may return timestamp as str, and validate_python # in non-strict mode will coerce it to the expected int type. # To preserve compatibility, always keep non-strict mode here and avoid strict mode. subscription_plan = subscription_adapter.validate_python(plan_dict) # NOTE END tenant_plans[tenant_id] = subscription_plan except Exception: logger.exception( "get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id ) cache_misses.append(tenant_id) else: cache_misses.append(tenant_id) logger.info( "get_plan_bulk_with_cache: cache hits=%s, cache misses=%s", len(tenant_plans), len(cache_misses), ) except Exception: logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API") cache_misses = list(tenant_ids) # Step 2: Fetch missing plans from billing API if cache_misses: bulk_plans = BillingService.get_plan_bulk(cache_misses) if bulk_plans: plans_to_cache: dict[str, SubscriptionPlan] = {} for tenant_id, subscription_plan in bulk_plans.items(): tenant_plans[tenant_id] = subscription_plan plans_to_cache[tenant_id] = subscription_plan # Step 3: Batch update Redis cache using pipeline if plans_to_cache: try: pipe = redis_client.pipeline() for tenant_id, subscription_plan in plans_to_cache.items(): redis_key = cls._make_plan_cache_key(tenant_id) # Serialize dict to JSON string json_str = json.dumps(subscription_plan) pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str) pipe.execute() logger.info( "get_plan_bulk_with_cache: cached %s new tenant plans to Redis", len(plans_to_cache), ) except Exception: logger.exception("get_plan_bulk_with_cache: redis pipeline failed") return tenant_plans @classmethod def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]: resp = cls._send_request("GET", "/subscription/cleanup/whitelist") data = resp.get("data", []) tenant_whitelist = [] for item in data: tenant_whitelist.append(item["tenant_id"]) return tenant_whitelist @classmethod def get_account_notification(cls, account_id: str) -> AccountNotificationDict: """Return the active in-product notification for account_id, if any. Calling this endpoint also marks the notification as seen; subsequent calls will return should_show=false when frequency='once'. Response shape (mirrors GetAccountNotificationReply): { "should_show": bool, "notification": { # present only when should_show=true "notification_id": str, "contents": { # lang -> LangContent "en": {"lang": "en", "title": ..., "subtitle": ..., "body": ..., "title_pic_url": ...}, ... }, "frequency": "once" | "every_page_load" } } """ return cls._send_request("GET", "/notifications/active", params={"account_id": account_id}) @classmethod def upsert_notification( cls, contents: list[LangContentDict], frequency: str = "once", status: str = "active", notification_id: str | None = None, start_time: str | None = None, end_time: str | None = None, ) -> UpsertNotificationDict: """Create or update a notification. contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str} start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional. Returns {"notification_id": str}. """ payload: dict = { "contents": contents, "frequency": frequency, "status": status, } if notification_id: payload["notification_id"] = notification_id if start_time: payload["start_time"] = start_time if end_time: payload["end_time"] = end_time return cls._send_request("POST", "/notifications", json=payload) @classmethod def batch_add_notification_accounts( cls, notification_id: str, account_ids: list[str] ) -> BatchAddNotificationAccountsDict: """Register target account IDs for a notification (max 1000 per call). Returns {"count": int}. """ return cls._send_request( "POST", f"/notifications/{notification_id}/accounts", json={"account_ids": account_ids}, ) @classmethod def dismiss_notification(cls, notification_id: str, account_id: str) -> DismissNotificationDict: """Mark a notification as dismissed for an account. Returns {"success": bool}. """ return cls._send_request( "POST", f"/notifications/{notification_id}/dismiss", json={"account_id": account_id}, )