feat: marketplace and oauth fixes (#35509)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Junyan Chin 2026-04-24 15:53:14 +08:00 committed by fatelei
parent 775f9212f3
commit df28c99817
No known key found for this signature in database
GPG Key ID: 2F91DA05646F4EED
48 changed files with 1604 additions and 739 deletions

View File

@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}

View File

@ -11,7 +11,7 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for Creators Platform integration
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable Creators Platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client ID for Creators Platform integration",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@ -1379,6 +1400,7 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
CreatorsPlatformConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,

View File

@ -692,6 +692,32 @@ class AppExportApi(Resource):
return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
class AppPublishToCreatorsPlatformApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")

View File

@ -0,0 +1,41 @@
"""
Helper module for Creators Platform integration.
Provides functionality to upload DSL files to the Creators Platform
and generate redirect URLs with OAuth authorization codes.
"""
import logging
from urllib.parse import urlencode
import httpx
from yarl import URL
from configs import dify_config
logger = logging.getLogger(__name__)
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
response.raise_for_status()
data = response.json()
claim_code = data.get("data", {}).get("claim_code")
if not claim_code:
raise ValueError("Creators Platform did not return a valid claim_code")
return claim_code
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
params: dict[str, str] = {"dsl_claim_code": claim_code}
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
if client_id:
from services.oauth_server import OAuthServerService
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
params["oauth_code"] = oauth_code
return f"{base_url}?{urlencode(params)}"

View File

@ -14,23 +14,23 @@ from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
class EncryptionError(Exception):
"""Encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
class SystemEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
A simple parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
This class provides methods to encrypt and decrypt parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
Initialize the encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
def encrypt_params(self, params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Encrypt parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
EncryptionError: If encryption fails
ValueError: If params is invalid
"""
try:
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
raise EncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Decrypt parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
EncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
if not isinstance(params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
return params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
raise EncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
"""
Create an OAuth encrypter instance.
Create an encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
return SystemEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None
_encrypter: SystemEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
def get_system_encrypter() -> SystemEncrypter:
"""
Get the global OAuth encrypter instance.
Get the global encrypter instance.
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
global _encrypter
if _encrypter is None:
_encrypter = SystemEncrypter()
return _encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
def encrypt_system_params(params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Encrypt parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
params: Parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
return get_system_encrypter().encrypt_params(params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Decrypt parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
return get_system_encrypter().decrypt_params(encrypted_data)

View File

@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
trial_models: list[str] = []
enable_creators_platform: bool = False
enable_trial_app: bool = False
enable_explore_banner: bool = False
@ -241,6 +242,9 @@ class FeatureService:
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
system_features.enable_creators_platform = True
return system_features
@classmethod

View File

@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID
@ -521,7 +521,7 @@ class BuiltinToolManageService:
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from core.trigger.entities.api_entities import (
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
@ -635,7 +635,7 @@ class TriggerProviderService:
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -0,0 +1,106 @@
"""Tests for the Creators Platform helper module."""
from unittest.mock import MagicMock, patch
import httpx
import pytest
from yarl import URL
@pytest.fixture(autouse=True)
def _patch_creators_url(monkeypatch):
"""Patch the module-level creators_platform_api_url for all tests."""
monkeypatch.setattr(
"core.helper.creators.creators_platform_api_url",
URL("https://creators.example.com"),
)
class TestUploadDSL:
@patch("core.helper.creators.httpx.post")
def test_returns_claim_code(self, mock_post):
mock_response = MagicMock(spec=httpx.Response)
mock_response.json.return_value = {"data": {"claim_code": "abc123"}}
mock_response.raise_for_status = MagicMock()
mock_post.return_value = mock_response
from core.helper.creators import upload_dsl
result = upload_dsl(b"app: demo", "demo.yaml")
assert result == "abc123"
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
assert "anonymous-upload" in call_kwargs.args[0]
assert call_kwargs.kwargs["timeout"] == 30
@patch("core.helper.creators.httpx.post")
def test_raises_on_missing_claim_code(self, mock_post):
mock_response = MagicMock(spec=httpx.Response)
mock_response.json.return_value = {"data": {}}
mock_response.raise_for_status = MagicMock()
mock_post.return_value = mock_response
from core.helper.creators import upload_dsl
with pytest.raises(ValueError, match="claim_code"):
upload_dsl(b"app: demo")
@patch("core.helper.creators.httpx.post")
def test_raises_on_http_error(self, mock_post):
mock_response = MagicMock(spec=httpx.Response)
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Server Error",
request=MagicMock(),
response=MagicMock(),
)
mock_post.return_value = mock_response
from core.helper.creators import upload_dsl
with pytest.raises(httpx.HTTPStatusError):
upload_dsl(b"app: demo")
class TestGetRedirectUrl:
@patch("core.helper.creators.dify_config")
def test_without_oauth_client_id(self, mock_config):
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
from core.helper.creators import get_redirect_url
url = get_redirect_url("user-1", "claim-abc")
assert "dsl_claim_code=claim-abc" in url
assert "oauth_code" not in url
assert url.startswith("https://creators.example.com")
@patch("core.helper.creators.dify_config")
def test_with_oauth_client_id(self, mock_config):
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz"
with patch(
"services.oauth_server.OAuthServerService.sign_oauth_authorization_code",
return_value="oauth-code-123",
) as mock_sign:
from core.helper.creators import get_redirect_url
url = get_redirect_url("user-1", "claim-abc")
mock_sign.assert_called_once_with("client-xyz", "user-1")
assert "dsl_claim_code=claim-abc" in url
assert "oauth_code=oauth-code-123" in url
@patch("core.helper.creators.dify_config")
def test_strips_trailing_slash(self, mock_config):
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/"
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
from core.helper.creators import get_redirect_url
url = get_redirect_url("user-1", "claim-abc")
assert url.startswith("https://creators.example.com?")
assert "creators.example.com/?" not in url

View File

@ -2,50 +2,50 @@ from __future__ import annotations
import pytest
from core.tools.utils import system_oauth_encryption as oauth_encryption
from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter
from core.tools.utils import system_encryption as encryption
from core.tools.utils.system_encryption import EncryptionError, SystemEncrypter
def test_system_oauth_encrypter_roundtrip():
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
def test_system_encrypter_roundtrip():
encrypter = SystemEncrypter(secret_key="test-secret")
payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"}
encrypted = encrypter.encrypt_oauth_params(payload)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(payload)
decrypted = encrypter.decrypt_params(encrypted)
assert encrypted
assert dict(decrypted) == payload
def test_system_oauth_encrypter_decrypt_validates_input():
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
def test_system_encrypter_decrypt_validates_input():
encrypter = SystemEncrypter(secret_key="test-secret")
with pytest.raises(ValueError, match="must be a string"):
encrypter.decrypt_oauth_params(123) # type: ignore[arg-type]
encrypter.decrypt_params(123) # type: ignore[arg-type]
with pytest.raises(ValueError, match="cannot be empty"):
encrypter.decrypt_oauth_params("")
encrypter.decrypt_params("")
def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext():
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
def test_system_encrypter_raises_error_for_invalid_ciphertext():
encrypter = SystemEncrypter(secret_key="test-secret")
with pytest.raises(OAuthEncryptionError, match="Decryption failed"):
encrypter.decrypt_oauth_params("not-base64")
with pytest.raises(EncryptionError, match="Decryption failed"):
encrypter.decrypt_params("not-base64")
def test_system_oauth_helpers_use_global_cached_instance(monkeypatch):
monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None)
monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret")
def test_system_helpers_use_global_cached_instance(monkeypatch):
monkeypatch.setattr(encryption, "_encrypter", None)
monkeypatch.setattr("core.tools.utils.system_encryption.dify_config.SECRET_KEY", "global-secret")
first = oauth_encryption.get_system_oauth_encrypter()
second = oauth_encryption.get_system_oauth_encrypter()
first = encryption.get_system_encrypter()
second = encryption.get_system_encrypter()
assert first is second
encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"})
assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"}
encrypted = encryption.encrypt_system_params({"k": "v"})
assert encryption.decrypt_system_params(encrypted) == {"k": "v"}
def test_create_system_oauth_encrypter_factory():
encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret")
assert isinstance(encrypter, SystemOAuthEncrypter)
def test_create_system_encrypter_factory():
encrypter = encryption.create_system_encrypter(secret_key="factory-secret")
assert isinstance(encrypter, SystemEncrypter)

View File

@ -694,7 +694,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified(
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
mocker.patch(
"services.trigger.trigger_provider_service.decrypt_system_oauth_params",
"services.trigger.trigger_provider_service.decrypt_system_params",
return_value={"client_id": "system"},
)
@ -716,7 +716,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails(
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
mocker.patch(
"services.trigger.trigger_provider_service.decrypt_system_oauth_params",
"services.trigger.trigger_provider_service.decrypt_system_params",
side_effect=RuntimeError("bad data"),
)

View File

@ -280,7 +280,7 @@ class TestGetOauthClient:
assert result == {"client_id": "id", "client_secret": "secret"}
@patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"})
@patch(f"{MODULE}.decrypt_system_params", return_value={"sys_key": "sys_val"})
@patch(f"{MODULE}.PluginService")
@patch(f"{MODULE}.create_provider_encrypter")
@patch(f"{MODULE}.ToolManager")

View File

@ -0,0 +1,619 @@
import base64
import hashlib
from unittest.mock import patch
import pytest
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad
from core.tools.utils.system_encryption import (
EncryptionError,
SystemEncrypter,
create_system_encrypter,
decrypt_system_params,
encrypt_system_params,
get_system_encrypter,
)
class TestSystemEncrypter:
"""Test cases for SystemEncrypter class"""
def test_init_with_secret_key(self):
"""Test initialization with provided secret key"""
secret_key = "test_secret_key"
encrypter = SystemEncrypter(secret_key=secret_key)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
def test_init_with_none_secret_key(self):
"""Test initialization with None secret key falls back to config"""
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = SystemEncrypter(secret_key=None)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
def test_init_with_empty_secret_key(self):
"""Test initialization with empty secret key"""
encrypter = SystemEncrypter(secret_key="")
expected_key = hashlib.sha256(b"").digest()
assert encrypter.key == expected_key
def test_init_without_secret_key_uses_config(self):
"""Test initialization without secret key uses config"""
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "default_secret"
encrypter = SystemEncrypter()
expected_key = hashlib.sha256(b"default_secret").digest()
assert encrypter.key == expected_key
def test_encrypt_params_basic(self):
"""Test basic parameters encryption"""
encrypter = SystemEncrypter("test_secret")
params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_params(params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
# Should be valid base64
try:
base64.b64decode(encrypted)
except Exception:
pytest.fail("Encrypted result is not valid base64")
def test_encrypt_params_empty_dict(self):
"""Test encryption with empty dictionary"""
encrypter = SystemEncrypter("test_secret")
params = {}
encrypted = encrypter.encrypt_params(params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_params_complex_data(self):
"""Test encryption with complex data structures"""
encrypter = SystemEncrypter("test_secret")
params = {
"client_id": "test_id",
"client_secret": "test_secret",
"scopes": ["read", "write", "admin"],
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
"numeric_value": 42,
"boolean_value": False,
"null_value": None,
}
encrypted = encrypter.encrypt_params(params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_params_unicode_data(self):
"""Test encryption with unicode data"""
encrypter = SystemEncrypter("test_secret")
params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
encrypted = encrypter.encrypt_params(params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_params_large_data(self):
"""Test encryption with large data"""
encrypter = SystemEncrypter("test_secret")
params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_params(params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_params_invalid_input(self):
"""Test encryption with invalid input types"""
encrypter = SystemEncrypter("test_secret")
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_params(None)
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_params("not_a_dict")
def test_decrypt_params_basic(self):
"""Test basic parameters decryption"""
encrypter = SystemEncrypter("test_secret")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_params_empty_dict(self):
"""Test decryption of empty dictionary"""
encrypter = SystemEncrypter("test_secret")
original_params = {}
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_params_complex_data(self):
"""Test decryption with complex data structures"""
encrypter = SystemEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"scopes": ["read", "write", "admin"],
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
"numeric_value": 42,
"boolean_value": False,
"null_value": None,
}
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_params_unicode_data(self):
"""Test decryption with unicode data"""
encrypter = SystemEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"description": "This is a test case 🚀",
}
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_params_large_data(self):
"""Test decryption with large data"""
encrypter = SystemEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_params_invalid_base64(self):
"""Test decryption with invalid base64 data"""
encrypter = SystemEncrypter("test_secret")
with pytest.raises(EncryptionError):
encrypter.decrypt_params("invalid_base64!")
def test_decrypt_params_empty_string(self):
"""Test decryption with empty string"""
encrypter = SystemEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
def test_decrypt_params_non_string_input(self):
"""Test decryption with non-string input"""
encrypter = SystemEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_params(None)
assert "encrypted_data must be a string" in str(exc_info.value)
def test_decrypt_params_too_short_data(self):
"""Test decryption with too short encrypted data"""
encrypter = SystemEncrypter("test_secret")
# Create data that's too short (less than 32 bytes)
short_data = base64.b64encode(b"short").decode()
with pytest.raises(EncryptionError) as exc_info:
encrypter.decrypt_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
def test_decrypt_params_corrupted_data(self):
"""Test decryption with corrupted data"""
encrypter = SystemEncrypter("test_secret")
# Create corrupted data (valid base64 but invalid encrypted content)
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
with pytest.raises(EncryptionError):
encrypter.decrypt_params(corrupted_data)
def test_decrypt_params_wrong_key(self):
"""Test decryption with wrong key"""
encrypter1 = SystemEncrypter("secret1")
encrypter2 = SystemEncrypter("secret2")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter1.encrypt_params(original_params)
with pytest.raises(EncryptionError):
encrypter2.decrypt_params(encrypted)
def test_encryption_decryption_consistency(self):
"""Test that encryption and decryption are consistent"""
encrypter = SystemEncrypter("test_secret")
test_cases = [
{},
{"simple": "value"},
{"client_id": "id", "client_secret": "secret"},
{"complex": {"nested": {"deep": "value"}}},
{"unicode": "test 🚀"},
{"numbers": 42, "boolean": True, "null": None},
{"array": [1, 2, 3, "four", {"five": 5}]},
]
for original_params in test_cases:
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_encryption_randomness(self):
"""Test that encryption produces different results for same input"""
encrypter = SystemEncrypter("test_secret")
params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter.encrypt_params(params)
encrypted2 = encrypter.encrypt_params(params)
# Should be different due to random IV
assert encrypted1 != encrypted2
# But should decrypt to same result
decrypted1 = encrypter.decrypt_params(encrypted1)
decrypted2 = encrypter.decrypt_params(encrypted2)
assert decrypted1 == decrypted2 == params
def test_different_secret_keys_produce_different_results(self):
"""Test that different secret keys produce different encrypted results"""
encrypter1 = SystemEncrypter("secret1")
encrypter2 = SystemEncrypter("secret2")
params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter1.encrypt_params(params)
encrypted2 = encrypter2.encrypt_params(params)
# Should produce different encrypted results
assert encrypted1 != encrypted2
# But each should decrypt correctly with its own key
decrypted1 = encrypter1.decrypt_params(encrypted1)
decrypted2 = encrypter2.decrypt_params(encrypted2)
assert decrypted1 == decrypted2 == params
@patch("core.tools.utils.system_encryption.get_random_bytes")
def test_encrypt_params_crypto_error(self, mock_get_random_bytes):
"""Test encryption when crypto operation fails"""
mock_get_random_bytes.side_effect = Exception("Crypto error")
encrypter = SystemEncrypter("test_secret")
params = {"client_id": "test_id"}
with pytest.raises(EncryptionError) as exc_info:
encrypter.encrypt_params(params)
assert "Encryption failed" in str(exc_info.value)
@patch("core.tools.utils.system_encryption.TypeAdapter")
def test_encrypt_params_serialization_error(self, mock_type_adapter):
"""Test encryption when JSON serialization fails"""
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
encrypter = SystemEncrypter("test_secret")
params = {"client_id": "test_id"}
with pytest.raises(EncryptionError) as exc_info:
encrypter.encrypt_params(params)
assert "Encryption failed" in str(exc_info.value)
def test_decrypt_params_invalid_json(self):
"""Test decryption with invalid JSON data"""
encrypter = SystemEncrypter("test_secret")
# Create valid encrypted data but with invalid JSON content
iv = get_random_bytes(16)
cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
invalid_json = b"invalid json content"
padded_data = pad(invalid_json, AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
combined = iv + encrypted_data
encoded = base64.b64encode(combined).decode()
with pytest.raises(EncryptionError):
encrypter.decrypt_params(encoded)
def test_key_derivation_consistency(self):
"""Test that key derivation is consistent"""
secret_key = "test_secret"
encrypter1 = SystemEncrypter(secret_key)
encrypter2 = SystemEncrypter(secret_key)
assert encrypter1.key == encrypter2.key
# Keys should be 32 bytes (256 bits)
assert len(encrypter1.key) == 32
class TestFactoryFunctions:
"""Test cases for factory functions"""
def test_create_system_encrypter_with_secret(self):
"""Test factory function with secret key"""
secret_key = "test_secret"
encrypter = create_system_encrypter(secret_key)
assert isinstance(encrypter, SystemEncrypter)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
def test_create_system_encrypter_without_secret(self):
"""Test factory function without secret key"""
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_encrypter()
assert isinstance(encrypter, SystemEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
def test_create_system_encrypter_with_none_secret(self):
"""Test factory function with None secret key"""
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_encrypter(None)
assert isinstance(encrypter, SystemEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
class TestGlobalEncrypterInstance:
"""Test cases for global encrypter instance"""
def test_get_system_encrypter_singleton(self):
"""Test that get_system_encrypter returns singleton instance"""
# Clear the global instance first
import core.tools.utils.system_encryption
core.tools.utils.system_encryption._encrypter = None
encrypter1 = get_system_encrypter()
encrypter2 = get_system_encrypter()
assert encrypter1 is encrypter2
assert isinstance(encrypter1, SystemEncrypter)
def test_get_system_encrypter_uses_config(self):
"""Test that global encrypter uses config"""
# Clear the global instance first
import core.tools.utils.system_encryption
core.tools.utils.system_encryption._encrypter = None
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "global_secret"
encrypter = get_system_encrypter()
expected_key = hashlib.sha256(b"global_secret").digest()
assert encrypter.key == expected_key
class TestConvenienceFunctions:
"""Test cases for convenience functions"""
def test_encrypt_system_params(self):
"""Test encrypt_system_params convenience function"""
params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_params(params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_decrypt_system_params(self):
"""Test decrypt_system_params convenience function"""
params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_params(params)
decrypted = decrypt_system_params(encrypted)
assert decrypted == params
def test_convenience_functions_consistency(self):
"""Test that convenience functions work consistently"""
test_cases = [
{},
{"simple": "value"},
{"client_id": "id", "client_secret": "secret"},
{"complex": {"nested": {"deep": "value"}}},
{"unicode": "test 🚀"},
{"numbers": 42, "boolean": True, "null": None},
]
for original_params in test_cases:
encrypted = encrypt_system_params(original_params)
decrypted = decrypt_system_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_convenience_functions_with_errors(self):
"""Test convenience functions with error conditions"""
# Test encryption with invalid input
with pytest.raises(Exception): # noqa: B017
encrypt_system_params(None)
# Test decryption with invalid input
with pytest.raises(ValueError):
decrypt_system_params("")
with pytest.raises(ValueError):
decrypt_system_params(None)
class TestErrorHandling:
"""Test cases for error handling"""
def test_encryption_error_inheritance(self):
"""Test that EncryptionError is a proper exception"""
error = EncryptionError("Test error")
assert isinstance(error, Exception)
assert str(error) == "Test error"
def test_encryption_error_with_cause(self):
"""Test EncryptionError with cause"""
original_error = ValueError("Original error")
error = EncryptionError("Wrapper error")
error.__cause__ = original_error
assert isinstance(error, Exception)
assert str(error) == "Wrapper error"
assert error.__cause__ is original_error
def test_error_messages_are_informative(self):
"""Test that error messages are informative"""
encrypter = SystemEncrypter("test_secret")
# Test empty string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
# Test non-string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
# Test invalid format error
short_data = base64.b64encode(b"short").decode()
with pytest.raises(EncryptionError) as exc_info:
encrypter.decrypt_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
class TestEdgeCases:
"""Test cases for edge cases and boundary conditions"""
def test_very_long_secret_key(self):
"""Test with very long secret key"""
long_secret = "x" * 10000
encrypter = SystemEncrypter(long_secret)
# Key should still be 32 bytes due to SHA-256
assert len(encrypter.key) == 32
# Should still work normally
params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_params(params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == params
def test_special_characters_in_secret_key(self):
"""Test with special characters in secret key"""
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
encrypter = SystemEncrypter(special_secret)
params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_params(params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == params
def test_empty_values_in_params(self):
"""Test with empty values in params"""
params = {
"client_id": "",
"client_secret": "",
"empty_dict": {},
"empty_list": [],
"empty_string": "",
"zero": 0,
"false": False,
"none": None,
}
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == params
def test_deeply_nested_params(self):
"""Test with deeply nested params"""
params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == params
def test_params_with_all_json_types(self):
"""Test with all JSON-supported data types"""
params = {
"string": "test_string",
"integer": 42,
"float": 3.14159,
"boolean_true": True,
"boolean_false": False,
"null_value": None,
"empty_string": "",
"array": [1, "two", 3.0, True, False, None],
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
}
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == params
class TestPerformance:
"""Test cases for performance considerations"""
def test_large_params(self):
"""Test with large params"""
large_value = "x" * 100000 # 100KB
params = {"client_id": "test_id", "large_data": large_value}
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == params
def test_many_fields_params(self):
"""Test with many fields in params"""
params = {f"field_{i}": f"value_{i}" for i in range(1000)}
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == params
def test_repeated_encryption_decryption(self):
"""Test repeated encryption and decryption operations"""
encrypter = SystemEncrypter("test_secret")
params = {"client_id": "test_id", "client_secret": "test_secret"}
# Test multiple rounds of encryption/decryption
for i in range(100):
encrypted = encrypter.encrypt_params(params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == params

View File

@ -1,619 +0,0 @@
import base64
import hashlib
from unittest.mock import patch
import pytest
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad
from core.tools.utils.system_oauth_encryption import (
OAuthEncryptionError,
SystemOAuthEncrypter,
create_system_oauth_encrypter,
decrypt_system_oauth_params,
encrypt_system_oauth_params,
get_system_oauth_encrypter,
)
class TestSystemOAuthEncrypter:
"""Test cases for SystemOAuthEncrypter class"""
def test_init_with_secret_key(self):
"""Test initialization with provided secret key"""
secret_key = "test_secret_key"
encrypter = SystemOAuthEncrypter(secret_key=secret_key)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
def test_init_with_none_secret_key(self):
"""Test initialization with None secret key falls back to config"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = SystemOAuthEncrypter(secret_key=None)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
def test_init_with_empty_secret_key(self):
"""Test initialization with empty secret key"""
encrypter = SystemOAuthEncrypter(secret_key="")
expected_key = hashlib.sha256(b"").digest()
assert encrypter.key == expected_key
def test_init_without_secret_key_uses_config(self):
"""Test initialization without secret key uses config"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "default_secret"
encrypter = SystemOAuthEncrypter()
expected_key = hashlib.sha256(b"default_secret").digest()
assert encrypter.key == expected_key
def test_encrypt_oauth_params_basic(self):
"""Test basic OAuth parameters encryption"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
# Should be valid base64
try:
base64.b64decode(encrypted)
except Exception:
pytest.fail("Encrypted result is not valid base64")
def test_encrypt_oauth_params_empty_dict(self):
"""Test encryption with empty dictionary"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_complex_data(self):
"""Test encryption with complex data structures"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"scopes": ["read", "write", "admin"],
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
"numeric_value": 42,
"boolean_value": False,
"null_value": None,
}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_unicode_data(self):
"""Test encryption with unicode data"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_large_data(self):
"""Test encryption with large data"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_invalid_input(self):
"""Test encryption with invalid input types"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params(None)
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params("not_a_dict")
def test_decrypt_oauth_params_basic(self):
"""Test basic OAuth parameters decryption"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_empty_dict(self):
"""Test decryption of empty dictionary"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_complex_data(self):
"""Test decryption with complex data structures"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"scopes": ["read", "write", "admin"],
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
"numeric_value": 42,
"boolean_value": False,
"null_value": None,
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_unicode_data(self):
"""Test decryption with unicode data"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"description": "This is a test case 🚀",
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_large_data(self):
"""Test decryption with large data"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_invalid_base64(self):
"""Test decryption with invalid base64 data"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params("invalid_base64!")
def test_decrypt_oauth_params_empty_string(self):
"""Test decryption with empty string"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
def test_decrypt_oauth_params_non_string_input(self):
"""Test decryption with non-string input"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(None)
assert "encrypted_data must be a string" in str(exc_info.value)
def test_decrypt_oauth_params_too_short_data(self):
"""Test decryption with too short encrypted data"""
encrypter = SystemOAuthEncrypter("test_secret")
# Create data that's too short (less than 32 bytes)
short_data = base64.b64encode(b"short").decode()
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.decrypt_oauth_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
def test_decrypt_oauth_params_corrupted_data(self):
"""Test decryption with corrupted data"""
encrypter = SystemOAuthEncrypter("test_secret")
# Create corrupted data (valid base64 but invalid encrypted content)
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params(corrupted_data)
def test_decrypt_oauth_params_wrong_key(self):
"""Test decryption with wrong key"""
encrypter1 = SystemOAuthEncrypter("secret1")
encrypter2 = SystemOAuthEncrypter("secret2")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter1.encrypt_oauth_params(original_params)
with pytest.raises(OAuthEncryptionError):
encrypter2.decrypt_oauth_params(encrypted)
def test_encryption_decryption_consistency(self):
"""Test that encryption and decryption are consistent"""
encrypter = SystemOAuthEncrypter("test_secret")
test_cases = [
{},
{"simple": "value"},
{"client_id": "id", "client_secret": "secret"},
{"complex": {"nested": {"deep": "value"}}},
{"unicode": "test 🚀"},
{"numbers": 42, "boolean": True, "null": None},
{"array": [1, 2, 3, "four", {"five": 5}]},
]
for original_params in test_cases:
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_encryption_randomness(self):
"""Test that encryption produces different results for same input"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
# Should be different due to random IV
assert encrypted1 != encrypted2
# But should decrypt to same result
decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
assert decrypted1 == decrypted2 == oauth_params
def test_different_secret_keys_produce_different_results(self):
"""Test that different secret keys produce different encrypted results"""
encrypter1 = SystemOAuthEncrypter("secret1")
encrypter2 = SystemOAuthEncrypter("secret2")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
# Should produce different encrypted results
assert encrypted1 != encrypted2
# But each should decrypt correctly with its own key
decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
assert decrypted1 == decrypted2 == oauth_params
@patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
"""Test encryption when crypto operation fails"""
mock_get_random_bytes.side_effect = Exception("Crypto error")
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id"}
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.encrypt_oauth_params(oauth_params)
assert "Encryption failed" in str(exc_info.value)
@patch("core.tools.utils.system_oauth_encryption.TypeAdapter")
def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
"""Test encryption when JSON serialization fails"""
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id"}
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.encrypt_oauth_params(oauth_params)
assert "Encryption failed" in str(exc_info.value)
def test_decrypt_oauth_params_invalid_json(self):
"""Test decryption with invalid JSON data"""
encrypter = SystemOAuthEncrypter("test_secret")
# Create valid encrypted data but with invalid JSON content
iv = get_random_bytes(16)
cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
invalid_json = b"invalid json content"
padded_data = pad(invalid_json, AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
combined = iv + encrypted_data
encoded = base64.b64encode(combined).decode()
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params(encoded)
def test_key_derivation_consistency(self):
"""Test that key derivation is consistent"""
secret_key = "test_secret"
encrypter1 = SystemOAuthEncrypter(secret_key)
encrypter2 = SystemOAuthEncrypter(secret_key)
assert encrypter1.key == encrypter2.key
# Keys should be 32 bytes (256 bits)
assert len(encrypter1.key) == 32
class TestFactoryFunctions:
"""Test cases for factory functions"""
def test_create_system_oauth_encrypter_with_secret(self):
"""Test factory function with secret key"""
secret_key = "test_secret"
encrypter = create_system_oauth_encrypter(secret_key)
assert isinstance(encrypter, SystemOAuthEncrypter)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
def test_create_system_oauth_encrypter_without_secret(self):
"""Test factory function without secret key"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_oauth_encrypter()
assert isinstance(encrypter, SystemOAuthEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
def test_create_system_oauth_encrypter_with_none_secret(self):
"""Test factory function with None secret key"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_oauth_encrypter(None)
assert isinstance(encrypter, SystemOAuthEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
class TestGlobalEncrypterInstance:
"""Test cases for global encrypter instance"""
def test_get_system_oauth_encrypter_singleton(self):
"""Test that get_system_oauth_encrypter returns singleton instance"""
# Clear the global instance first
import core.tools.utils.system_oauth_encryption
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
encrypter1 = get_system_oauth_encrypter()
encrypter2 = get_system_oauth_encrypter()
assert encrypter1 is encrypter2
assert isinstance(encrypter1, SystemOAuthEncrypter)
def test_get_system_oauth_encrypter_uses_config(self):
"""Test that global encrypter uses config"""
# Clear the global instance first
import core.tools.utils.system_oauth_encryption
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "global_secret"
encrypter = get_system_oauth_encrypter()
expected_key = hashlib.sha256(b"global_secret").digest()
assert encrypter.key == expected_key
class TestConvenienceFunctions:
"""Test cases for convenience functions"""
def test_encrypt_system_oauth_params(self):
"""Test encrypt_system_oauth_params convenience function"""
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_decrypt_system_oauth_params(self):
"""Test decrypt_system_oauth_params convenience function"""
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_oauth_params(oauth_params)
decrypted = decrypt_system_oauth_params(encrypted)
assert decrypted == oauth_params
def test_convenience_functions_consistency(self):
"""Test that convenience functions work consistently"""
test_cases = [
{},
{"simple": "value"},
{"client_id": "id", "client_secret": "secret"},
{"complex": {"nested": {"deep": "value"}}},
{"unicode": "test 🚀"},
{"numbers": 42, "boolean": True, "null": None},
]
for original_params in test_cases:
encrypted = encrypt_system_oauth_params(original_params)
decrypted = decrypt_system_oauth_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_convenience_functions_with_errors(self):
"""Test convenience functions with error conditions"""
# Test encryption with invalid input
with pytest.raises(Exception): # noqa: B017
encrypt_system_oauth_params(None)
# Test decryption with invalid input
with pytest.raises(ValueError):
decrypt_system_oauth_params("")
with pytest.raises(ValueError):
decrypt_system_oauth_params(None)
class TestErrorHandling:
"""Test cases for error handling"""
def test_oauth_encryption_error_inheritance(self):
"""Test that OAuthEncryptionError is a proper exception"""
error = OAuthEncryptionError("Test error")
assert isinstance(error, Exception)
assert str(error) == "Test error"
def test_oauth_encryption_error_with_cause(self):
"""Test OAuthEncryptionError with cause"""
original_error = ValueError("Original error")
error = OAuthEncryptionError("Wrapper error")
error.__cause__ = original_error
assert isinstance(error, Exception)
assert str(error) == "Wrapper error"
assert error.__cause__ is original_error
def test_error_messages_are_informative(self):
"""Test that error messages are informative"""
encrypter = SystemOAuthEncrypter("test_secret")
# Test empty string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
# Test non-string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
# Test invalid format error
short_data = base64.b64encode(b"short").decode()
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.decrypt_oauth_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
class TestEdgeCases:
"""Test cases for edge cases and boundary conditions"""
def test_very_long_secret_key(self):
"""Test with very long secret key"""
long_secret = "x" * 10000
encrypter = SystemOAuthEncrypter(long_secret)
# Key should still be 32 bytes due to SHA-256
assert len(encrypter.key) == 32
# Should still work normally
oauth_params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_special_characters_in_secret_key(self):
"""Test with special characters in secret key"""
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
encrypter = SystemOAuthEncrypter(special_secret)
oauth_params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_empty_values_in_oauth_params(self):
"""Test with empty values in oauth params"""
oauth_params = {
"client_id": "",
"client_secret": "",
"empty_dict": {},
"empty_list": [],
"empty_string": "",
"zero": 0,
"false": False,
"none": None,
}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_deeply_nested_oauth_params(self):
"""Test with deeply nested oauth params"""
oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_oauth_params_with_all_json_types(self):
"""Test with all JSON-supported data types"""
oauth_params = {
"string": "test_string",
"integer": 42,
"float": 3.14159,
"boolean_true": True,
"boolean_false": False,
"null_value": None,
"empty_string": "",
"array": [1, "two", 3.0, True, False, None],
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
class TestPerformance:
"""Test cases for performance considerations"""
def test_large_oauth_params(self):
"""Test with large oauth params"""
large_value = "x" * 100000 # 100KB
oauth_params = {"client_id": "test_id", "large_data": large_value}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_many_fields_oauth_params(self):
"""Test with many fields in oauth params"""
oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_repeated_encryption_decryption(self):
"""Test repeated encryption and decryption operations"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
# Test multiple rounds of encryption/decryption
for i in range(100):
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params

View File

@ -1467,6 +1467,11 @@ ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
FORCE_VERIFYING_SIGNATURE=true
ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES=true

View File

@ -629,6 +629,9 @@ x-shared-env: &shared-api-worker-env
ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}}
MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai}
CREATORS_PLATFORM_FEATURES_ENABLED: ${CREATORS_PLATFORM_FEATURES_ENABLED:-true}
CREATORS_PLATFORM_API_URL: ${CREATORS_PLATFORM_API_URL:-https://creators.dify.ai}
CREATORS_PLATFORM_OAUTH_CLIENT_ID: ${CREATORS_PLATFORM_OAUTH_CLIENT_ID:-}
FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true}
ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES: ${ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES:-true}
PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024}

View File

@ -16,9 +16,9 @@ import { useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import Loading from '@/app/components/base/loading'
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
import { setOAuthPendingRedirect } from '@/app/signin/utils/post-login-redirect'
import { useRouter, useSearchParams } from '@/next/navigation'
import { isLegacyBase401, userProfileQueryOptions } from '@/service/use-common'
import { isLegacyBase401, useLogout, userProfileQueryOptions } from '@/service/use-common'
import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth'
function buildReturnUrl(pathname: string, search: string) {
@ -73,14 +73,17 @@ export default function OAuthAuthorize() {
const userProfile = userProfileResp?.profile
const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri)
const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp()
const { mutateAsync: logout } = useLogout()
const hasNotifiedRef = useRef(false)
const isLoading = isOAuthLoading || isProfileLoading
const onLoginSwitchClick = () => {
const onLoginSwitchClick = async () => {
try {
const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`)
setPostLoginRedirect(returnUrl)
router.push('/signin')
const returnUrl = buildReturnUrl('/account/oauth/authorize', `?${searchParams.toString()}`)
setOAuthPendingRedirect(returnUrl)
if (isLoggedIn)
await logout()
router.push(`/signin?redirect_url=${encodeURIComponent(returnUrl)}`)
}
catch {
router.push('/signin')

View File

@ -85,7 +85,7 @@ export const AppInitializer = ({
return
}
const redirectUrl = resolvePostLoginRedirect()
const redirectUrl = resolvePostLoginRedirect(searchParams)
if (redirectUrl) {
location.replace(redirectUrl)
return

View File

@ -80,8 +80,11 @@ vi.mock('@/service/explore', () => ({
fetchInstalledAppList: (...args: unknown[]) => mockFetchInstalledAppList(...args),
}))
const mockPublishToCreatorsPlatform = vi.fn()
vi.mock('@/service/apps', () => ({
fetchAppDetailDirect: (...args: unknown[]) => mockFetchAppDetailDirect(...args),
publishToCreatorsPlatform: (...args: unknown[]) => mockPublishToCreatorsPlatform(...args),
}))
vi.mock('@/service/use-workflow', () => ({
@ -434,6 +437,76 @@ describe('AppPublisher', () => {
})
})
it('should show marketplace button and open redirect URL on success', async () => {
mockPublishToCreatorsPlatform.mockResolvedValue({ redirect_url: 'https://marketplace.example.com/publish?code=abc' })
const windowOpenSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
renderWithSystemFeatures(
<AppPublisher
publishedAt={Date.now()}
onPublish={mockOnPublish}
/>,
{ systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } },
)
fireEvent.click(screen.getByText('common.publish'))
fireEvent.click(screen.getByText('common.publishToMarketplace'))
await waitFor(() => {
expect(mockPublishToCreatorsPlatform).toHaveBeenCalledWith({ appID: 'app-1' })
expect(windowOpenSpy).toHaveBeenCalledWith('https://marketplace.example.com/publish?code=abc', '_blank')
})
windowOpenSpy.mockRestore()
})
it('should show toast error when publish to marketplace fails', async () => {
mockPublishToCreatorsPlatform.mockRejectedValue(new Error('network error'))
renderWithSystemFeatures(
<AppPublisher
publishedAt={Date.now()}
onPublish={mockOnPublish}
/>,
{ systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } },
)
fireEvent.click(screen.getByText('common.publish'))
fireEvent.click(screen.getByText('common.publishToMarketplace'))
await waitFor(() => {
expect(mockToastError).toHaveBeenCalledWith('common.publishToMarketplaceFailed')
})
})
it('should disable marketplace button when not yet published', () => {
renderWithSystemFeatures(
<AppPublisher
onPublish={mockOnPublish}
/>,
{ systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } },
)
fireEvent.click(screen.getByText('common.publish'))
const marketplaceButton = screen.getByText('common.publishToMarketplace').closest('a, button, div[role="button"]') as HTMLElement
expect(marketplaceButton).toBeInTheDocument()
// clicking should not call the API because publishedAt is undefined
fireEvent.click(screen.getByText('common.publishToMarketplace'))
expect(mockPublishToCreatorsPlatform).not.toHaveBeenCalled()
})
it('should hide marketplace button when enable_creators_platform is false', () => {
render(
<AppPublisher
publishedAt={Date.now()}
onPublish={mockOnPublish}
/>,
)
fireEvent.click(screen.getByText('common.publish'))
expect(screen.queryByText('common.publishToMarketplace')).not.toBeInTheDocument()
})
it('should keep access control open when app detail is unavailable during confirmation', async () => {
mockAppDetail = null

View File

@ -5,6 +5,7 @@ import type { PublishWorkflowParams } from '@/types/workflow'
import { Button } from '@langgenius/dify-ui/button'
import { Popover, PopoverContent, PopoverTrigger } from '@langgenius/dify-ui/popover'
import { toast } from '@langgenius/dify-ui/toast'
import { RiStoreLine } from '@remixicon/react'
import { useSuspenseQuery } from '@tanstack/react-query'
import { useKeyPress } from 'ahooks'
import {
@ -26,7 +27,7 @@ import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
import { AccessMode } from '@/models/access-control'
import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control'
import { fetchAppDetailDirect } from '@/service/apps'
import { fetchAppDetailDirect, publishToCreatorsPlatform } from '@/service/apps'
import { fetchInstalledAppList } from '@/service/explore'
import { systemFeaturesQueryOptions } from '@/service/system-features'
import { useInvalidateAppWorkflow } from '@/service/use-workflow'
@ -40,6 +41,7 @@ import {
PublisherActionsSection,
PublisherSummarySection,
} from './sections'
import SuggestedAction from './suggested-action'
import {
getDisabledFunctionTooltip,
getPublisherAppUrl,
@ -100,6 +102,7 @@ const AppPublisher = ({
const [showAppAccessControl, setShowAppAccessControl] = useState(false)
const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false)
const [publishingToMarketplace, setPublishingToMarketplace] = useState(false)
const workflowStore = useContext(WorkflowContext)
const appDetail = useAppStore(state => state.appDetail)
@ -219,6 +222,23 @@ const AppPublisher = ({
}
}, [appDetail, setAppDetail])
const handlePublishToMarketplace = useCallback(async () => {
if (!appDetail?.id || publishingToMarketplace)
return
setPublishingToMarketplace(true)
try {
const res = await publishToCreatorsPlatform({ appID: appDetail.id })
if (res.redirect_url)
window.open(res.redirect_url, '_blank')
}
catch {
toast.error(t('common.publishToMarketplaceFailed', { ns: 'workflow' }))
}
finally {
setPublishingToMarketplace(false)
}
}, [appDetail?.id, publishingToMarketplace, t])
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.shift.p`, (e) => {
e.preventDefault()
if (publishDisabled || published)
@ -336,6 +356,19 @@ const AppPublisher = ({
workflowToolAvailable={workflowToolAvailable}
workflowToolMessage={workflowToolMessage}
/>
{systemFeatures.enable_creators_platform && (
<div className="border-t border-divider-subtle p-4">
<SuggestedAction
icon={<RiStoreLine className="h-4 w-4" />}
disabled={!publishedAt || publishingToMarketplace}
onClick={handlePublishToMarketplace}
>
{publishingToMarketplace
? t('common.publishingToMarketplace', { ns: 'workflow' })
: t('common.publishToMarketplace', { ns: 'workflow' })}
</SuggestedAction>
</div>
)}
</div>
</PopoverContent>
<EmbeddedModal

View File

@ -137,7 +137,7 @@ describe('CreateFromDSLModal', () => {
/>,
)
expect(screen.getByText('importFromDSL'))!.toBeInTheDocument()
expect(screen.getByText('importApp'))!.toBeInTheDocument()
await waitFor(() => {
expect(screen.getByText('demo.yml'))!.toBeInTheDocument()
@ -161,7 +161,7 @@ describe('CreateFromDSLModal', () => {
})
expect(screen.getByPlaceholderText('importFromDSLUrlPlaceholder'))!.toBeInTheDocument()
const closeTrigger = screen.getByText('importFromDSL').parentElement?.querySelector('.cursor-pointer.items-center') as HTMLElement
const closeTrigger = screen.getByText('importApp').parentElement?.querySelector('.cursor-pointer.items-center') as HTMLElement
fireEvent.click(closeTrigger)
expect(handleClose).toHaveBeenCalledTimes(1)
})

View File

@ -225,7 +225,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
onClose={noop}
>
<div className="flex items-center justify-between pt-6 pr-5 pb-3 pl-6 title-2xl-semi-bold text-text-primary">
{t('importFromDSL', { ns: 'app' })}
{t('importApp', { ns: 'app' })}
<div
className="flex h-8 w-8 cursor-pointer items-center"
onClick={() => onClose()}

View File

@ -7,9 +7,21 @@ import { useContextSelector } from 'use-context-selector'
import AppListContext from '@/context/app-list-context'
import { fetchAppDetail } from '@/service/explore'
import { AppModeEnum } from '@/types/app'
import Apps from '../index'
vi.mock('@/next/dynamic', () => ({
default: (loader: () => Promise<{ default: React.ComponentType }>) => {
const LazyComp = React.lazy(loader)
return function DynamicWrapper(props: Record<string, unknown>) {
return React.createElement(
React.Suspense,
{ fallback: null },
React.createElement(LazyComp, props),
)
}
},
}))
let documentTitleCalls: string[] = []
let educationInitCalls: number = 0
const mockHandleImportDSL = vi.fn()
@ -65,6 +77,16 @@ vi.mock('@/hooks/use-import-dsl', () => ({
}),
}))
const mockReplace = vi.fn()
let mockSearchParams = new URLSearchParams()
vi.mock('@/next/navigation', () => ({
useRouter: () => ({
replace: mockReplace,
}),
useSearchParams: () => mockSearchParams,
}))
vi.mock('../list', () => {
const MockList = () => {
const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel)
@ -129,6 +151,16 @@ vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({
),
}))
vi.mock('../import-from-marketplace-template-modal', () => ({
default: ({ templateId, onClose, onConfirm }: { templateId: string, onClose: () => void, onConfirm: (dsl: string) => void }) => (
<div data-testid="marketplace-template-modal">
<span data-testid="template-id">{templateId}</span>
<button data-testid="close-template" onClick={onClose}>Close Template</button>
<button data-testid="confirm-template" onClick={() => onConfirm('yaml-dsl-content')}>Confirm Template</button>
</div>
),
}))
vi.mock('@/service/explore', () => ({
fetchAppDetail: vi.fn(),
}))
@ -161,6 +193,8 @@ describe('Apps', () => {
vi.clearAllMocks()
documentTitleCalls = []
educationInitCalls = 0
mockSearchParams = new URLSearchParams()
mockReplace.mockClear()
mockFetchAppDetail.mockResolvedValue({
id: 'template-1',
name: 'Sample App',
@ -304,6 +338,66 @@ describe('Apps', () => {
})
})
describe('Marketplace Template', () => {
it('should render the template modal when template-id is in search params', async () => {
mockSearchParams = new URLSearchParams('template-id=tpl-42')
renderWithClient(<Apps />)
expect(await screen.findByTestId('marketplace-template-modal')).toBeInTheDocument()
expect(screen.getByTestId('template-id')).toHaveTextContent('tpl-42')
})
it('should not render the template modal when no template-id is present', () => {
renderWithClient(<Apps />)
expect(screen.queryByTestId('marketplace-template-modal')).not.toBeInTheDocument()
})
it('should close the template modal and remove template-id from URL', async () => {
mockSearchParams = new URLSearchParams('template-id=tpl-42')
renderWithClient(<Apps />)
fireEvent.click(await screen.findByTestId('close-template'))
expect(mockReplace).toHaveBeenCalledTimes(1)
const replaceArg = mockReplace.mock.calls[0]![0] as string
expect(replaceArg).not.toContain('template-id')
})
it('should import DSL from marketplace template on confirm', async () => {
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
options.onSuccess?.()
})
mockSearchParams = new URLSearchParams('template-id=tpl-42')
renderWithClient(<Apps />)
fireEvent.click(await screen.findByTestId('confirm-template'))
await waitFor(() => {
expect(mockHandleImportDSL).toHaveBeenCalledWith(
{ mode: 'yaml-content', yaml_content: 'yaml-dsl-content' },
expect.objectContaining({ onSuccess: expect.any(Function) }),
)
expect(mockReplace).toHaveBeenCalled()
})
})
it('should show DSL confirm modal when marketplace import is pending', async () => {
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => {
options.onPending?.()
})
mockSearchParams = new URLSearchParams('template-id=tpl-42')
renderWithClient(<Apps />)
fireEvent.click(await screen.findByTestId('confirm-template'))
await waitFor(() => {
expect(screen.getByTestId('dsl-confirm-modal')).toBeInTheDocument()
expect(mockReplace).toHaveBeenCalled()
})
})
})
describe('Styling', () => {
it('should have overflow-y-auto class', () => {
const { container } = renderWithClient(<Apps />)

View File

@ -0,0 +1,182 @@
'use client'
import { Button } from '@langgenius/dify-ui/button'
import { Dialog, DialogContent } from '@langgenius/dify-ui/dialog'
import { toast } from '@langgenius/dify-ui/toast'
import { RiCloseLine } from '@remixicon/react'
import { useCallback, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { MARKETPLACE_API_PREFIX } from '@/config'
import {
fetchMarketplaceTemplateDSL,
useMarketplaceTemplateDetail,
} from '@/service/marketplace-templates'
type ImportFromMarketplaceTemplateModalProps = {
templateId: string
onClose: () => void
onConfirm: (dslContent: string) => void
}
const ImportFromMarketplaceTemplateModal = ({
templateId,
onClose,
onConfirm,
}: ImportFromMarketplaceTemplateModalProps) => {
const { t } = useTranslation()
const { data, isLoading, isError } = useMarketplaceTemplateDetail(templateId)
const template = data?.data
const [importing, setImporting] = useState(false)
const isImportingRef = useRef(false)
const CATEGORY_I18N_MAP: Record<string, string> = useMemo(() => ({
marketing: t('marketplace.template.category.marketing', { ns: 'app' }),
sales: t('marketplace.template.category.sales', { ns: 'app' }),
support: t('marketplace.template.category.support', { ns: 'app' }),
operations: t('marketplace.template.category.operations', { ns: 'app' }),
it: t('marketplace.template.category.it', { ns: 'app' }),
knowledge: t('marketplace.template.category.knowledge', { ns: 'app' }),
design: t('marketplace.template.category.design', { ns: 'app' }),
}), [t])
const translateCategory = useCallback((slug: string) => {
return CATEGORY_I18N_MAP[slug] ?? slug
}, [CATEGORY_I18N_MAP])
const handleConfirm = useCallback(async () => {
if (isImportingRef.current)
return
isImportingRef.current = true
setImporting(true)
try {
const dsl = await fetchMarketplaceTemplateDSL(templateId)
onConfirm(dsl)
}
catch {
toast.error(t('marketplace.template.importFailed', { ns: 'app' }))
}
finally {
setImporting(false)
isImportingRef.current = false
}
}, [templateId, onConfirm, t])
return (
<Dialog
open
onOpenChange={(open) => {
if (!open)
onClose()
}}
>
<DialogContent
className="w-[520px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg p-0 shadow-xl"
>
<div className="flex items-center justify-between pt-6 pr-5 pb-3 pl-6 title-2xl-semi-bold text-text-primary">
{t('marketplace.template.modalTitle', { ns: 'app' })}
<div
className="flex h-8 w-8 cursor-pointer items-center"
onClick={onClose}
>
<RiCloseLine className="h-5 w-5 text-text-tertiary" />
</div>
</div>
<div className="px-6 py-4">
{isLoading && (
<div className="flex items-center justify-center py-8">
<div className="system-md-regular text-text-tertiary">Loading...</div>
</div>
)}
{isError && (
<div className="flex items-center justify-center py-8">
<div className="system-md-regular text-text-destructive">
{t('marketplace.template.fetchFailed', { ns: 'app' })}
</div>
</div>
)}
{template && (
<div className="flex flex-col gap-4">
<div className="flex items-center gap-3">
{template.icon_file_key
? (
<img
src={`${MARKETPLACE_API_PREFIX}/templates/${template.id}/icon`}
alt={template.template_name}
className="h-10 w-10 rounded-lg object-cover"
/>
)
: (
<div
className="flex h-10 w-10 items-center justify-center rounded-lg text-xl"
style={{ background: template.icon_background || '#F3F4F6' }}
>
{template.icon || '📄'}
</div>
)}
<div className="flex flex-col">
<div className="system-md-semibold text-text-primary">{template.template_name}</div>
<div className="flex items-center gap-1 system-xs-regular text-text-tertiary">
<span>
{t('marketplace.template.publishedBy', { ns: 'app' })}
{' '}
{template.publisher_unique_handle}
</span>
<span>·</span>
<span>
{t('marketplace.template.usageCount', { ns: 'app' })}
{' '}
{template.usage_count}
</span>
</div>
</div>
</div>
{template.overview && (
<div className="flex flex-col gap-1">
<div className="system-xs-medium-uppercase text-text-tertiary">
{t('marketplace.template.overview', { ns: 'app' })}
</div>
<div className="system-sm-regular text-text-secondary">
{template.overview}
</div>
</div>
)}
{template.categories.length > 0 && (
<div className="flex flex-wrap items-center gap-2">
{template.categories.map(cat => (
<span
key={cat}
className="inline-flex items-center rounded-full bg-components-label-gray px-2.5 py-1 system-sm-regular text-text-secondary"
>
{translateCategory(cat)}
</span>
))}
</div>
)}
</div>
)}
</div>
<div className="flex justify-end px-6 py-5">
<Button className="mr-2" onClick={onClose}>
{t('newApp.Cancel', { ns: 'app' })}
</Button>
<Button
variant="primary"
disabled={isLoading || isError || importing}
loading={importing}
onClick={handleConfirm}
>
{t('marketplace.template.importConfirm', { ns: 'app' })}
</Button>
</div>
</DialogContent>
</Dialog>
)
}
export default ImportFromMarketplaceTemplateModal

View File

@ -9,6 +9,7 @@ import useDocumentTitle from '@/hooks/use-document-title'
import { useImportDSL } from '@/hooks/use-import-dsl'
import { DSLImportMode } from '@/models/app'
import dynamic from '@/next/dynamic'
import { useRouter, useSearchParams } from '@/next/navigation'
import { fetchAppDetail } from '@/service/explore'
import { trackCreateApp } from '@/utils/create-app-tracking'
import List from './list'
@ -16,9 +17,14 @@ import List from './list'
const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false })
const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false })
const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false })
const ImportFromMarketplaceTemplateModal = dynamic(() => import('./import-from-marketplace-template-modal'), { ssr: false })
const Apps = () => {
const { t } = useTranslation()
const searchParams = useSearchParams()
const { replace } = useRouter()
const templateId = searchParams.get('template-id')
const templateDismissedRef = useRef(false)
useDocumentTitle(t('menus.apps', { ns: 'common' }))
useEducationInit()
@ -58,6 +64,14 @@ const Apps = () => {
const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false)
const handleCloseTemplateModal = useCallback(() => {
templateDismissedRef.current = true
const params = new URLSearchParams(searchParams.toString())
params.delete('template-id')
const query = params.toString()
replace(query ? `?${query}` : window.location.pathname, { scroll: false })
}, [searchParams, replace])
const {
handleImportDSL,
handleImportDSLConfirm,
@ -74,6 +88,22 @@ const Apps = () => {
})
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
const handleMarketplaceTemplateConfirm = useCallback(async (dslContent: string) => {
await handleImportDSL({
mode: DSLImportMode.YAML_CONTENT,
yaml_content: dslContent,
}, {
onSuccess: () => {
handleCloseTemplateModal()
onSuccess()
},
onPending: () => {
handleCloseTemplateModal()
setShowDSLConfirmModal(true)
},
})
}, [handleImportDSL, handleCloseTemplateModal, onSuccess])
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
name,
icon_type,
@ -152,6 +182,14 @@ const Apps = () => {
onHide={() => setIsShowCreateModal(false)}
/>
)}
{templateId && !templateDismissedRef.current && (
<ImportFromMarketplaceTemplateModal
templateId={templateId}
onClose={handleCloseTemplateModal}
onConfirm={handleMarketplaceTemplateConfirm}
/>
)}
</div>
</AppListContext.Provider>
)

View File

@ -156,7 +156,7 @@ describe('PanelContextmenu', () => {
fireEvent.click(screen.getByText('common.run'))
fireEvent.click(screen.getByText('common.pasteHere'))
fireEvent.click(screen.getByText('export'))
fireEvent.click(screen.getByText('common.importDSL'))
fireEvent.click(screen.getByText('importApp'))
clickAwayHandler?.()
expect(mockHandleAddNote).toHaveBeenCalledTimes(1)

View File

@ -137,7 +137,7 @@ const PanelContextmenu = () => {
className="flex h-8 cursor-pointer items-center justify-between rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover"
onClick={() => setShowImportDSLModal(true)}
>
{t('common.importDSL', { ns: 'workflow' })}
{t('importApp', { ns: 'app' })}
</div>
</div>
</div>

View File

@ -205,7 +205,7 @@ const UpdateDSLModal = ({
onClose={onCancel}
>
<div className="mb-3 flex items-center justify-between">
<div className="title-2xl-semi-bold text-text-primary">{t('common.importDSL', { ns: 'workflow' })}</div>
<div className="title-2xl-semi-bold text-text-primary">{t('importApp', { ns: 'app' })}</div>
<div className="flex h-[22px] w-[22px] cursor-pointer items-center justify-center" onClick={onCancel}>
<RiCloseLine className="h-[18px] w-[18px] text-text-tertiary" />
</div>

View File

@ -1,18 +1,23 @@
import Loading from '@/app/components/base/loading'
import Link from '@/next/link'
import { redirect } from '@/next/navigation'
const Home = async () => {
return (
<div className="flex min-h-screen flex-col justify-center py-12 sm:px-6 lg:px-8">
type HomePageProps = {
searchParams: Promise<Record<string, string | string[] | undefined>>
}
<div className="sm:mx-auto sm:w-full sm:max-w-md">
<Loading type="area" />
<div className="mt-10 text-center">
<Link href="/apps">🚀</Link>
</div>
</div>
</div>
)
const Home = async ({ searchParams }: HomePageProps) => {
const resolvedSearchParams = await searchParams
const urlSearchParams = new URLSearchParams()
Object.entries(resolvedSearchParams).forEach(([key, value]) => {
if (value === undefined)
return
if (Array.isArray(value)) {
value.forEach(item => urlSearchParams.append(key, item))
return
}
urlSearchParams.set(key, value)
})
const queryString = urlSearchParams.toString()
redirect(queryString ? `/apps?${queryString}` : '/apps')
}
export default Home

View File

@ -51,7 +51,7 @@ export default function CheckCode() {
router.replace(`/signin/invite-settings?${searchParams.toString()}`)
}
else {
const redirectUrl = resolvePostLoginRedirect()
const redirectUrl = resolvePostLoginRedirect(searchParams)
router.replace(redirectUrl || '/apps')
}
}

View File

@ -75,7 +75,7 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis
router.replace(`/signin/invite-settings?${searchParams.toString()}`)
}
else {
const redirectUrl = resolvePostLoginRedirect()
const redirectUrl = resolvePostLoginRedirect(searchParams)
router.replace(redirectUrl || '/apps')
}
}

View File

@ -65,7 +65,7 @@ export default function InviteSettingsPage() {
if (res.result === 'success') {
// Tokens are now stored in cookies by the backend
await setLocaleOnClient(language!, false)
const redirectUrl = resolvePostLoginRedirect()
const redirectUrl = resolvePostLoginRedirect(searchParams)
router.replace(redirectUrl || '/apps')
}
}

View File

@ -49,7 +49,7 @@ const NormalForm = () => {
try {
if (isLoggedIn) {
setIsRedirecting(true)
const redirectUrl = resolvePostLoginRedirect()
const redirectUrl = resolvePostLoginRedirect(searchParams)
router.replace(redirectUrl || '/apps')
return
}

View File

@ -1,15 +1,63 @@
let postLoginRedirect: string | null = null
import type { ReadonlyURLSearchParams } from '@/next/navigation'
export const setPostLoginRedirect = (value: string | null) => {
postLoginRedirect = value
const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending_redirect'
const REDIRECT_URL_KEY = 'redirect_url'
type OAuthPendingRedirect = {
value?: string
expiry?: number
}
export const resolvePostLoginRedirect = () => {
if (postLoginRedirect) {
const redirectUrl = postLoginRedirect
postLoginRedirect = null
return redirectUrl
const getCurrentUnixTimestamp = () => Math.floor(Date.now() / 1000)
function removeOAuthPendingRedirect() {
try {
localStorage.removeItem(OAUTH_AUTHORIZE_PENDING_KEY)
}
return null
catch {}
}
function getOAuthPendingRedirect(): string | null {
try {
const raw = localStorage.getItem(OAUTH_AUTHORIZE_PENDING_KEY)
if (!raw)
return null
removeOAuthPendingRedirect()
const item: OAuthPendingRedirect = JSON.parse(raw)
if (!item.value || typeof item.expiry !== 'number')
return null
return getCurrentUnixTimestamp() > item.expiry ? null : item.value
}
catch {
removeOAuthPendingRedirect()
return null
}
}
export function setOAuthPendingRedirect(url: string, ttlSeconds: number = 300) {
try {
const item: OAuthPendingRedirect = {
value: url,
expiry: getCurrentUnixTimestamp() + ttlSeconds,
}
localStorage.setItem(OAUTH_AUTHORIZE_PENDING_KEY, JSON.stringify(item))
}
catch {}
}
export const resolvePostLoginRedirect = (searchParams?: ReadonlyURLSearchParams) => {
if (searchParams) {
const redirectUrl = searchParams.get(REDIRECT_URL_KEY)
if (redirectUrl) {
try {
removeOAuthPendingRedirect()
return decodeURIComponent(redirectUrl)
}
catch {
removeOAuthPendingRedirect()
return redirectUrl
}
}
}
return getOAuthPendingRedirect()
}

View File

@ -1,5 +1,6 @@
import type { CollectionsAndPluginsSearchParams, MarketplaceCollection, PluginsSearchParams } from '@/app/components/plugins/marketplace/types'
import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types'
import type { MarketplaceTemplate } from '@/types/marketplace-template'
import { type } from '@orpc/contract'
import { base } from './base'
@ -54,3 +55,15 @@ export const searchAdvancedContract = base
body: Omit<PluginsSearchParams, 'type'>
}>())
.output(type<{ data: PluginsFromMarketplaceResponse }>())
export const templateDetailContract = base
.route({
path: '/templates/{templateId}',
method: 'GET',
})
.input(type<{
params: {
templateId: string
}
}>())
.output(type<{ data: MarketplaceTemplate }>())

View File

@ -42,12 +42,13 @@ import {
workflowDraftUpdateFeaturesContract,
} from './console/workflow'
import { workflowCommentContracts } from './console/workflow-comment'
import { collectionPluginsContract, collectionsContract, searchAdvancedContract } from './marketplace'
import { collectionPluginsContract, collectionsContract, searchAdvancedContract, templateDetailContract } from './marketplace'
export const marketplaceRouterContract = {
collections: collectionsContract,
collectionPlugins: collectionPluginsContract,
searchAdvanced: searchAdvancedContract,
templateDetail: templateDetailContract,
}
export type MarketPlaceInputs = InferContractRouterInputs<typeof marketplaceRouterContract>

View File

@ -118,12 +118,29 @@
"iconPicker.emoji": "Emoji",
"iconPicker.image": "Image",
"iconPicker.ok": "OK",
"importApp": "Import App",
"importDSL": "Import DSL file",
"importFromDSL": "Import from DSL",
"importFromDSLFile": "From DSL file",
"importFromDSLUrl": "From URL",
"importFromDSLUrlPlaceholder": "Paste DSL link here",
"join": "Join the community",
"marketplace.template.categories": "Categories",
"marketplace.template.category.design": "Design",
"marketplace.template.category.it": "IT",
"marketplace.template.category.knowledge": "Knowledge",
"marketplace.template.category.marketing": "Marketing",
"marketplace.template.category.operations": "Operations",
"marketplace.template.category.sales": "Sales",
"marketplace.template.category.support": "Support",
"marketplace.template.fetchFailed": "Failed to fetch template",
"marketplace.template.importConfirm": "Import",
"marketplace.template.importFailed": "Failed to import template",
"marketplace.template.modalTitle": "Import from Marketplace",
"marketplace.template.overview": "Overview",
"marketplace.template.publishedBy": "By",
"marketplace.template.usageCount": "Usage",
"marketplace.template.viewOnMarketplace": "View on Marketplace",
"maxActiveRequests": "Max concurrent requests",
"maxActiveRequestsPlaceholder": "Enter 0 for unlimited",
"maxActiveRequestsTip": "Maximum number of concurrent active requests per app (0 for unlimited)",

View File

@ -229,9 +229,12 @@
"common.previewPlaceholder": "Enter content in the box below to start debugging the Chatbot",
"common.processData": "Process Data",
"common.publish": "Publish",
"common.publishToMarketplace": "Publish to Marketplace",
"common.publishToMarketplaceFailed": "Failed to publish to Marketplace",
"common.publishUpdate": "Publish Update",
"common.published": "Published",
"common.publishedAt": "Published",
"common.publishingToMarketplace": "Publishing...",
"common.redo": "Redo",
"common.restart": "Restart",
"common.restore": "Restore",

View File

@ -118,12 +118,29 @@
"iconPicker.emoji": "表情符号",
"iconPicker.image": "图片",
"iconPicker.ok": "确认",
"importApp": "导入应用",
"importDSL": "导入 DSL 文件",
"importFromDSL": "导入 DSL",
"importFromDSLFile": "文件",
"importFromDSLUrl": "URL",
"importFromDSLUrlPlaceholder": "输入 DSL 文件的 URL",
"join": "参与社区",
"marketplace.template.categories": "分类",
"marketplace.template.category.design": "设计",
"marketplace.template.category.it": "IT",
"marketplace.template.category.knowledge": "知识",
"marketplace.template.category.marketing": "营销",
"marketplace.template.category.operations": "运营",
"marketplace.template.category.sales": "销售",
"marketplace.template.category.support": "支持",
"marketplace.template.fetchFailed": "获取模板失败",
"marketplace.template.importConfirm": "导入",
"marketplace.template.importFailed": "导入模板失败",
"marketplace.template.modalTitle": "从市场导入",
"marketplace.template.overview": "概述",
"marketplace.template.publishedBy": "来自",
"marketplace.template.usageCount": "使用次数",
"marketplace.template.viewOnMarketplace": "在市场查看",
"maxActiveRequests": "最大活跃请求数",
"maxActiveRequestsPlaceholder": "0 表示不限制",
"maxActiveRequestsTip": "当前应用的最大活跃请求数0 表示不限制)",

View File

@ -229,9 +229,12 @@
"common.previewPlaceholder": "在下面的框中输入内容开始调试聊天机器人",
"common.processData": "数据处理",
"common.publish": "发布",
"common.publishToMarketplace": "发布到市场",
"common.publishToMarketplaceFailed": "发布到市场失败",
"common.publishUpdate": "发布更新",
"common.published": "已发布",
"common.publishedAt": "发布于",
"common.publishingToMarketplace": "发布中...",
"common.redo": "重做",
"common.restart": "重新开始",
"common.restore": "恢复",

View File

@ -1,4 +1,5 @@
export {
redirect,
useParams,
usePathname,
useRouter,
@ -6,3 +7,4 @@ export {
useSelectedLayoutSegment,
useSelectedLayoutSegments,
} from 'next/navigation'
export type { ReadonlyURLSearchParams } from 'next/navigation'

View File

@ -0,0 +1,68 @@
import { buildSigninUrlWithRedirect } from '../base'
vi.mock('@/utils/var', () => ({
basePath: '/app',
API_PREFIX: '/console/api',
PUBLIC_API_PREFIX: '/api',
IS_CE_EDITION: false,
}))
describe('buildSigninUrlWithRedirect', () => {
const originalLocation = globalThis.location
beforeEach(() => {
Object.defineProperty(globalThis, 'location', {
value: {
origin: 'https://example.com',
pathname: '/apps',
href: 'https://example.com/apps',
},
writable: true,
configurable: true,
})
})
afterEach(() => {
Object.defineProperty(globalThis, 'location', {
value: originalLocation,
writable: true,
configurable: true,
})
})
it('should return plain signin URL for non-OAuth pages', () => {
const url = buildSigninUrlWithRedirect()
expect(url).toBe('https://example.com/app/signin')
})
it('should append redirect_url for OAuth authorize pages', () => {
const oauthHref = 'https://example.com/account/oauth/authorize?client_id=abc&state=xyz'
Object.defineProperty(globalThis, 'location', {
value: {
origin: 'https://example.com',
pathname: '/account/oauth/authorize',
href: oauthHref,
},
writable: true,
configurable: true,
})
const url = buildSigninUrlWithRedirect()
expect(url).toBe(`https://example.com/app/signin?redirect_url=${encodeURIComponent(oauthHref)}`)
})
it('should not include redirect_url for other paths containing partial match', () => {
Object.defineProperty(globalThis, 'location', {
value: {
origin: 'https://example.com',
pathname: '/settings/oauth',
href: 'https://example.com/settings/oauth',
},
writable: true,
configurable: true,
})
const url = buildSigninUrlWithRedirect()
expect(url).toBe('https://example.com/app/signin')
})
})

View File

@ -192,3 +192,11 @@ export const updateTracingConfig = ({ appId, body }: { appId: string, body: Trac
export const removeTracingConfig = ({ appId, provider }: { appId: string, provider: TracingProvider }): Promise<CommonResponse> => {
return del<CommonResponse>(`/apps/${appId}/trace-config?tracing_provider=${provider}`)
}
type PublishToCreatorsPlatformResponse = {
redirect_url: string
}
export const publishToCreatorsPlatform = ({ appID }: { appID: string }): Promise<PublishToCreatorsPlatformResponse> => {
return post<PublishToCreatorsPlatformResponse>(`apps/${appID}/publish-to-creators-platform`, { body: {} })
}

View File

@ -140,6 +140,20 @@ function jumpTo(url: string) {
globalThis.location.href = url
}
const OAUTH_AUTHORIZE_PATH = '/account/oauth/authorize'
export const buildSigninUrlWithRedirect = (): string => {
const loginUrl = `${globalThis.location.origin}${basePath}/signin`
// Only preserve redirect URL for OAuth authorize pages
if (globalThis.location.pathname.includes(OAUTH_AUTHORIZE_PATH)) {
const currentUrl = globalThis.location.href
return `${loginUrl}?redirect_url=${encodeURIComponent(currentUrl)}`
}
return loginUrl
}
function unicodeToChar(text: string) {
if (!text)
return ''
@ -795,14 +809,14 @@ export const request = async<T>(url: string, options = {}, otherOptions?: IOther
if (refreshErr === null)
return baseFetch<T>(url, options, otherOptionsForBaseFetch)
if (location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) {
jumpTo(loginUrl)
jumpTo(buildSigninUrlWithRedirect())
return Promise.reject(err)
}
if (!silent) {
toast.error(message)
return Promise.reject(err)
}
jumpTo(loginUrl)
jumpTo(buildSigninUrlWithRedirect())
return Promise.reject(err)
}
else {

View File

@ -0,0 +1,18 @@
import { useQuery } from '@tanstack/react-query'
import { MARKETPLACE_API_PREFIX } from '@/config'
import { marketplaceQuery } from './client'
export const useMarketplaceTemplateDetail = (templateId: string | null) => {
return useQuery({
...marketplaceQuery.templateDetail.queryOptions({ input: { params: { templateId: templateId ?? '' } } }),
enabled: !!templateId,
})
}
export const fetchMarketplaceTemplateDSL = async (templateId: string): Promise<string> => {
const url = `${MARKETPLACE_API_PREFIX}/templates/${templateId}/dsl`
const response = await fetch(url)
if (!response.ok)
throw new Error(`Failed to fetch DSL: ${response.statusText}`)
return response.text()
}

View File

@ -64,6 +64,7 @@ export type SystemFeatures = {
allow_email_code_login: boolean
allow_email_password_login: boolean
}
enable_creators_platform: boolean
enable_trial_app: boolean
enable_explore_banner: boolean
}
@ -108,6 +109,7 @@ export const defaultSystemFeatures: SystemFeatures = {
allow_email_code_login: false,
allow_email_password_login: false,
},
enable_creators_platform: false,
enable_trial_app: false,
enable_explore_banner: false,
}

View File

@ -0,0 +1,11 @@
export type MarketplaceTemplate = {
id: string
template_name: string
overview: string
icon: string
icon_background: string
icon_file_key: string
publisher_unique_handle: string
usage_count: number
categories: string[]
}