From 2431ddfde69330ca84eaeab351b46b160af6e857 Mon Sep 17 00:00:00 2001 From: hj24 Date: Thu, 20 Nov 2025 15:58:05 +0800 Subject: [PATCH] Feat integrate partner stack (#28353) Co-authored-by: Joel Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/controllers/console/billing/billing.py | 41 ++- api/services/billing_service.py | 15 +- .../console/billing/test_billing.py | 253 ++++++++++++++++++ .../services/test_billing_service.py | 236 ++++++++++++++++ web/app/(commonLayout)/layout.tsx | 2 + .../billing/partner-stack/index.tsx | 20 ++ .../billing/partner-stack/use-ps-info.ts | 70 +++++ web/app/signin/page.tsx | 7 + web/config/index.ts | 5 + web/service/use-billing.ts | 19 ++ 10 files changed, 665 insertions(+), 3 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/billing/test_billing.py create mode 100644 api/tests/unit_tests/services/test_billing_service.py create mode 100644 web/app/components/billing/partner-stack/index.tsx create mode 100644 web/app/components/billing/partner-stack/use-ps-info.ts create mode 100644 web/service/use-billing.ts diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 436d29df83..6efb4564ca 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,6 +1,9 @@ -from flask_restx import Resource, reqparse +import base64 -from controllers.console import console_ns +from flask_restx import Resource, fields, reqparse +from werkzeug.exceptions import BadRequest + +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from enums.cloud_plan import CloudPlan from libs.login import current_account_with_tenant, login_required @@ -41,3 +44,37 @@ class Invoices(Resource): current_user, current_tenant_id = current_account_with_tenant() BillingService.is_tenant_owner_or_admin(current_user) return BillingService.get_invoices(current_user.email, current_tenant_id) + + +@console_ns.route("/billing/partners//tenants") +class PartnerTenants(Resource): + @api.doc("sync_partner_tenants_bindings") + @api.doc(description="Sync partner tenants bindings") + @api.doc(params={"partner_key": "Partner key"}) + @api.expect( + api.model( + "SyncPartnerTenantsBindingsRequest", + {"click_id": fields.String(required=True, description="Click Id from partner referral link")}, + ) + ) + @api.response(200, "Tenants synced to partner successfully") + @api.response(400, "Invalid partner information") + @setup_required + @login_required + @account_initialization_required + @only_edition_cloud + def put(self, partner_key: str): + current_user, _ = current_account_with_tenant() + parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json") + args = parser.parse_args() + + try: + click_id = args["click_id"] + decoded_partner_key = base64.b64decode(partner_key).decode("utf-8") + except Exception: + raise BadRequest("Invalid partner_key") + + if not click_id or not decoded_partner_key or not current_user.id: + raise BadRequest("Invalid partner information") + + return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id) diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 1f1aea709f..54e1c9d285 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -3,6 +3,7 @@ from typing import Literal import httpx from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed +from werkzeug.exceptions import InternalServerError from enums.cloud_plan import CloudPlan from extensions.ext_database import db @@ -107,13 +108,20 @@ class BillingService: retry=retry_if_exception_type(httpx.RequestError), reraise=True, ) - def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None): + def _send_request(cls, method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, json=None, params=None): headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = httpx.request(method, url, json=json, params=params, headers=headers) if method == "GET" and response.status_code != httpx.codes.OK: raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") + if method == "PUT": + if response.status_code == httpx.codes.INTERNAL_SERVER_ERROR: + raise InternalServerError( + "Unable to process billing request. Please try again later or contact support." + ) + if response.status_code != httpx.codes.OK: + raise ValueError("Invalid arguments.") if method == "POST" and response.status_code != httpx.codes.OK: raise ValueError(f"Unable to send request to {url}. Please try again later or contact support.") return response.json() @@ -226,3 +234,8 @@ class BillingService: @classmethod def clean_billing_info_cache(cls, tenant_id: str): redis_client.delete(f"tenant:{tenant_id}:billing_info") + + @classmethod + def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str): + payload = {"account_id": account_id, "click_id": click_id} + return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload) diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py new file mode 100644 index 0000000000..eaa489d56b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -0,0 +1,253 @@ +import base64 +import json +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import BadRequest + +from controllers.console.billing.billing import PartnerTenants +from models.account import Account + + +class TestPartnerTenants: + """Unit tests for PartnerTenants controller.""" + + @pytest.fixture + def app(self): + """Create Flask app for testing.""" + app = Flask(__name__) + app.config["TESTING"] = True + app.config["SECRET_KEY"] = "test-secret-key" + return app + + @pytest.fixture + def mock_account(self): + """Create a mock account.""" + account = MagicMock(spec=Account) + account.id = "account-123" + account.email = "test@example.com" + account.current_tenant_id = "tenant-456" + account.is_authenticated = True + return account + + @pytest.fixture + def mock_billing_service(self): + """Mock BillingService.""" + with patch("controllers.console.billing.billing.BillingService") as mock_service: + yield mock_service + + @pytest.fixture + def mock_decorators(self): + """Mock decorators to avoid database access.""" + with ( + patch("controllers.console.wraps.db") as mock_db, + patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), + patch("libs.login.dify_config.LOGIN_DISABLED", False), + patch("libs.login.check_csrf_token") as mock_csrf, + ): + mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists + mock_csrf.return_value = None + yield {"db": mock_db, "csrf": mock_csrf} + + def test_put_success(self, app, mock_account, mock_billing_service, mock_decorators): + """Test successful partner tenants bindings sync.""" + # Arrange + partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") + click_id = "click-id-789" + expected_response = {"result": "success", "data": {"synced": True}} + + mock_billing_service.sync_partner_tenants_bindings.return_value = expected_response + + with app.test_request_context( + method="PUT", + json={"click_id": click_id}, + path=f"/billing/partners/{partner_key_encoded}/tenants", + ): + with ( + patch( + "controllers.console.billing.billing.current_account_with_tenant", + return_value=(mock_account, "tenant-456"), + ), + patch("libs.login._get_user", return_value=mock_account), + ): + resource = PartnerTenants() + result = resource.put(partner_key_encoded) + + # Assert + assert result == expected_response + mock_billing_service.sync_partner_tenants_bindings.assert_called_once_with( + mock_account.id, "partner-key-123", click_id + ) + + def test_put_invalid_partner_key_base64(self, app, mock_account, mock_billing_service, mock_decorators): + """Test that invalid base64 partner_key raises BadRequest.""" + # Arrange + invalid_partner_key = "invalid-base64-!@#$" + click_id = "click-id-789" + + with app.test_request_context( + method="PUT", + json={"click_id": click_id}, + path=f"/billing/partners/{invalid_partner_key}/tenants", + ): + with ( + patch( + "controllers.console.billing.billing.current_account_with_tenant", + return_value=(mock_account, "tenant-456"), + ), + patch("libs.login._get_user", return_value=mock_account), + ): + resource = PartnerTenants() + + # Act & Assert + with pytest.raises(BadRequest) as exc_info: + resource.put(invalid_partner_key) + assert "Invalid partner_key" in str(exc_info.value) + + def test_put_missing_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + """Test that missing click_id raises BadRequest.""" + # Arrange + partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") + + with app.test_request_context( + method="PUT", + json={}, + path=f"/billing/partners/{partner_key_encoded}/tenants", + ): + with ( + patch( + "controllers.console.billing.billing.current_account_with_tenant", + return_value=(mock_account, "tenant-456"), + ), + patch("libs.login._get_user", return_value=mock_account), + ): + resource = PartnerTenants() + + # Act & Assert + # reqparse will raise BadRequest for missing required field + with pytest.raises(BadRequest): + resource.put(partner_key_encoded) + + def test_put_billing_service_json_decode_error(self, app, mock_account, mock_billing_service, mock_decorators): + """Test handling of billing service JSON decode error. + + When billing service returns non-200 status code with invalid JSON response, + response.json() raises JSONDecodeError. This exception propagates to the controller + and should be handled by the global error handler (handle_general_exception), + which returns a 500 status code with error details. + + Note: In unit tests, when directly calling resource.put(), the exception is raised + directly. In actual Flask application, the error handler would catch it and return + a 500 response with JSON: {"code": "unknown", "message": "...", "status": 500} + """ + # Arrange + partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") + click_id = "click-id-789" + + # Simulate JSON decode error when billing service returns invalid JSON + # This happens when billing service returns non-200 with empty/invalid response body + json_decode_error = json.JSONDecodeError("Expecting value", "", 0) + mock_billing_service.sync_partner_tenants_bindings.side_effect = json_decode_error + + with app.test_request_context( + method="PUT", + json={"click_id": click_id}, + path=f"/billing/partners/{partner_key_encoded}/tenants", + ): + with ( + patch( + "controllers.console.billing.billing.current_account_with_tenant", + return_value=(mock_account, "tenant-456"), + ), + patch("libs.login._get_user", return_value=mock_account), + ): + resource = PartnerTenants() + + # Act & Assert + # JSONDecodeError will be raised from the controller + # In actual Flask app, this would be caught by handle_general_exception + # which returns: {"code": "unknown", "message": str(e), "status": 500} + with pytest.raises(json.JSONDecodeError) as exc_info: + resource.put(partner_key_encoded) + + # Verify the exception is JSONDecodeError + assert isinstance(exc_info.value, json.JSONDecodeError) + assert "Expecting value" in str(exc_info.value) + + def test_put_empty_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + """Test that empty click_id raises BadRequest.""" + # Arrange + partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") + click_id = "" + + with app.test_request_context( + method="PUT", + json={"click_id": click_id}, + path=f"/billing/partners/{partner_key_encoded}/tenants", + ): + with ( + patch( + "controllers.console.billing.billing.current_account_with_tenant", + return_value=(mock_account, "tenant-456"), + ), + patch("libs.login._get_user", return_value=mock_account), + ): + resource = PartnerTenants() + + # Act & Assert + with pytest.raises(BadRequest) as exc_info: + resource.put(partner_key_encoded) + assert "Invalid partner information" in str(exc_info.value) + + def test_put_empty_partner_key_after_decode(self, app, mock_account, mock_billing_service, mock_decorators): + """Test that empty partner_key after decode raises BadRequest.""" + # Arrange + # Base64 encode an empty string + empty_partner_key_encoded = base64.b64encode(b"").decode("utf-8") + click_id = "click-id-789" + + with app.test_request_context( + method="PUT", + json={"click_id": click_id}, + path=f"/billing/partners/{empty_partner_key_encoded}/tenants", + ): + with ( + patch( + "controllers.console.billing.billing.current_account_with_tenant", + return_value=(mock_account, "tenant-456"), + ), + patch("libs.login._get_user", return_value=mock_account), + ): + resource = PartnerTenants() + + # Act & Assert + with pytest.raises(BadRequest) as exc_info: + resource.put(empty_partner_key_encoded) + assert "Invalid partner information" in str(exc_info.value) + + def test_put_empty_user_id(self, app, mock_account, mock_billing_service, mock_decorators): + """Test that empty user id raises BadRequest.""" + # Arrange + partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") + click_id = "click-id-789" + mock_account.id = None # Empty user id + + with app.test_request_context( + method="PUT", + json={"click_id": click_id}, + path=f"/billing/partners/{partner_key_encoded}/tenants", + ): + with ( + patch( + "controllers.console.billing.billing.current_account_with_tenant", + return_value=(mock_account, "tenant-456"), + ), + patch("libs.login._get_user", return_value=mock_account), + ): + resource = PartnerTenants() + + # Act & Assert + with pytest.raises(BadRequest) as exc_info: + resource.put(partner_key_encoded) + assert "Invalid partner information" in str(exc_info.value) diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py new file mode 100644 index 0000000000..dc13143417 --- /dev/null +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -0,0 +1,236 @@ +import json +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from werkzeug.exceptions import InternalServerError + +from services.billing_service import BillingService + + +class TestBillingServiceSendRequest: + """Unit tests for BillingService._send_request method.""" + + @pytest.fixture + def mock_httpx_request(self): + """Mock httpx.request for testing.""" + with patch("services.billing_service.httpx.request") as mock_request: + yield mock_request + + @pytest.fixture + def mock_billing_config(self): + """Mock BillingService configuration.""" + with ( + patch.object(BillingService, "base_url", "https://billing-api.example.com"), + patch.object(BillingService, "secret_key", "test-secret-key"), + ): + yield + + def test_get_request_success(self, mock_httpx_request, mock_billing_config): + """Test successful GET request.""" + # Arrange + expected_response = {"result": "success", "data": {"info": "test"}} + mock_response = MagicMock() + mock_response.status_code = httpx.codes.OK + mock_response.json.return_value = expected_response + mock_httpx_request.return_value = mock_response + + # Act + result = BillingService._send_request("GET", "/test", params={"key": "value"}) + + # Assert + assert result == expected_response + mock_httpx_request.assert_called_once() + call_args = mock_httpx_request.call_args + assert call_args[0][0] == "GET" + assert call_args[0][1] == "https://billing-api.example.com/test" + assert call_args[1]["params"] == {"key": "value"} + assert call_args[1]["headers"]["Billing-Api-Secret-Key"] == "test-secret-key" + assert call_args[1]["headers"]["Content-Type"] == "application/json" + + @pytest.mark.parametrize( + "status_code", [httpx.codes.NOT_FOUND, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.BAD_REQUEST] + ) + def test_get_request_non_200_status_code(self, mock_httpx_request, mock_billing_config, status_code): + """Test GET request with non-200 status code raises ValueError.""" + # Arrange + mock_response = MagicMock() + mock_response.status_code = status_code + mock_httpx_request.return_value = mock_response + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + BillingService._send_request("GET", "/test") + assert "Unable to retrieve billing information" in str(exc_info.value) + + def test_put_request_success(self, mock_httpx_request, mock_billing_config): + """Test successful PUT request.""" + # Arrange + expected_response = {"result": "success"} + mock_response = MagicMock() + mock_response.status_code = httpx.codes.OK + mock_response.json.return_value = expected_response + mock_httpx_request.return_value = mock_response + + # Act + result = BillingService._send_request("PUT", "/test", json={"key": "value"}) + + # Assert + assert result == expected_response + call_args = mock_httpx_request.call_args + assert call_args[0][0] == "PUT" + + def test_put_request_internal_server_error(self, mock_httpx_request, mock_billing_config): + """Test PUT request with INTERNAL_SERVER_ERROR raises InternalServerError.""" + # Arrange + mock_response = MagicMock() + mock_response.status_code = httpx.codes.INTERNAL_SERVER_ERROR + mock_httpx_request.return_value = mock_response + + # Act & Assert + with pytest.raises(InternalServerError) as exc_info: + BillingService._send_request("PUT", "/test", json={"key": "value"}) + assert exc_info.value.code == 500 + assert "Unable to process billing request" in str(exc_info.value.description) + + @pytest.mark.parametrize( + "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.NOT_FOUND, httpx.codes.UNAUTHORIZED, httpx.codes.FORBIDDEN] + ) + def test_put_request_non_200_non_500(self, mock_httpx_request, mock_billing_config, status_code): + """Test PUT request with non-200 and non-500 status code raises ValueError.""" + # Arrange + mock_response = MagicMock() + mock_response.status_code = status_code + mock_httpx_request.return_value = mock_response + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + BillingService._send_request("PUT", "/test", json={"key": "value"}) + assert "Invalid arguments." in str(exc_info.value) + + @pytest.mark.parametrize("method", ["POST", "DELETE"]) + def test_non_get_non_put_request_success(self, mock_httpx_request, mock_billing_config, method): + """Test successful POST/DELETE request.""" + # Arrange + expected_response = {"result": "success"} + mock_response = MagicMock() + mock_response.status_code = httpx.codes.OK + mock_response.json.return_value = expected_response + mock_httpx_request.return_value = mock_response + + # Act + result = BillingService._send_request(method, "/test", json={"key": "value"}) + + # Assert + assert result == expected_response + call_args = mock_httpx_request.call_args + assert call_args[0][0] == method + + @pytest.mark.parametrize( + "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] + ) + def test_post_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code): + """Test POST request with non-200 status code raises ValueError.""" + # Arrange + error_response = {"detail": "Error message"} + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.json.return_value = error_response + mock_httpx_request.return_value = mock_response + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + BillingService._send_request("POST", "/test", json={"key": "value"}) + assert "Unable to send request to" in str(exc_info.value) + + @pytest.mark.parametrize( + "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] + ) + def test_delete_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code): + """Test DELETE request with non-200 status code but valid JSON response. + + DELETE doesn't check status code, so it returns the error JSON. + """ + # Arrange + error_response = {"detail": "Error message"} + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.json.return_value = error_response + mock_httpx_request.return_value = mock_response + + # Act + result = BillingService._send_request("DELETE", "/test", json={"key": "value"}) + + # Assert + assert result == error_response + + @pytest.mark.parametrize( + "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] + ) + def test_post_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code): + """Test POST request with non-200 status code raises ValueError before JSON parsing.""" + # Arrange + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = "" + mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_httpx_request.return_value = mock_response + + # Act & Assert + # POST checks status code before calling response.json(), so ValueError is raised + with pytest.raises(ValueError) as exc_info: + BillingService._send_request("POST", "/test", json={"key": "value"}) + assert "Unable to send request to" in str(exc_info.value) + + @pytest.mark.parametrize( + "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] + ) + def test_delete_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code): + """Test DELETE request with non-200 status code and invalid JSON response raises exception. + + DELETE doesn't check status code, so it calls response.json() which raises JSONDecodeError + when the response cannot be parsed as JSON (e.g., empty response). + """ + # Arrange + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = "" + mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_httpx_request.return_value = mock_response + + # Act & Assert + with pytest.raises(json.JSONDecodeError): + BillingService._send_request("DELETE", "/test", json={"key": "value"}) + + def test_retry_on_request_error(self, mock_httpx_request, mock_billing_config): + """Test that _send_request retries on httpx.RequestError.""" + # Arrange + expected_response = {"result": "success"} + mock_response = MagicMock() + mock_response.status_code = httpx.codes.OK + mock_response.json.return_value = expected_response + + # First call raises RequestError, second succeeds + mock_httpx_request.side_effect = [ + httpx.RequestError("Network error"), + mock_response, + ] + + # Act + result = BillingService._send_request("GET", "/test") + + # Assert + assert result == expected_response + assert mock_httpx_request.call_count == 2 + + def test_retry_exhausted_raises_exception(self, mock_httpx_request, mock_billing_config): + """Test that _send_request raises exception after retries are exhausted.""" + # Arrange + mock_httpx_request.side_effect = httpx.RequestError("Network error") + + # Act & Assert + with pytest.raises(httpx.RequestError): + BillingService._send_request("GET", "/test") + + # Should retry multiple times (wait=2, stop_before_delay=10 means ~5 attempts) + assert mock_httpx_request.call_count > 1 diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index 7f6bbb1f52..6014f7edc7 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -10,6 +10,7 @@ import { ProviderContextProvider } from '@/context/provider-context' import { ModalContextProvider } from '@/context/modal-context' import GotoAnything from '@/app/components/goto-anything' import Zendesk from '@/app/components/base/zendesk' +import PartnerStack from '../components/billing/partner-stack' import ReadmePanel from '@/app/components/plugins/readme-panel' import Splash from '../components/splash' @@ -26,6 +27,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
{children} + diff --git a/web/app/components/billing/partner-stack/index.tsx b/web/app/components/billing/partner-stack/index.tsx new file mode 100644 index 0000000000..84a09e260d --- /dev/null +++ b/web/app/components/billing/partner-stack/index.tsx @@ -0,0 +1,20 @@ +'use client' +import { IS_CLOUD_EDITION } from '@/config' +import type { FC } from 'react' +import React, { useEffect } from 'react' +import usePSInfo from './use-ps-info' + +const PartnerStack: FC = () => { + const { saveOrUpdate, bind } = usePSInfo() + useEffect(() => { + if (!IS_CLOUD_EDITION) + return + // Save PartnerStack info in cookie first. Because if user hasn't logged in, redirecting to login page would cause lose the partnerStack info in URL. + saveOrUpdate() + // bind PartnerStack info after user logged in + bind() + }, []) + + return null +} +export default React.memo(PartnerStack) diff --git a/web/app/components/billing/partner-stack/use-ps-info.ts b/web/app/components/billing/partner-stack/use-ps-info.ts new file mode 100644 index 0000000000..a308f7446e --- /dev/null +++ b/web/app/components/billing/partner-stack/use-ps-info.ts @@ -0,0 +1,70 @@ +import { PARTNER_STACK_CONFIG } from '@/config' +import { useBindPartnerStackInfo } from '@/service/use-billing' +import { useBoolean } from 'ahooks' +import Cookies from 'js-cookie' +import { useSearchParams } from 'next/navigation' +import { useCallback } from 'react' + +const usePSInfo = () => { + const searchParams = useSearchParams() + const psInfoInCookie = (() => { + try { + return JSON.parse(Cookies.get(PARTNER_STACK_CONFIG.cookieName) || '{}') + } + catch (e) { + console.error('Failed to parse partner stack info from cookie:', e) + return {} + } + })() + const psPartnerKey = searchParams.get('ps_partner_key') || psInfoInCookie?.partnerKey + const psClickId = searchParams.get('ps_xid') || psInfoInCookie?.clickId + const isPSChanged = psInfoInCookie?.partnerKey !== psPartnerKey || psInfoInCookie?.clickId !== psClickId + const [hasBind, { + setTrue: setBind, + }] = useBoolean(false) + const { mutateAsync } = useBindPartnerStackInfo() + // Save to top domain. cloud.dify.ai => .dify.ai + const domain = globalThis.location.hostname.replace('cloud', '') + + const saveOrUpdate = useCallback(() => { + if(!psPartnerKey || !psClickId) + return + if(!isPSChanged) + return + Cookies.set(PARTNER_STACK_CONFIG.cookieName, JSON.stringify({ + partnerKey: psPartnerKey, + clickId: psClickId, + }), { + expires: PARTNER_STACK_CONFIG.saveCookieDays, + path: '/', + domain, + }) + }, [psPartnerKey, psClickId, isPSChanged]) + + const bind = useCallback(async () => { + if (psPartnerKey && psClickId && !hasBind) { + let shouldRemoveCookie = false + try { + await mutateAsync({ + partnerKey: psPartnerKey, + clickId: psClickId, + }) + shouldRemoveCookie = true + } + catch (error: unknown) { + if((error as { status: number })?.status === 400) + shouldRemoveCookie = true + } + if (shouldRemoveCookie) + Cookies.remove(PARTNER_STACK_CONFIG.cookieName, { path: '/', domain }) + setBind() + } + }, [psPartnerKey, psClickId, mutateAsync, hasBind, setBind]) + return { + psPartnerKey, + psClickId, + saveOrUpdate, + bind, + } +} +export default usePSInfo diff --git a/web/app/signin/page.tsx b/web/app/signin/page.tsx index 60fee366df..01c790c760 100644 --- a/web/app/signin/page.tsx +++ b/web/app/signin/page.tsx @@ -2,10 +2,17 @@ import { useSearchParams } from 'next/navigation' import OneMoreStep from './one-more-step' import NormalForm from './normal-form' +import { useEffect } from 'react' +import usePSInfo from '../components/billing/partner-stack/use-ps-info' const SignIn = () => { const searchParams = useSearchParams() const step = searchParams.get('step') + const { saveOrUpdate } = usePSInfo() + + useEffect(() => { + saveOrUpdate() + }, []) if (step === 'next') return diff --git a/web/config/index.ts b/web/config/index.ts index 2f75206cc0..2555a9767e 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -449,3 +449,8 @@ export const STOP_PARAMETER_RULE: ModelParameterRule = { zh_Hans: '输入序列并按 Tab 键', }, } + +export const PARTNER_STACK_CONFIG = { + cookieName: 'partner_stack_info', + saveCookieDays: 90, +} diff --git a/web/service/use-billing.ts b/web/service/use-billing.ts new file mode 100644 index 0000000000..b48a75eab0 --- /dev/null +++ b/web/service/use-billing.ts @@ -0,0 +1,19 @@ +import { useMutation } from '@tanstack/react-query' +import { put } from './base' + +const NAME_SPACE = 'billing' + +export const useBindPartnerStackInfo = () => { + return useMutation({ + mutationKey: [NAME_SPACE, 'bind-partner-stack'], + mutationFn: (data: { partnerKey: string; clickId: string }) => { + return put(`/billing/partners/${data.partnerKey}/tenants`, { + body: { + click_id: data.clickId, + }, + }, { + silent: true, + }) + }, + }) +}