mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
feat: marketplace and oauth fixes (#35509)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
775f9212f3
commit
df28c99817
@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Creators Platform configuration
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for Creators Platform integration
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable Creators Platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client ID for Creators Platform integration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -1379,6 +1400,7 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
CreatorsPlatformConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
|
||||
@ -692,6 +692,32 @@ class AppExportApi(Resource):
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
|
||||
41
api/core/helper/creators.py
Normal file
41
api/core/helper/creators.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
@ -14,23 +14,23 @@ from configs import dify_config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
class EncryptionError(Exception):
|
||||
"""Encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
class SystemEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
A simple parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
This class provides methods to encrypt and decrypt parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
Initialize the encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_params(self, params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
Encrypt parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
EncryptionError: If encryption fails
|
||||
ValueError: If params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
Decrypt parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
EncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
return params
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
||||
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
Create an encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
return SystemEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
||||
_encrypter: SystemEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
def get_system_encrypter() -> SystemEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
Get the global encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
global _encrypter
|
||||
if _encrypter is None:
|
||||
_encrypter = SystemEncrypter()
|
||||
return _encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_system_params(params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
Encrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
params: Parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
return get_system_encrypter().encrypt_params(params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
Decrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
return get_system_encrypter().decrypt_params(encrypted_data)
|
||||
@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
|
||||
enable_change_email: bool = True
|
||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||
trial_models: list[str] = []
|
||||
enable_creators_platform: bool = False
|
||||
enable_trial_app: bool = False
|
||||
enable_explore_banner: bool = False
|
||||
|
||||
@ -241,6 +242,9 @@ class FeatureService:
|
||||
if dify_config.MARKETPLACE_ENABLED:
|
||||
system_features.enable_marketplace = True
|
||||
|
||||
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
system_features.enable_creators_platform = True
|
||||
|
||||
return system_features
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ToolProviderID
|
||||
@ -521,7 +521,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from core.trigger.entities.api_entities import (
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
@ -635,7 +635,7 @@ class TriggerProviderService:
|
||||
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
|
||||
106
api/tests/unit_tests/core/helper/test_creators.py
Normal file
106
api/tests/unit_tests/core/helper/test_creators.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Tests for the Creators Platform helper module."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from yarl import URL
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_creators_url(monkeypatch):
|
||||
"""Patch the module-level creators_platform_api_url for all tests."""
|
||||
monkeypatch.setattr(
|
||||
"core.helper.creators.creators_platform_api_url",
|
||||
URL("https://creators.example.com"),
|
||||
)
|
||||
|
||||
|
||||
class TestUploadDSL:
|
||||
@patch("core.helper.creators.httpx.post")
|
||||
def test_returns_claim_code(self, mock_post):
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"data": {"claim_code": "abc123"}}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from core.helper.creators import upload_dsl
|
||||
|
||||
result = upload_dsl(b"app: demo", "demo.yaml")
|
||||
|
||||
assert result == "abc123"
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
assert "anonymous-upload" in call_kwargs.args[0]
|
||||
assert call_kwargs.kwargs["timeout"] == 30
|
||||
|
||||
@patch("core.helper.creators.httpx.post")
|
||||
def test_raises_on_missing_claim_code(self, mock_post):
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"data": {}}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from core.helper.creators import upload_dsl
|
||||
|
||||
with pytest.raises(ValueError, match="claim_code"):
|
||||
upload_dsl(b"app: demo")
|
||||
|
||||
@patch("core.helper.creators.httpx.post")
|
||||
def test_raises_on_http_error(self, mock_post):
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"Server Error",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(),
|
||||
)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from core.helper.creators import upload_dsl
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
upload_dsl(b"app: demo")
|
||||
|
||||
|
||||
class TestGetRedirectUrl:
|
||||
@patch("core.helper.creators.dify_config")
|
||||
def test_without_oauth_client_id(self, mock_config):
|
||||
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
|
||||
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
|
||||
|
||||
from core.helper.creators import get_redirect_url
|
||||
|
||||
url = get_redirect_url("user-1", "claim-abc")
|
||||
|
||||
assert "dsl_claim_code=claim-abc" in url
|
||||
assert "oauth_code" not in url
|
||||
assert url.startswith("https://creators.example.com")
|
||||
|
||||
@patch("core.helper.creators.dify_config")
|
||||
def test_with_oauth_client_id(self, mock_config):
|
||||
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
|
||||
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz"
|
||||
|
||||
with patch(
|
||||
"services.oauth_server.OAuthServerService.sign_oauth_authorization_code",
|
||||
return_value="oauth-code-123",
|
||||
) as mock_sign:
|
||||
from core.helper.creators import get_redirect_url
|
||||
|
||||
url = get_redirect_url("user-1", "claim-abc")
|
||||
|
||||
mock_sign.assert_called_once_with("client-xyz", "user-1")
|
||||
assert "dsl_claim_code=claim-abc" in url
|
||||
assert "oauth_code=oauth-code-123" in url
|
||||
|
||||
@patch("core.helper.creators.dify_config")
|
||||
def test_strips_trailing_slash(self, mock_config):
|
||||
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/"
|
||||
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
|
||||
|
||||
from core.helper.creators import get_redirect_url
|
||||
|
||||
url = get_redirect_url("user-1", "claim-abc")
|
||||
|
||||
assert url.startswith("https://creators.example.com?")
|
||||
assert "creators.example.com/?" not in url
|
||||
@ -2,50 +2,50 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.utils import system_oauth_encryption as oauth_encryption
|
||||
from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter
|
||||
from core.tools.utils import system_encryption as encryption
|
||||
from core.tools.utils.system_encryption import EncryptionError, SystemEncrypter
|
||||
|
||||
|
||||
def test_system_oauth_encrypter_roundtrip():
|
||||
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
|
||||
def test_system_encrypter_roundtrip():
|
||||
encrypter = SystemEncrypter(secret_key="test-secret")
|
||||
payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(payload)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(payload)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert encrypted
|
||||
assert dict(decrypted) == payload
|
||||
|
||||
|
||||
def test_system_oauth_encrypter_decrypt_validates_input():
|
||||
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
|
||||
def test_system_encrypter_decrypt_validates_input():
|
||||
encrypter = SystemEncrypter(secret_key="test-secret")
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
encrypter.decrypt_oauth_params(123) # type: ignore[arg-type]
|
||||
encrypter.decrypt_params(123) # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
encrypter.decrypt_oauth_params("")
|
||||
encrypter.decrypt_params("")
|
||||
|
||||
|
||||
def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext():
|
||||
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
|
||||
def test_system_encrypter_raises_error_for_invalid_ciphertext():
|
||||
encrypter = SystemEncrypter(secret_key="test-secret")
|
||||
|
||||
with pytest.raises(OAuthEncryptionError, match="Decryption failed"):
|
||||
encrypter.decrypt_oauth_params("not-base64")
|
||||
with pytest.raises(EncryptionError, match="Decryption failed"):
|
||||
encrypter.decrypt_params("not-base64")
|
||||
|
||||
|
||||
def test_system_oauth_helpers_use_global_cached_instance(monkeypatch):
|
||||
monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None)
|
||||
monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret")
|
||||
def test_system_helpers_use_global_cached_instance(monkeypatch):
|
||||
monkeypatch.setattr(encryption, "_encrypter", None)
|
||||
monkeypatch.setattr("core.tools.utils.system_encryption.dify_config.SECRET_KEY", "global-secret")
|
||||
|
||||
first = oauth_encryption.get_system_oauth_encrypter()
|
||||
second = oauth_encryption.get_system_oauth_encrypter()
|
||||
first = encryption.get_system_encrypter()
|
||||
second = encryption.get_system_encrypter()
|
||||
assert first is second
|
||||
|
||||
encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"})
|
||||
assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"}
|
||||
encrypted = encryption.encrypt_system_params({"k": "v"})
|
||||
assert encryption.decrypt_system_params(encrypted) == {"k": "v"}
|
||||
|
||||
|
||||
def test_create_system_oauth_encrypter_factory():
|
||||
encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret")
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
def test_create_system_encrypter_factory():
|
||||
encrypter = encryption.create_system_encrypter(secret_key="factory-secret")
|
||||
assert isinstance(encrypter, SystemEncrypter)
|
||||
|
||||
@ -694,7 +694,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified(
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
mocker.patch(
|
||||
"services.trigger.trigger_provider_service.decrypt_system_oauth_params",
|
||||
"services.trigger.trigger_provider_service.decrypt_system_params",
|
||||
return_value={"client_id": "system"},
|
||||
)
|
||||
|
||||
@ -716,7 +716,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails(
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
mocker.patch(
|
||||
"services.trigger.trigger_provider_service.decrypt_system_oauth_params",
|
||||
"services.trigger.trigger_provider_service.decrypt_system_params",
|
||||
side_effect=RuntimeError("bad data"),
|
||||
)
|
||||
|
||||
|
||||
@ -280,7 +280,7 @@ class TestGetOauthClient:
|
||||
|
||||
assert result == {"client_id": "id", "client_secret": "secret"}
|
||||
|
||||
@patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"})
|
||||
@patch(f"{MODULE}.decrypt_system_params", return_value={"sys_key": "sys_val"})
|
||||
@patch(f"{MODULE}.PluginService")
|
||||
@patch(f"{MODULE}.create_provider_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
|
||||
619
api/tests/unit_tests/utils/encryption/test_system_encryption.py
Normal file
619
api/tests/unit_tests/utils/encryption/test_system_encryption.py
Normal file
@ -0,0 +1,619 @@
|
||||
import base64
|
||||
import hashlib
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad
|
||||
|
||||
from core.tools.utils.system_encryption import (
|
||||
EncryptionError,
|
||||
SystemEncrypter,
|
||||
create_system_encrypter,
|
||||
decrypt_system_params,
|
||||
encrypt_system_params,
|
||||
get_system_encrypter,
|
||||
)
|
||||
|
||||
|
||||
class TestSystemEncrypter:
|
||||
"""Test cases for SystemEncrypter class"""
|
||||
|
||||
def test_init_with_secret_key(self):
|
||||
"""Test initialization with provided secret key"""
|
||||
secret_key = "test_secret_key"
|
||||
encrypter = SystemEncrypter(secret_key=secret_key)
|
||||
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_with_none_secret_key(self):
|
||||
"""Test initialization with None secret key falls back to config"""
|
||||
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = SystemEncrypter(secret_key=None)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_with_empty_secret_key(self):
|
||||
"""Test initialization with empty secret key"""
|
||||
encrypter = SystemEncrypter(secret_key="")
|
||||
expected_key = hashlib.sha256(b"").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_without_secret_key_uses_config(self):
|
||||
"""Test initialization without secret key uses config"""
|
||||
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "default_secret"
|
||||
encrypter = SystemEncrypter()
|
||||
expected_key = hashlib.sha256(b"default_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_encrypt_params_basic(self):
|
||||
"""Test basic parameters encryption"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
# Should be valid base64
|
||||
try:
|
||||
base64.b64decode(encrypted)
|
||||
except Exception:
|
||||
pytest.fail("Encrypted result is not valid base64")
|
||||
|
||||
def test_encrypt_params_empty_dict(self):
|
||||
"""Test encryption with empty dictionary"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
params = {}
|
||||
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_params_complex_data(self):
|
||||
"""Test encryption with complex data structures"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"scopes": ["read", "write", "admin"],
|
||||
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
|
||||
"numeric_value": 42,
|
||||
"boolean_value": False,
|
||||
"null_value": None,
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_params_unicode_data(self):
|
||||
"""Test encryption with unicode data"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
|
||||
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_params_large_data(self):
|
||||
"""Test encryption with large data"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
params = {
|
||||
"client_id": "test_id",
|
||||
"large_data": "x" * 10000, # 10KB of data
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_params_invalid_input(self):
|
||||
"""Test encryption with invalid input types"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_params(None)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_params("not_a_dict")
|
||||
|
||||
def test_decrypt_params_basic(self):
|
||||
"""Test basic parameters decryption"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_params_empty_dict(self):
|
||||
"""Test decryption of empty dictionary"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {}
|
||||
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_params_complex_data(self):
|
||||
"""Test decryption with complex data structures"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"scopes": ["read", "write", "admin"],
|
||||
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
|
||||
"numeric_value": 42,
|
||||
"boolean_value": False,
|
||||
"null_value": None,
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_params_unicode_data(self):
|
||||
"""Test decryption with unicode data"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"description": "This is a test case 🚀",
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_params_large_data(self):
|
||||
"""Test decryption with large data"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"large_data": "x" * 10000, # 10KB of data
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_params_invalid_base64(self):
|
||||
"""Test decryption with invalid base64 data"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(EncryptionError):
|
||||
encrypter.decrypt_params("invalid_base64!")
|
||||
|
||||
def test_decrypt_params_empty_string(self):
|
||||
"""Test decryption with empty string"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_params("")
|
||||
|
||||
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_params_non_string_input(self):
|
||||
"""Test decryption with non-string input"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_params(123)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_params(None)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_params_too_short_data(self):
|
||||
"""Test decryption with too short encrypted data"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
# Create data that's too short (less than 32 bytes)
|
||||
short_data = base64.b64encode(b"short").decode()
|
||||
|
||||
with pytest.raises(EncryptionError) as exc_info:
|
||||
encrypter.decrypt_params(short_data)
|
||||
|
||||
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_params_corrupted_data(self):
|
||||
"""Test decryption with corrupted data"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
# Create corrupted data (valid base64 but invalid encrypted content)
|
||||
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
|
||||
|
||||
with pytest.raises(EncryptionError):
|
||||
encrypter.decrypt_params(corrupted_data)
|
||||
|
||||
def test_decrypt_params_wrong_key(self):
|
||||
"""Test decryption with wrong key"""
|
||||
encrypter1 = SystemEncrypter("secret1")
|
||||
encrypter2 = SystemEncrypter("secret2")
|
||||
|
||||
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
encrypted = encrypter1.encrypt_params(original_params)
|
||||
|
||||
with pytest.raises(EncryptionError):
|
||||
encrypter2.decrypt_params(encrypted)
|
||||
|
||||
def test_encryption_decryption_consistency(self):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
test_cases = [
|
||||
{},
|
||||
{"simple": "value"},
|
||||
{"client_id": "id", "client_secret": "secret"},
|
||||
{"complex": {"nested": {"deep": "value"}}},
|
||||
{"unicode": "test 🚀"},
|
||||
{"numbers": 42, "boolean": True, "null": None},
|
||||
{"array": [1, 2, 3, "four", {"five": 5}]},
|
||||
]
|
||||
|
||||
for original_params in test_cases:
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||
|
||||
def test_encryption_randomness(self):
|
||||
"""Test that encryption produces different results for same input"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted1 = encrypter.encrypt_params(params)
|
||||
encrypted2 = encrypter.encrypt_params(params)
|
||||
|
||||
# Should be different due to random IV
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# But should decrypt to same result
|
||||
decrypted1 = encrypter.decrypt_params(encrypted1)
|
||||
decrypted2 = encrypter.decrypt_params(encrypted2)
|
||||
assert decrypted1 == decrypted2 == params
|
||||
|
||||
def test_different_secret_keys_produce_different_results(self):
|
||||
"""Test that different secret keys produce different encrypted results"""
|
||||
encrypter1 = SystemEncrypter("secret1")
|
||||
encrypter2 = SystemEncrypter("secret2")
|
||||
|
||||
params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted1 = encrypter1.encrypt_params(params)
|
||||
encrypted2 = encrypter2.encrypt_params(params)
|
||||
|
||||
# Should produce different encrypted results
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# But each should decrypt correctly with its own key
|
||||
decrypted1 = encrypter1.decrypt_params(encrypted1)
|
||||
decrypted2 = encrypter2.decrypt_params(encrypted2)
|
||||
assert decrypted1 == decrypted2 == params
|
||||
|
||||
@patch("core.tools.utils.system_encryption.get_random_bytes")
|
||||
def test_encrypt_params_crypto_error(self, mock_get_random_bytes):
|
||||
"""Test encryption when crypto operation fails"""
|
||||
mock_get_random_bytes.side_effect = Exception("Crypto error")
|
||||
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
params = {"client_id": "test_id"}
|
||||
|
||||
with pytest.raises(EncryptionError) as exc_info:
|
||||
encrypter.encrypt_params(params)
|
||||
|
||||
assert "Encryption failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.tools.utils.system_encryption.TypeAdapter")
|
||||
def test_encrypt_params_serialization_error(self, mock_type_adapter):
|
||||
"""Test encryption when JSON serialization fails"""
|
||||
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
|
||||
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
params = {"client_id": "test_id"}
|
||||
|
||||
with pytest.raises(EncryptionError) as exc_info:
|
||||
encrypter.encrypt_params(params)
|
||||
|
||||
assert "Encryption failed" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_params_invalid_json(self):
|
||||
"""Test decryption with invalid JSON data"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
# Create valid encrypted data but with invalid JSON content
|
||||
iv = get_random_bytes(16)
|
||||
cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
|
||||
invalid_json = b"invalid json content"
|
||||
padded_data = pad(invalid_json, AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
combined = iv + encrypted_data
|
||||
encoded = base64.b64encode(combined).decode()
|
||||
|
||||
with pytest.raises(EncryptionError):
|
||||
encrypter.decrypt_params(encoded)
|
||||
|
||||
def test_key_derivation_consistency(self):
|
||||
"""Test that key derivation is consistent"""
|
||||
secret_key = "test_secret"
|
||||
encrypter1 = SystemEncrypter(secret_key)
|
||||
encrypter2 = SystemEncrypter(secret_key)
|
||||
|
||||
assert encrypter1.key == encrypter2.key
|
||||
|
||||
# Keys should be 32 bytes (256 bits)
|
||||
assert len(encrypter1.key) == 32
|
||||
|
||||
|
||||
class TestFactoryFunctions:
|
||||
"""Test cases for factory functions"""
|
||||
|
||||
def test_create_system_encrypter_with_secret(self):
|
||||
"""Test factory function with secret key"""
|
||||
secret_key = "test_secret"
|
||||
encrypter = create_system_encrypter(secret_key)
|
||||
|
||||
assert isinstance(encrypter, SystemEncrypter)
|
||||
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_create_system_encrypter_without_secret(self):
|
||||
"""Test factory function without secret key"""
|
||||
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = create_system_encrypter()
|
||||
|
||||
assert isinstance(encrypter, SystemEncrypter)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_create_system_encrypter_with_none_secret(self):
|
||||
"""Test factory function with None secret key"""
|
||||
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = create_system_encrypter(None)
|
||||
|
||||
assert isinstance(encrypter, SystemEncrypter)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
|
||||
class TestGlobalEncrypterInstance:
|
||||
"""Test cases for global encrypter instance"""
|
||||
|
||||
def test_get_system_encrypter_singleton(self):
|
||||
"""Test that get_system_encrypter returns singleton instance"""
|
||||
# Clear the global instance first
|
||||
import core.tools.utils.system_encryption
|
||||
|
||||
core.tools.utils.system_encryption._encrypter = None
|
||||
|
||||
encrypter1 = get_system_encrypter()
|
||||
encrypter2 = get_system_encrypter()
|
||||
|
||||
assert encrypter1 is encrypter2
|
||||
assert isinstance(encrypter1, SystemEncrypter)
|
||||
|
||||
def test_get_system_encrypter_uses_config(self):
|
||||
"""Test that global encrypter uses config"""
|
||||
# Clear the global instance first
|
||||
import core.tools.utils.system_encryption
|
||||
|
||||
core.tools.utils.system_encryption._encrypter = None
|
||||
|
||||
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "global_secret"
|
||||
encrypter = get_system_encrypter()
|
||||
|
||||
expected_key = hashlib.sha256(b"global_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Test cases for convenience functions"""
|
||||
|
||||
def test_encrypt_system_params(self):
|
||||
"""Test encrypt_system_params convenience function"""
|
||||
params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypt_system_params(params)
|
||||
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_decrypt_system_params(self):
|
||||
"""Test decrypt_system_params convenience function"""
|
||||
params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypt_system_params(params)
|
||||
decrypted = decrypt_system_params(encrypted)
|
||||
|
||||
assert decrypted == params
|
||||
|
||||
def test_convenience_functions_consistency(self):
|
||||
"""Test that convenience functions work consistently"""
|
||||
test_cases = [
|
||||
{},
|
||||
{"simple": "value"},
|
||||
{"client_id": "id", "client_secret": "secret"},
|
||||
{"complex": {"nested": {"deep": "value"}}},
|
||||
{"unicode": "test 🚀"},
|
||||
{"numbers": 42, "boolean": True, "null": None},
|
||||
]
|
||||
|
||||
for original_params in test_cases:
|
||||
encrypted = encrypt_system_params(original_params)
|
||||
decrypted = decrypt_system_params(encrypted)
|
||||
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||
|
||||
def test_convenience_functions_with_errors(self):
|
||||
"""Test convenience functions with error conditions"""
|
||||
# Test encryption with invalid input
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypt_system_params(None)
|
||||
|
||||
# Test decryption with invalid input
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_params("")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_params(None)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test cases for error handling"""
|
||||
|
||||
def test_encryption_error_inheritance(self):
|
||||
"""Test that EncryptionError is a proper exception"""
|
||||
error = EncryptionError("Test error")
|
||||
assert isinstance(error, Exception)
|
||||
assert str(error) == "Test error"
|
||||
|
||||
def test_encryption_error_with_cause(self):
|
||||
"""Test EncryptionError with cause"""
|
||||
original_error = ValueError("Original error")
|
||||
error = EncryptionError("Wrapper error")
|
||||
error.__cause__ = original_error
|
||||
|
||||
assert isinstance(error, Exception)
|
||||
assert str(error) == "Wrapper error"
|
||||
assert error.__cause__ is original_error
|
||||
|
||||
def test_error_messages_are_informative(self):
|
||||
"""Test that error messages are informative"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
# Test empty string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_params("")
|
||||
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||
|
||||
# Test non-string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_params(123)
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
# Test invalid format error
|
||||
short_data = base64.b64encode(b"short").decode()
|
||||
with pytest.raises(EncryptionError) as exc_info:
|
||||
encrypter.decrypt_params(short_data)
|
||||
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test cases for edge cases and boundary conditions"""
|
||||
|
||||
def test_very_long_secret_key(self):
|
||||
"""Test with very long secret key"""
|
||||
long_secret = "x" * 10000
|
||||
encrypter = SystemEncrypter(long_secret)
|
||||
|
||||
# Key should still be 32 bytes due to SHA-256
|
||||
assert len(encrypter.key) == 32
|
||||
|
||||
# Should still work normally
|
||||
params = {"client_id": "test_id"}
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == params
|
||||
|
||||
def test_special_characters_in_secret_key(self):
|
||||
"""Test with special characters in secret key"""
|
||||
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
|
||||
encrypter = SystemEncrypter(special_secret)
|
||||
|
||||
params = {"client_id": "test_id"}
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == params
|
||||
|
||||
def test_empty_values_in_params(self):
|
||||
"""Test with empty values in params"""
|
||||
params = {
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
"empty_dict": {},
|
||||
"empty_list": [],
|
||||
"empty_string": "",
|
||||
"zero": 0,
|
||||
"false": False,
|
||||
"none": None,
|
||||
}
|
||||
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == params
|
||||
|
||||
def test_deeply_nested_params(self):
|
||||
"""Test with deeply nested params"""
|
||||
params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
|
||||
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == params
|
||||
|
||||
def test_params_with_all_json_types(self):
|
||||
"""Test with all JSON-supported data types"""
|
||||
params = {
|
||||
"string": "test_string",
|
||||
"integer": 42,
|
||||
"float": 3.14159,
|
||||
"boolean_true": True,
|
||||
"boolean_false": False,
|
||||
"null_value": None,
|
||||
"empty_string": "",
|
||||
"array": [1, "two", 3.0, True, False, None],
|
||||
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
|
||||
}
|
||||
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == params
|
||||
|
||||
|
||||
class TestPerformance:
|
||||
"""Test cases for performance considerations"""
|
||||
|
||||
def test_large_params(self):
|
||||
"""Test with large params"""
|
||||
large_value = "x" * 100000 # 100KB
|
||||
params = {"client_id": "test_id", "large_data": large_value}
|
||||
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == params
|
||||
|
||||
def test_many_fields_params(self):
|
||||
"""Test with many fields in params"""
|
||||
params = {f"field_{i}": f"value_{i}" for i in range(1000)}
|
||||
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == params
|
||||
|
||||
def test_repeated_encryption_decryption(self):
|
||||
"""Test repeated encryption and decryption operations"""
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
# Test multiple rounds of encryption/decryption
|
||||
for i in range(100):
|
||||
encrypted = encrypter.encrypt_params(params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == params
|
||||
@ -1,619 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad
|
||||
|
||||
from core.tools.utils.system_oauth_encryption import (
|
||||
OAuthEncryptionError,
|
||||
SystemOAuthEncrypter,
|
||||
create_system_oauth_encrypter,
|
||||
decrypt_system_oauth_params,
|
||||
encrypt_system_oauth_params,
|
||||
get_system_oauth_encrypter,
|
||||
)
|
||||
|
||||
|
||||
class TestSystemOAuthEncrypter:
|
||||
"""Test cases for SystemOAuthEncrypter class"""
|
||||
|
||||
def test_init_with_secret_key(self):
|
||||
"""Test initialization with provided secret key"""
|
||||
secret_key = "test_secret_key"
|
||||
encrypter = SystemOAuthEncrypter(secret_key=secret_key)
|
||||
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_with_none_secret_key(self):
|
||||
"""Test initialization with None secret key falls back to config"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = SystemOAuthEncrypter(secret_key=None)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_with_empty_secret_key(self):
|
||||
"""Test initialization with empty secret key"""
|
||||
encrypter = SystemOAuthEncrypter(secret_key="")
|
||||
expected_key = hashlib.sha256(b"").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_without_secret_key_uses_config(self):
|
||||
"""Test initialization without secret key uses config"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "default_secret"
|
||||
encrypter = SystemOAuthEncrypter()
|
||||
expected_key = hashlib.sha256(b"default_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_encrypt_oauth_params_basic(self):
|
||||
"""Test basic OAuth parameters encryption"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
# Should be valid base64
|
||||
try:
|
||||
base64.b64decode(encrypted)
|
||||
except Exception:
|
||||
pytest.fail("Encrypted result is not valid base64")
|
||||
|
||||
def test_encrypt_oauth_params_empty_dict(self):
|
||||
"""Test encryption with empty dictionary"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_complex_data(self):
|
||||
"""Test encryption with complex data structures"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"scopes": ["read", "write", "admin"],
|
||||
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
|
||||
"numeric_value": 42,
|
||||
"boolean_value": False,
|
||||
"null_value": None,
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_unicode_data(self):
|
||||
"""Test encryption with unicode data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_large_data(self):
|
||||
"""Test encryption with large data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {
|
||||
"client_id": "test_id",
|
||||
"large_data": "x" * 10000, # 10KB of data
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_invalid_input(self):
|
||||
"""Test encryption with invalid input types"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params(None)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params("not_a_dict")
|
||||
|
||||
def test_decrypt_oauth_params_basic(self):
|
||||
"""Test basic OAuth parameters decryption"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_empty_dict(self):
|
||||
"""Test decryption of empty dictionary"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_complex_data(self):
|
||||
"""Test decryption with complex data structures"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"scopes": ["read", "write", "admin"],
|
||||
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
|
||||
"numeric_value": 42,
|
||||
"boolean_value": False,
|
||||
"null_value": None,
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_unicode_data(self):
|
||||
"""Test decryption with unicode data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"description": "This is a test case 🚀",
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_large_data(self):
|
||||
"""Test decryption with large data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"large_data": "x" * 10000, # 10KB of data
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_invalid_base64(self):
|
||||
"""Test decryption with invalid base64 data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter.decrypt_oauth_params("invalid_base64!")
|
||||
|
||||
def test_decrypt_oauth_params_empty_string(self):
|
||||
"""Test decryption with empty string"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params("")
|
||||
|
||||
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_non_string_input(self):
|
||||
"""Test decryption with non-string input"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(None)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_too_short_data(self):
|
||||
"""Test decryption with too short encrypted data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
# Create data that's too short (less than 32 bytes)
|
||||
short_data = base64.b64encode(b"short").decode()
|
||||
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(short_data)
|
||||
|
||||
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_corrupted_data(self):
|
||||
"""Test decryption with corrupted data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
# Create corrupted data (valid base64 but invalid encrypted content)
|
||||
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter.decrypt_oauth_params(corrupted_data)
|
||||
|
||||
def test_decrypt_oauth_params_wrong_key(self):
|
||||
"""Test decryption with wrong key"""
|
||||
encrypter1 = SystemOAuthEncrypter("secret1")
|
||||
encrypter2 = SystemOAuthEncrypter("secret2")
|
||||
|
||||
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
encrypted = encrypter1.encrypt_oauth_params(original_params)
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter2.decrypt_oauth_params(encrypted)
|
||||
|
||||
def test_encryption_decryption_consistency(self):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
test_cases = [
|
||||
{},
|
||||
{"simple": "value"},
|
||||
{"client_id": "id", "client_secret": "secret"},
|
||||
{"complex": {"nested": {"deep": "value"}}},
|
||||
{"unicode": "test 🚀"},
|
||||
{"numbers": 42, "boolean": True, "null": None},
|
||||
{"array": [1, 2, 3, "four", {"five": 5}]},
|
||||
]
|
||||
|
||||
for original_params in test_cases:
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||
|
||||
def test_encryption_randomness(self):
|
||||
"""Test that encryption produces different results for same input"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
|
||||
encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
|
||||
|
||||
# Should be different due to random IV
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# But should decrypt to same result
|
||||
decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
|
||||
decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
|
||||
assert decrypted1 == decrypted2 == oauth_params
|
||||
|
||||
def test_different_secret_keys_produce_different_results(self):
|
||||
"""Test that different secret keys produce different encrypted results"""
|
||||
encrypter1 = SystemOAuthEncrypter("secret1")
|
||||
encrypter2 = SystemOAuthEncrypter("secret2")
|
||||
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
|
||||
encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
|
||||
|
||||
# Should produce different encrypted results
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# But each should decrypt correctly with its own key
|
||||
decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
|
||||
decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
|
||||
assert decrypted1 == decrypted2 == oauth_params
|
||||
|
||||
@patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
|
||||
def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
|
||||
"""Test encryption when crypto operation fails"""
|
||||
mock_get_random_bytes.side_effect = Exception("Crypto error")
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.encrypt_oauth_params(oauth_params)
|
||||
|
||||
assert "Encryption failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.tools.utils.system_oauth_encryption.TypeAdapter")
|
||||
def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
|
||||
"""Test encryption when JSON serialization fails"""
|
||||
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.encrypt_oauth_params(oauth_params)
|
||||
|
||||
assert "Encryption failed" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_invalid_json(self):
|
||||
"""Test decryption with invalid JSON data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
# Create valid encrypted data but with invalid JSON content
|
||||
iv = get_random_bytes(16)
|
||||
cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
|
||||
invalid_json = b"invalid json content"
|
||||
padded_data = pad(invalid_json, AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
combined = iv + encrypted_data
|
||||
encoded = base64.b64encode(combined).decode()
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter.decrypt_oauth_params(encoded)
|
||||
|
||||
def test_key_derivation_consistency(self):
|
||||
"""Test that key derivation is consistent"""
|
||||
secret_key = "test_secret"
|
||||
encrypter1 = SystemOAuthEncrypter(secret_key)
|
||||
encrypter2 = SystemOAuthEncrypter(secret_key)
|
||||
|
||||
assert encrypter1.key == encrypter2.key
|
||||
|
||||
# Keys should be 32 bytes (256 bits)
|
||||
assert len(encrypter1.key) == 32
|
||||
|
||||
|
||||
class TestFactoryFunctions:
|
||||
"""Test cases for factory functions"""
|
||||
|
||||
def test_create_system_oauth_encrypter_with_secret(self):
|
||||
"""Test factory function with secret key"""
|
||||
secret_key = "test_secret"
|
||||
encrypter = create_system_oauth_encrypter(secret_key)
|
||||
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_create_system_oauth_encrypter_without_secret(self):
|
||||
"""Test factory function without secret key"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = create_system_oauth_encrypter()
|
||||
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_create_system_oauth_encrypter_with_none_secret(self):
|
||||
"""Test factory function with None secret key"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = create_system_oauth_encrypter(None)
|
||||
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
|
||||
class TestGlobalEncrypterInstance:
|
||||
"""Test cases for global encrypter instance"""
|
||||
|
||||
def test_get_system_oauth_encrypter_singleton(self):
|
||||
"""Test that get_system_oauth_encrypter returns singleton instance"""
|
||||
# Clear the global instance first
|
||||
import core.tools.utils.system_oauth_encryption
|
||||
|
||||
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
|
||||
|
||||
encrypter1 = get_system_oauth_encrypter()
|
||||
encrypter2 = get_system_oauth_encrypter()
|
||||
|
||||
assert encrypter1 is encrypter2
|
||||
assert isinstance(encrypter1, SystemOAuthEncrypter)
|
||||
|
||||
def test_get_system_oauth_encrypter_uses_config(self):
|
||||
"""Test that global encrypter uses config"""
|
||||
# Clear the global instance first
|
||||
import core.tools.utils.system_oauth_encryption
|
||||
|
||||
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
|
||||
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "global_secret"
|
||||
encrypter = get_system_oauth_encrypter()
|
||||
|
||||
expected_key = hashlib.sha256(b"global_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Test cases for convenience functions"""
|
||||
|
||||
def test_encrypt_system_oauth_params(self):
|
||||
"""Test encrypt_system_oauth_params convenience function"""
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypt_system_oauth_params(oauth_params)
|
||||
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_decrypt_system_oauth_params(self):
|
||||
"""Test decrypt_system_oauth_params convenience function"""
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypt_system_oauth_params(oauth_params)
|
||||
decrypted = decrypt_system_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_convenience_functions_consistency(self):
|
||||
"""Test that convenience functions work consistently"""
|
||||
test_cases = [
|
||||
{},
|
||||
{"simple": "value"},
|
||||
{"client_id": "id", "client_secret": "secret"},
|
||||
{"complex": {"nested": {"deep": "value"}}},
|
||||
{"unicode": "test 🚀"},
|
||||
{"numbers": 42, "boolean": True, "null": None},
|
||||
]
|
||||
|
||||
for original_params in test_cases:
|
||||
encrypted = encrypt_system_oauth_params(original_params)
|
||||
decrypted = decrypt_system_oauth_params(encrypted)
|
||||
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||
|
||||
def test_convenience_functions_with_errors(self):
|
||||
"""Test convenience functions with error conditions"""
|
||||
# Test encryption with invalid input
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypt_system_oauth_params(None)
|
||||
|
||||
# Test decryption with invalid input
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params("")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params(None)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test cases for error handling"""
|
||||
|
||||
def test_oauth_encryption_error_inheritance(self):
|
||||
"""Test that OAuthEncryptionError is a proper exception"""
|
||||
error = OAuthEncryptionError("Test error")
|
||||
assert isinstance(error, Exception)
|
||||
assert str(error) == "Test error"
|
||||
|
||||
def test_oauth_encryption_error_with_cause(self):
|
||||
"""Test OAuthEncryptionError with cause"""
|
||||
original_error = ValueError("Original error")
|
||||
error = OAuthEncryptionError("Wrapper error")
|
||||
error.__cause__ = original_error
|
||||
|
||||
assert isinstance(error, Exception)
|
||||
assert str(error) == "Wrapper error"
|
||||
assert error.__cause__ is original_error
|
||||
|
||||
def test_error_messages_are_informative(self):
|
||||
"""Test that error messages are informative"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
# Test empty string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params("")
|
||||
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||
|
||||
# Test non-string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123)
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
# Test invalid format error
|
||||
short_data = base64.b64encode(b"short").decode()
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(short_data)
|
||||
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test cases for edge cases and boundary conditions"""
|
||||
|
||||
def test_very_long_secret_key(self):
|
||||
"""Test with very long secret key"""
|
||||
long_secret = "x" * 10000
|
||||
encrypter = SystemOAuthEncrypter(long_secret)
|
||||
|
||||
# Key should still be 32 bytes due to SHA-256
|
||||
assert len(encrypter.key) == 32
|
||||
|
||||
# Should still work normally
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_special_characters_in_secret_key(self):
|
||||
"""Test with special characters in secret key"""
|
||||
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
|
||||
encrypter = SystemOAuthEncrypter(special_secret)
|
||||
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_empty_values_in_oauth_params(self):
|
||||
"""Test with empty values in oauth params"""
|
||||
oauth_params = {
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
"empty_dict": {},
|
||||
"empty_list": [],
|
||||
"empty_string": "",
|
||||
"zero": 0,
|
||||
"false": False,
|
||||
"none": None,
|
||||
}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_deeply_nested_oauth_params(self):
|
||||
"""Test with deeply nested oauth params"""
|
||||
oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_oauth_params_with_all_json_types(self):
|
||||
"""Test with all JSON-supported data types"""
|
||||
oauth_params = {
|
||||
"string": "test_string",
|
||||
"integer": 42,
|
||||
"float": 3.14159,
|
||||
"boolean_true": True,
|
||||
"boolean_false": False,
|
||||
"null_value": None,
|
||||
"empty_string": "",
|
||||
"array": [1, "two", 3.0, True, False, None],
|
||||
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
|
||||
}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
|
||||
class TestPerformance:
|
||||
"""Test cases for performance considerations"""
|
||||
|
||||
def test_large_oauth_params(self):
|
||||
"""Test with large oauth params"""
|
||||
large_value = "x" * 100000 # 100KB
|
||||
oauth_params = {"client_id": "test_id", "large_data": large_value}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_many_fields_oauth_params(self):
|
||||
"""Test with many fields in oauth params"""
|
||||
oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_repeated_encryption_decryption(self):
|
||||
"""Test repeated encryption and decryption operations"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
# Test multiple rounds of encryption/decryption
|
||||
for i in range(100):
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
@ -1467,6 +1467,11 @@ ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Creators Platform configuration
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||
|
||||
FORCE_VERIFYING_SIGNATURE=true
|
||||
ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES=true
|
||||
|
||||
|
||||
@ -629,6 +629,9 @@ x-shared-env: &shared-api-worker-env
|
||||
ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}}
|
||||
MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true}
|
||||
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai}
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: ${CREATORS_PLATFORM_FEATURES_ENABLED:-true}
|
||||
CREATORS_PLATFORM_API_URL: ${CREATORS_PLATFORM_API_URL:-https://creators.dify.ai}
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: ${CREATORS_PLATFORM_OAUTH_CLIENT_ID:-}
|
||||
FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true}
|
||||
ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES: ${ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES:-true}
|
||||
PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024}
|
||||
|
||||
@ -16,9 +16,9 @@ import { useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
import { setOAuthPendingRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { isLegacyBase401, userProfileQueryOptions } from '@/service/use-common'
|
||||
import { isLegacyBase401, useLogout, userProfileQueryOptions } from '@/service/use-common'
|
||||
import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth'
|
||||
|
||||
function buildReturnUrl(pathname: string, search: string) {
|
||||
@ -73,14 +73,17 @@ export default function OAuthAuthorize() {
|
||||
const userProfile = userProfileResp?.profile
|
||||
const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri)
|
||||
const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp()
|
||||
const { mutateAsync: logout } = useLogout()
|
||||
const hasNotifiedRef = useRef(false)
|
||||
|
||||
const isLoading = isOAuthLoading || isProfileLoading
|
||||
const onLoginSwitchClick = () => {
|
||||
const onLoginSwitchClick = async () => {
|
||||
try {
|
||||
const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`)
|
||||
setPostLoginRedirect(returnUrl)
|
||||
router.push('/signin')
|
||||
const returnUrl = buildReturnUrl('/account/oauth/authorize', `?${searchParams.toString()}`)
|
||||
setOAuthPendingRedirect(returnUrl)
|
||||
if (isLoggedIn)
|
||||
await logout()
|
||||
router.push(`/signin?redirect_url=${encodeURIComponent(returnUrl)}`)
|
||||
}
|
||||
catch {
|
||||
router.push('/signin')
|
||||
|
||||
@ -85,7 +85,7 @@ export const AppInitializer = ({
|
||||
return
|
||||
}
|
||||
|
||||
const redirectUrl = resolvePostLoginRedirect()
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
if (redirectUrl) {
|
||||
location.replace(redirectUrl)
|
||||
return
|
||||
|
||||
@ -80,8 +80,11 @@ vi.mock('@/service/explore', () => ({
|
||||
fetchInstalledAppList: (...args: unknown[]) => mockFetchInstalledAppList(...args),
|
||||
}))
|
||||
|
||||
const mockPublishToCreatorsPlatform = vi.fn()
|
||||
|
||||
vi.mock('@/service/apps', () => ({
|
||||
fetchAppDetailDirect: (...args: unknown[]) => mockFetchAppDetailDirect(...args),
|
||||
publishToCreatorsPlatform: (...args: unknown[]) => mockPublishToCreatorsPlatform(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-workflow', () => ({
|
||||
@ -434,6 +437,76 @@ describe('AppPublisher', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should show marketplace button and open redirect URL on success', async () => {
|
||||
mockPublishToCreatorsPlatform.mockResolvedValue({ redirect_url: 'https://marketplace.example.com/publish?code=abc' })
|
||||
const windowOpenSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||
|
||||
renderWithSystemFeatures(
|
||||
<AppPublisher
|
||||
publishedAt={Date.now()}
|
||||
onPublish={mockOnPublish}
|
||||
/>,
|
||||
{ systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } },
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('common.publish'))
|
||||
fireEvent.click(screen.getByText('common.publishToMarketplace'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPublishToCreatorsPlatform).toHaveBeenCalledWith({ appID: 'app-1' })
|
||||
expect(windowOpenSpy).toHaveBeenCalledWith('https://marketplace.example.com/publish?code=abc', '_blank')
|
||||
})
|
||||
|
||||
windowOpenSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should show toast error when publish to marketplace fails', async () => {
|
||||
mockPublishToCreatorsPlatform.mockRejectedValue(new Error('network error'))
|
||||
|
||||
renderWithSystemFeatures(
|
||||
<AppPublisher
|
||||
publishedAt={Date.now()}
|
||||
onPublish={mockOnPublish}
|
||||
/>,
|
||||
{ systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } },
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('common.publish'))
|
||||
fireEvent.click(screen.getByText('common.publishToMarketplace'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastError).toHaveBeenCalledWith('common.publishToMarketplaceFailed')
|
||||
})
|
||||
})
|
||||
|
||||
it('should disable marketplace button when not yet published', () => {
|
||||
renderWithSystemFeatures(
|
||||
<AppPublisher
|
||||
onPublish={mockOnPublish}
|
||||
/>,
|
||||
{ systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } },
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('common.publish'))
|
||||
const marketplaceButton = screen.getByText('common.publishToMarketplace').closest('a, button, div[role="button"]') as HTMLElement
|
||||
expect(marketplaceButton).toBeInTheDocument()
|
||||
// clicking should not call the API because publishedAt is undefined
|
||||
fireEvent.click(screen.getByText('common.publishToMarketplace'))
|
||||
expect(mockPublishToCreatorsPlatform).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should hide marketplace button when enable_creators_platform is false', () => {
|
||||
render(
|
||||
<AppPublisher
|
||||
publishedAt={Date.now()}
|
||||
onPublish={mockOnPublish}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('common.publish'))
|
||||
expect(screen.queryByText('common.publishToMarketplace')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should keep access control open when app detail is unavailable during confirmation', async () => {
|
||||
mockAppDetail = null
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import type { PublishWorkflowParams } from '@/types/workflow'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import { Popover, PopoverContent, PopoverTrigger } from '@langgenius/dify-ui/popover'
|
||||
import { toast } from '@langgenius/dify-ui/toast'
|
||||
import { RiStoreLine } from '@remixicon/react'
|
||||
import { useSuspenseQuery } from '@tanstack/react-query'
|
||||
import { useKeyPress } from 'ahooks'
|
||||
import {
|
||||
@ -26,7 +27,7 @@ import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { fetchAppDetailDirect } from '@/service/apps'
|
||||
import { fetchAppDetailDirect, publishToCreatorsPlatform } from '@/service/apps'
|
||||
import { fetchInstalledAppList } from '@/service/explore'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import { useInvalidateAppWorkflow } from '@/service/use-workflow'
|
||||
@ -40,6 +41,7 @@ import {
|
||||
PublisherActionsSection,
|
||||
PublisherSummarySection,
|
||||
} from './sections'
|
||||
import SuggestedAction from './suggested-action'
|
||||
import {
|
||||
getDisabledFunctionTooltip,
|
||||
getPublisherAppUrl,
|
||||
@ -100,6 +102,7 @@ const AppPublisher = ({
|
||||
const [showAppAccessControl, setShowAppAccessControl] = useState(false)
|
||||
|
||||
const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false)
|
||||
const [publishingToMarketplace, setPublishingToMarketplace] = useState(false)
|
||||
|
||||
const workflowStore = useContext(WorkflowContext)
|
||||
const appDetail = useAppStore(state => state.appDetail)
|
||||
@ -219,6 +222,23 @@ const AppPublisher = ({
|
||||
}
|
||||
}, [appDetail, setAppDetail])
|
||||
|
||||
const handlePublishToMarketplace = useCallback(async () => {
|
||||
if (!appDetail?.id || publishingToMarketplace)
|
||||
return
|
||||
setPublishingToMarketplace(true)
|
||||
try {
|
||||
const res = await publishToCreatorsPlatform({ appID: appDetail.id })
|
||||
if (res.redirect_url)
|
||||
window.open(res.redirect_url, '_blank')
|
||||
}
|
||||
catch {
|
||||
toast.error(t('common.publishToMarketplaceFailed', { ns: 'workflow' }))
|
||||
}
|
||||
finally {
|
||||
setPublishingToMarketplace(false)
|
||||
}
|
||||
}, [appDetail?.id, publishingToMarketplace, t])
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.shift.p`, (e) => {
|
||||
e.preventDefault()
|
||||
if (publishDisabled || published)
|
||||
@ -336,6 +356,19 @@ const AppPublisher = ({
|
||||
workflowToolAvailable={workflowToolAvailable}
|
||||
workflowToolMessage={workflowToolMessage}
|
||||
/>
|
||||
{systemFeatures.enable_creators_platform && (
|
||||
<div className="border-t border-divider-subtle p-4">
|
||||
<SuggestedAction
|
||||
icon={<RiStoreLine className="h-4 w-4" />}
|
||||
disabled={!publishedAt || publishingToMarketplace}
|
||||
onClick={handlePublishToMarketplace}
|
||||
>
|
||||
{publishingToMarketplace
|
||||
? t('common.publishingToMarketplace', { ns: 'workflow' })
|
||||
: t('common.publishToMarketplace', { ns: 'workflow' })}
|
||||
</SuggestedAction>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</PopoverContent>
|
||||
<EmbeddedModal
|
||||
|
||||
@ -137,7 +137,7 @@ describe('CreateFromDSLModal', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('importFromDSL'))!.toBeInTheDocument()
|
||||
expect(screen.getByText('importApp'))!.toBeInTheDocument()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('demo.yml'))!.toBeInTheDocument()
|
||||
@ -161,7 +161,7 @@ describe('CreateFromDSLModal', () => {
|
||||
})
|
||||
expect(screen.getByPlaceholderText('importFromDSLUrlPlaceholder'))!.toBeInTheDocument()
|
||||
|
||||
const closeTrigger = screen.getByText('importFromDSL').parentElement?.querySelector('.cursor-pointer.items-center') as HTMLElement
|
||||
const closeTrigger = screen.getByText('importApp').parentElement?.querySelector('.cursor-pointer.items-center') as HTMLElement
|
||||
fireEvent.click(closeTrigger)
|
||||
expect(handleClose).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
@ -225,7 +225,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
onClose={noop}
|
||||
>
|
||||
<div className="flex items-center justify-between pt-6 pr-5 pb-3 pl-6 title-2xl-semi-bold text-text-primary">
|
||||
{t('importFromDSL', { ns: 'app' })}
|
||||
{t('importApp', { ns: 'app' })}
|
||||
<div
|
||||
className="flex h-8 w-8 cursor-pointer items-center"
|
||||
onClick={() => onClose()}
|
||||
|
||||
@ -7,9 +7,21 @@ import { useContextSelector } from 'use-context-selector'
|
||||
import AppListContext from '@/context/app-list-context'
|
||||
import { fetchAppDetail } from '@/service/explore'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
import Apps from '../index'
|
||||
|
||||
vi.mock('@/next/dynamic', () => ({
|
||||
default: (loader: () => Promise<{ default: React.ComponentType }>) => {
|
||||
const LazyComp = React.lazy(loader)
|
||||
return function DynamicWrapper(props: Record<string, unknown>) {
|
||||
return React.createElement(
|
||||
React.Suspense,
|
||||
{ fallback: null },
|
||||
React.createElement(LazyComp, props),
|
||||
)
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
let documentTitleCalls: string[] = []
|
||||
let educationInitCalls: number = 0
|
||||
const mockHandleImportDSL = vi.fn()
|
||||
@ -65,6 +77,16 @@ vi.mock('@/hooks/use-import-dsl', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockReplace = vi.fn()
|
||||
let mockSearchParams = new URLSearchParams()
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
replace: mockReplace,
|
||||
}),
|
||||
useSearchParams: () => mockSearchParams,
|
||||
}))
|
||||
|
||||
vi.mock('../list', () => {
|
||||
const MockList = () => {
|
||||
const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel)
|
||||
@ -129,6 +151,16 @@ vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../import-from-marketplace-template-modal', () => ({
|
||||
default: ({ templateId, onClose, onConfirm }: { templateId: string, onClose: () => void, onConfirm: (dsl: string) => void }) => (
|
||||
<div data-testid="marketplace-template-modal">
|
||||
<span data-testid="template-id">{templateId}</span>
|
||||
<button data-testid="close-template" onClick={onClose}>Close Template</button>
|
||||
<button data-testid="confirm-template" onClick={() => onConfirm('yaml-dsl-content')}>Confirm Template</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/explore', () => ({
|
||||
fetchAppDetail: vi.fn(),
|
||||
}))
|
||||
@ -161,6 +193,8 @@ describe('Apps', () => {
|
||||
vi.clearAllMocks()
|
||||
documentTitleCalls = []
|
||||
educationInitCalls = 0
|
||||
mockSearchParams = new URLSearchParams()
|
||||
mockReplace.mockClear()
|
||||
mockFetchAppDetail.mockResolvedValue({
|
||||
id: 'template-1',
|
||||
name: 'Sample App',
|
||||
@ -304,6 +338,66 @@ describe('Apps', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('Marketplace Template', () => {
|
||||
it('should render the template modal when template-id is in search params', async () => {
|
||||
mockSearchParams = new URLSearchParams('template-id=tpl-42')
|
||||
renderWithClient(<Apps />)
|
||||
|
||||
expect(await screen.findByTestId('marketplace-template-modal')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('template-id')).toHaveTextContent('tpl-42')
|
||||
})
|
||||
|
||||
it('should not render the template modal when no template-id is present', () => {
|
||||
renderWithClient(<Apps />)
|
||||
|
||||
expect(screen.queryByTestId('marketplace-template-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should close the template modal and remove template-id from URL', async () => {
|
||||
mockSearchParams = new URLSearchParams('template-id=tpl-42')
|
||||
renderWithClient(<Apps />)
|
||||
|
||||
fireEvent.click(await screen.findByTestId('close-template'))
|
||||
|
||||
expect(mockReplace).toHaveBeenCalledTimes(1)
|
||||
const replaceArg = mockReplace.mock.calls[0]![0] as string
|
||||
expect(replaceArg).not.toContain('template-id')
|
||||
})
|
||||
|
||||
it('should import DSL from marketplace template on confirm', async () => {
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
|
||||
options.onSuccess?.()
|
||||
})
|
||||
mockSearchParams = new URLSearchParams('template-id=tpl-42')
|
||||
renderWithClient(<Apps />)
|
||||
|
||||
fireEvent.click(await screen.findByTestId('confirm-template'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleImportDSL).toHaveBeenCalledWith(
|
||||
{ mode: 'yaml-content', yaml_content: 'yaml-dsl-content' },
|
||||
expect.objectContaining({ onSuccess: expect.any(Function) }),
|
||||
)
|
||||
expect(mockReplace).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show DSL confirm modal when marketplace import is pending', async () => {
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => {
|
||||
options.onPending?.()
|
||||
})
|
||||
mockSearchParams = new URLSearchParams('template-id=tpl-42')
|
||||
renderWithClient(<Apps />)
|
||||
|
||||
fireEvent.click(await screen.findByTestId('confirm-template'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('dsl-confirm-modal')).toBeInTheDocument()
|
||||
expect(mockReplace).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Styling', () => {
|
||||
it('should have overflow-y-auto class', () => {
|
||||
const { container } = renderWithClient(<Apps />)
|
||||
|
||||
@ -0,0 +1,182 @@
|
||||
'use client'
|
||||
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import { Dialog, DialogContent } from '@langgenius/dify-ui/dialog'
|
||||
import { toast } from '@langgenius/dify-ui/toast'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import { useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { MARKETPLACE_API_PREFIX } from '@/config'
|
||||
import {
|
||||
fetchMarketplaceTemplateDSL,
|
||||
useMarketplaceTemplateDetail,
|
||||
} from '@/service/marketplace-templates'
|
||||
|
||||
type ImportFromMarketplaceTemplateModalProps = {
|
||||
templateId: string
|
||||
onClose: () => void
|
||||
onConfirm: (dslContent: string) => void
|
||||
}
|
||||
|
||||
const ImportFromMarketplaceTemplateModal = ({
|
||||
templateId,
|
||||
onClose,
|
||||
onConfirm,
|
||||
}: ImportFromMarketplaceTemplateModalProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { data, isLoading, isError } = useMarketplaceTemplateDetail(templateId)
|
||||
const template = data?.data
|
||||
const [importing, setImporting] = useState(false)
|
||||
const isImportingRef = useRef(false)
|
||||
|
||||
const CATEGORY_I18N_MAP: Record<string, string> = useMemo(() => ({
|
||||
marketing: t('marketplace.template.category.marketing', { ns: 'app' }),
|
||||
sales: t('marketplace.template.category.sales', { ns: 'app' }),
|
||||
support: t('marketplace.template.category.support', { ns: 'app' }),
|
||||
operations: t('marketplace.template.category.operations', { ns: 'app' }),
|
||||
it: t('marketplace.template.category.it', { ns: 'app' }),
|
||||
knowledge: t('marketplace.template.category.knowledge', { ns: 'app' }),
|
||||
design: t('marketplace.template.category.design', { ns: 'app' }),
|
||||
}), [t])
|
||||
|
||||
const translateCategory = useCallback((slug: string) => {
|
||||
return CATEGORY_I18N_MAP[slug] ?? slug
|
||||
}, [CATEGORY_I18N_MAP])
|
||||
|
||||
const handleConfirm = useCallback(async () => {
|
||||
if (isImportingRef.current)
|
||||
return
|
||||
isImportingRef.current = true
|
||||
setImporting(true)
|
||||
try {
|
||||
const dsl = await fetchMarketplaceTemplateDSL(templateId)
|
||||
onConfirm(dsl)
|
||||
}
|
||||
catch {
|
||||
toast.error(t('marketplace.template.importFailed', { ns: 'app' }))
|
||||
}
|
||||
finally {
|
||||
setImporting(false)
|
||||
isImportingRef.current = false
|
||||
}
|
||||
}, [templateId, onConfirm, t])
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open
|
||||
onOpenChange={(open) => {
|
||||
if (!open)
|
||||
onClose()
|
||||
}}
|
||||
>
|
||||
<DialogContent
|
||||
className="w-[520px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg p-0 shadow-xl"
|
||||
>
|
||||
<div className="flex items-center justify-between pt-6 pr-5 pb-3 pl-6 title-2xl-semi-bold text-text-primary">
|
||||
{t('marketplace.template.modalTitle', { ns: 'app' })}
|
||||
<div
|
||||
className="flex h-8 w-8 cursor-pointer items-center"
|
||||
onClick={onClose}
|
||||
>
|
||||
<RiCloseLine className="h-5 w-5 text-text-tertiary" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="px-6 py-4">
|
||||
{isLoading && (
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<div className="system-md-regular text-text-tertiary">Loading...</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isError && (
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<div className="system-md-regular text-text-destructive">
|
||||
{t('marketplace.template.fetchFailed', { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{template && (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center gap-3">
|
||||
{template.icon_file_key
|
||||
? (
|
||||
<img
|
||||
src={`${MARKETPLACE_API_PREFIX}/templates/${template.id}/icon`}
|
||||
alt={template.template_name}
|
||||
className="h-10 w-10 rounded-lg object-cover"
|
||||
/>
|
||||
)
|
||||
: (
|
||||
<div
|
||||
className="flex h-10 w-10 items-center justify-center rounded-lg text-xl"
|
||||
style={{ background: template.icon_background || '#F3F4F6' }}
|
||||
>
|
||||
{template.icon || '📄'}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex flex-col">
|
||||
<div className="system-md-semibold text-text-primary">{template.template_name}</div>
|
||||
<div className="flex items-center gap-1 system-xs-regular text-text-tertiary">
|
||||
<span>
|
||||
{t('marketplace.template.publishedBy', { ns: 'app' })}
|
||||
{' '}
|
||||
{template.publisher_unique_handle}
|
||||
</span>
|
||||
<span>·</span>
|
||||
<span>
|
||||
{t('marketplace.template.usageCount', { ns: 'app' })}
|
||||
{' '}
|
||||
{template.usage_count}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{template.overview && (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">
|
||||
{t('marketplace.template.overview', { ns: 'app' })}
|
||||
</div>
|
||||
<div className="system-sm-regular text-text-secondary">
|
||||
{template.overview}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{template.categories.length > 0 && (
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
{template.categories.map(cat => (
|
||||
<span
|
||||
key={cat}
|
||||
className="inline-flex items-center rounded-full bg-components-label-gray px-2.5 py-1 system-sm-regular text-text-secondary"
|
||||
>
|
||||
{translateCategory(cat)}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end px-6 py-5">
|
||||
<Button className="mr-2" onClick={onClose}>
|
||||
{t('newApp.Cancel', { ns: 'app' })}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
disabled={isLoading || isError || importing}
|
||||
loading={importing}
|
||||
onClick={handleConfirm}
|
||||
>
|
||||
{t('marketplace.template.importConfirm', { ns: 'app' })}
|
||||
</Button>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
export default ImportFromMarketplaceTemplateModal
|
||||
@ -9,6 +9,7 @@ import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { useImportDSL } from '@/hooks/use-import-dsl'
|
||||
import { DSLImportMode } from '@/models/app'
|
||||
import dynamic from '@/next/dynamic'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { fetchAppDetail } from '@/service/explore'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import List from './list'
|
||||
@ -16,9 +17,14 @@ import List from './list'
|
||||
const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false })
|
||||
const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false })
|
||||
const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false })
|
||||
const ImportFromMarketplaceTemplateModal = dynamic(() => import('./import-from-marketplace-template-modal'), { ssr: false })
|
||||
|
||||
const Apps = () => {
|
||||
const { t } = useTranslation()
|
||||
const searchParams = useSearchParams()
|
||||
const { replace } = useRouter()
|
||||
const templateId = searchParams.get('template-id')
|
||||
const templateDismissedRef = useRef(false)
|
||||
|
||||
useDocumentTitle(t('menus.apps', { ns: 'common' }))
|
||||
useEducationInit()
|
||||
@ -58,6 +64,14 @@ const Apps = () => {
|
||||
|
||||
const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false)
|
||||
|
||||
const handleCloseTemplateModal = useCallback(() => {
|
||||
templateDismissedRef.current = true
|
||||
const params = new URLSearchParams(searchParams.toString())
|
||||
params.delete('template-id')
|
||||
const query = params.toString()
|
||||
replace(query ? `?${query}` : window.location.pathname, { scroll: false })
|
||||
}, [searchParams, replace])
|
||||
|
||||
const {
|
||||
handleImportDSL,
|
||||
handleImportDSLConfirm,
|
||||
@ -74,6 +88,22 @@ const Apps = () => {
|
||||
})
|
||||
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
|
||||
|
||||
const handleMarketplaceTemplateConfirm = useCallback(async (dslContent: string) => {
|
||||
await handleImportDSL({
|
||||
mode: DSLImportMode.YAML_CONTENT,
|
||||
yaml_content: dslContent,
|
||||
}, {
|
||||
onSuccess: () => {
|
||||
handleCloseTemplateModal()
|
||||
onSuccess()
|
||||
},
|
||||
onPending: () => {
|
||||
handleCloseTemplateModal()
|
||||
setShowDSLConfirmModal(true)
|
||||
},
|
||||
})
|
||||
}, [handleImportDSL, handleCloseTemplateModal, onSuccess])
|
||||
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
|
||||
name,
|
||||
icon_type,
|
||||
@ -152,6 +182,14 @@ const Apps = () => {
|
||||
onHide={() => setIsShowCreateModal(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{templateId && !templateDismissedRef.current && (
|
||||
<ImportFromMarketplaceTemplateModal
|
||||
templateId={templateId}
|
||||
onClose={handleCloseTemplateModal}
|
||||
onConfirm={handleMarketplaceTemplateConfirm}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</AppListContext.Provider>
|
||||
)
|
||||
|
||||
@ -156,7 +156,7 @@ describe('PanelContextmenu', () => {
|
||||
fireEvent.click(screen.getByText('common.run'))
|
||||
fireEvent.click(screen.getByText('common.pasteHere'))
|
||||
fireEvent.click(screen.getByText('export'))
|
||||
fireEvent.click(screen.getByText('common.importDSL'))
|
||||
fireEvent.click(screen.getByText('importApp'))
|
||||
clickAwayHandler?.()
|
||||
|
||||
expect(mockHandleAddNote).toHaveBeenCalledTimes(1)
|
||||
|
||||
@ -137,7 +137,7 @@ const PanelContextmenu = () => {
|
||||
className="flex h-8 cursor-pointer items-center justify-between rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover"
|
||||
onClick={() => setShowImportDSLModal(true)}
|
||||
>
|
||||
{t('common.importDSL', { ns: 'workflow' })}
|
||||
{t('importApp', { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -205,7 +205,7 @@ const UpdateDSLModal = ({
|
||||
onClose={onCancel}
|
||||
>
|
||||
<div className="mb-3 flex items-center justify-between">
|
||||
<div className="title-2xl-semi-bold text-text-primary">{t('common.importDSL', { ns: 'workflow' })}</div>
|
||||
<div className="title-2xl-semi-bold text-text-primary">{t('importApp', { ns: 'app' })}</div>
|
||||
<div className="flex h-[22px] w-[22px] cursor-pointer items-center justify-center" onClick={onCancel}>
|
||||
<RiCloseLine className="h-[18px] w-[18px] text-text-tertiary" />
|
||||
</div>
|
||||
|
||||
@ -1,18 +1,23 @@
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Link from '@/next/link'
|
||||
import { redirect } from '@/next/navigation'
|
||||
|
||||
const Home = async () => {
|
||||
return (
|
||||
<div className="flex min-h-screen flex-col justify-center py-12 sm:px-6 lg:px-8">
|
||||
type HomePageProps = {
|
||||
searchParams: Promise<Record<string, string | string[] | undefined>>
|
||||
}
|
||||
|
||||
<div className="sm:mx-auto sm:w-full sm:max-w-md">
|
||||
<Loading type="area" />
|
||||
<div className="mt-10 text-center">
|
||||
<Link href="/apps">🚀</Link>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
const Home = async ({ searchParams }: HomePageProps) => {
|
||||
const resolvedSearchParams = await searchParams
|
||||
const urlSearchParams = new URLSearchParams()
|
||||
Object.entries(resolvedSearchParams).forEach(([key, value]) => {
|
||||
if (value === undefined)
|
||||
return
|
||||
if (Array.isArray(value)) {
|
||||
value.forEach(item => urlSearchParams.append(key, item))
|
||||
return
|
||||
}
|
||||
urlSearchParams.set(key, value)
|
||||
})
|
||||
const queryString = urlSearchParams.toString()
|
||||
redirect(queryString ? `/apps?${queryString}` : '/apps')
|
||||
}
|
||||
|
||||
export default Home
|
||||
|
||||
@ -51,7 +51,7 @@ export default function CheckCode() {
|
||||
router.replace(`/signin/invite-settings?${searchParams.toString()}`)
|
||||
}
|
||||
else {
|
||||
const redirectUrl = resolvePostLoginRedirect()
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
router.replace(redirectUrl || '/apps')
|
||||
}
|
||||
}
|
||||
|
||||
@ -75,7 +75,7 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis
|
||||
router.replace(`/signin/invite-settings?${searchParams.toString()}`)
|
||||
}
|
||||
else {
|
||||
const redirectUrl = resolvePostLoginRedirect()
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
router.replace(redirectUrl || '/apps')
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ export default function InviteSettingsPage() {
|
||||
if (res.result === 'success') {
|
||||
// Tokens are now stored in cookies by the backend
|
||||
await setLocaleOnClient(language!, false)
|
||||
const redirectUrl = resolvePostLoginRedirect()
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
router.replace(redirectUrl || '/apps')
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,7 +49,7 @@ const NormalForm = () => {
|
||||
try {
|
||||
if (isLoggedIn) {
|
||||
setIsRedirecting(true)
|
||||
const redirectUrl = resolvePostLoginRedirect()
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
router.replace(redirectUrl || '/apps')
|
||||
return
|
||||
}
|
||||
|
||||
@ -1,15 +1,63 @@
|
||||
let postLoginRedirect: string | null = null
|
||||
import type { ReadonlyURLSearchParams } from '@/next/navigation'
|
||||
|
||||
export const setPostLoginRedirect = (value: string | null) => {
|
||||
postLoginRedirect = value
|
||||
const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending_redirect'
|
||||
const REDIRECT_URL_KEY = 'redirect_url'
|
||||
|
||||
type OAuthPendingRedirect = {
|
||||
value?: string
|
||||
expiry?: number
|
||||
}
|
||||
|
||||
export const resolvePostLoginRedirect = () => {
|
||||
if (postLoginRedirect) {
|
||||
const redirectUrl = postLoginRedirect
|
||||
postLoginRedirect = null
|
||||
return redirectUrl
|
||||
const getCurrentUnixTimestamp = () => Math.floor(Date.now() / 1000)
|
||||
|
||||
function removeOAuthPendingRedirect() {
|
||||
try {
|
||||
localStorage.removeItem(OAUTH_AUTHORIZE_PENDING_KEY)
|
||||
}
|
||||
|
||||
return null
|
||||
catch {}
|
||||
}
|
||||
|
||||
function getOAuthPendingRedirect(): string | null {
|
||||
try {
|
||||
const raw = localStorage.getItem(OAUTH_AUTHORIZE_PENDING_KEY)
|
||||
if (!raw)
|
||||
return null
|
||||
removeOAuthPendingRedirect()
|
||||
const item: OAuthPendingRedirect = JSON.parse(raw)
|
||||
if (!item.value || typeof item.expiry !== 'number')
|
||||
return null
|
||||
return getCurrentUnixTimestamp() > item.expiry ? null : item.value
|
||||
}
|
||||
catch {
|
||||
removeOAuthPendingRedirect()
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export function setOAuthPendingRedirect(url: string, ttlSeconds: number = 300) {
|
||||
try {
|
||||
const item: OAuthPendingRedirect = {
|
||||
value: url,
|
||||
expiry: getCurrentUnixTimestamp() + ttlSeconds,
|
||||
}
|
||||
localStorage.setItem(OAUTH_AUTHORIZE_PENDING_KEY, JSON.stringify(item))
|
||||
}
|
||||
catch {}
|
||||
}
|
||||
|
||||
export const resolvePostLoginRedirect = (searchParams?: ReadonlyURLSearchParams) => {
|
||||
if (searchParams) {
|
||||
const redirectUrl = searchParams.get(REDIRECT_URL_KEY)
|
||||
if (redirectUrl) {
|
||||
try {
|
||||
removeOAuthPendingRedirect()
|
||||
return decodeURIComponent(redirectUrl)
|
||||
}
|
||||
catch {
|
||||
removeOAuthPendingRedirect()
|
||||
return redirectUrl
|
||||
}
|
||||
}
|
||||
}
|
||||
return getOAuthPendingRedirect()
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import type { CollectionsAndPluginsSearchParams, MarketplaceCollection, PluginsSearchParams } from '@/app/components/plugins/marketplace/types'
|
||||
import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types'
|
||||
import type { MarketplaceTemplate } from '@/types/marketplace-template'
|
||||
import { type } from '@orpc/contract'
|
||||
import { base } from './base'
|
||||
|
||||
@ -54,3 +55,15 @@ export const searchAdvancedContract = base
|
||||
body: Omit<PluginsSearchParams, 'type'>
|
||||
}>())
|
||||
.output(type<{ data: PluginsFromMarketplaceResponse }>())
|
||||
|
||||
export const templateDetailContract = base
|
||||
.route({
|
||||
path: '/templates/{templateId}',
|
||||
method: 'GET',
|
||||
})
|
||||
.input(type<{
|
||||
params: {
|
||||
templateId: string
|
||||
}
|
||||
}>())
|
||||
.output(type<{ data: MarketplaceTemplate }>())
|
||||
|
||||
@ -42,12 +42,13 @@ import {
|
||||
workflowDraftUpdateFeaturesContract,
|
||||
} from './console/workflow'
|
||||
import { workflowCommentContracts } from './console/workflow-comment'
|
||||
import { collectionPluginsContract, collectionsContract, searchAdvancedContract } from './marketplace'
|
||||
import { collectionPluginsContract, collectionsContract, searchAdvancedContract, templateDetailContract } from './marketplace'
|
||||
|
||||
export const marketplaceRouterContract = {
|
||||
collections: collectionsContract,
|
||||
collectionPlugins: collectionPluginsContract,
|
||||
searchAdvanced: searchAdvancedContract,
|
||||
templateDetail: templateDetailContract,
|
||||
}
|
||||
|
||||
export type MarketPlaceInputs = InferContractRouterInputs<typeof marketplaceRouterContract>
|
||||
|
||||
@ -118,12 +118,29 @@
|
||||
"iconPicker.emoji": "Emoji",
|
||||
"iconPicker.image": "Image",
|
||||
"iconPicker.ok": "OK",
|
||||
"importApp": "Import App",
|
||||
"importDSL": "Import DSL file",
|
||||
"importFromDSL": "Import from DSL",
|
||||
"importFromDSLFile": "From DSL file",
|
||||
"importFromDSLUrl": "From URL",
|
||||
"importFromDSLUrlPlaceholder": "Paste DSL link here",
|
||||
"join": "Join the community",
|
||||
"marketplace.template.categories": "Categories",
|
||||
"marketplace.template.category.design": "Design",
|
||||
"marketplace.template.category.it": "IT",
|
||||
"marketplace.template.category.knowledge": "Knowledge",
|
||||
"marketplace.template.category.marketing": "Marketing",
|
||||
"marketplace.template.category.operations": "Operations",
|
||||
"marketplace.template.category.sales": "Sales",
|
||||
"marketplace.template.category.support": "Support",
|
||||
"marketplace.template.fetchFailed": "Failed to fetch template",
|
||||
"marketplace.template.importConfirm": "Import",
|
||||
"marketplace.template.importFailed": "Failed to import template",
|
||||
"marketplace.template.modalTitle": "Import from Marketplace",
|
||||
"marketplace.template.overview": "Overview",
|
||||
"marketplace.template.publishedBy": "By",
|
||||
"marketplace.template.usageCount": "Usage",
|
||||
"marketplace.template.viewOnMarketplace": "View on Marketplace",
|
||||
"maxActiveRequests": "Max concurrent requests",
|
||||
"maxActiveRequestsPlaceholder": "Enter 0 for unlimited",
|
||||
"maxActiveRequestsTip": "Maximum number of concurrent active requests per app (0 for unlimited)",
|
||||
|
||||
@ -229,9 +229,12 @@
|
||||
"common.previewPlaceholder": "Enter content in the box below to start debugging the Chatbot",
|
||||
"common.processData": "Process Data",
|
||||
"common.publish": "Publish",
|
||||
"common.publishToMarketplace": "Publish to Marketplace",
|
||||
"common.publishToMarketplaceFailed": "Failed to publish to Marketplace",
|
||||
"common.publishUpdate": "Publish Update",
|
||||
"common.published": "Published",
|
||||
"common.publishedAt": "Published",
|
||||
"common.publishingToMarketplace": "Publishing...",
|
||||
"common.redo": "Redo",
|
||||
"common.restart": "Restart",
|
||||
"common.restore": "Restore",
|
||||
|
||||
@ -118,12 +118,29 @@
|
||||
"iconPicker.emoji": "表情符号",
|
||||
"iconPicker.image": "图片",
|
||||
"iconPicker.ok": "确认",
|
||||
"importApp": "导入应用",
|
||||
"importDSL": "导入 DSL 文件",
|
||||
"importFromDSL": "导入 DSL",
|
||||
"importFromDSLFile": "文件",
|
||||
"importFromDSLUrl": "URL",
|
||||
"importFromDSLUrlPlaceholder": "输入 DSL 文件的 URL",
|
||||
"join": "参与社区",
|
||||
"marketplace.template.categories": "分类",
|
||||
"marketplace.template.category.design": "设计",
|
||||
"marketplace.template.category.it": "IT",
|
||||
"marketplace.template.category.knowledge": "知识",
|
||||
"marketplace.template.category.marketing": "营销",
|
||||
"marketplace.template.category.operations": "运营",
|
||||
"marketplace.template.category.sales": "销售",
|
||||
"marketplace.template.category.support": "支持",
|
||||
"marketplace.template.fetchFailed": "获取模板失败",
|
||||
"marketplace.template.importConfirm": "导入",
|
||||
"marketplace.template.importFailed": "导入模板失败",
|
||||
"marketplace.template.modalTitle": "从市场导入",
|
||||
"marketplace.template.overview": "概述",
|
||||
"marketplace.template.publishedBy": "来自",
|
||||
"marketplace.template.usageCount": "使用次数",
|
||||
"marketplace.template.viewOnMarketplace": "在市场查看",
|
||||
"maxActiveRequests": "最大活跃请求数",
|
||||
"maxActiveRequestsPlaceholder": "0 表示不限制",
|
||||
"maxActiveRequestsTip": "当前应用的最大活跃请求数(0 表示不限制)",
|
||||
|
||||
@ -229,9 +229,12 @@
|
||||
"common.previewPlaceholder": "在下面的框中输入内容开始调试聊天机器人",
|
||||
"common.processData": "数据处理",
|
||||
"common.publish": "发布",
|
||||
"common.publishToMarketplace": "发布到市场",
|
||||
"common.publishToMarketplaceFailed": "发布到市场失败",
|
||||
"common.publishUpdate": "发布更新",
|
||||
"common.published": "已发布",
|
||||
"common.publishedAt": "发布于",
|
||||
"common.publishingToMarketplace": "发布中...",
|
||||
"common.redo": "重做",
|
||||
"common.restart": "重新开始",
|
||||
"common.restore": "恢复",
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
export {
|
||||
redirect,
|
||||
useParams,
|
||||
usePathname,
|
||||
useRouter,
|
||||
@ -6,3 +7,4 @@ export {
|
||||
useSelectedLayoutSegment,
|
||||
useSelectedLayoutSegments,
|
||||
} from 'next/navigation'
|
||||
export type { ReadonlyURLSearchParams } from 'next/navigation'
|
||||
|
||||
68
web/service/__tests__/base.spec.ts
Normal file
68
web/service/__tests__/base.spec.ts
Normal file
@ -0,0 +1,68 @@
|
||||
import { buildSigninUrlWithRedirect } from '../base'
|
||||
|
||||
vi.mock('@/utils/var', () => ({
|
||||
basePath: '/app',
|
||||
API_PREFIX: '/console/api',
|
||||
PUBLIC_API_PREFIX: '/api',
|
||||
IS_CE_EDITION: false,
|
||||
}))
|
||||
|
||||
describe('buildSigninUrlWithRedirect', () => {
|
||||
const originalLocation = globalThis.location
|
||||
|
||||
beforeEach(() => {
|
||||
Object.defineProperty(globalThis, 'location', {
|
||||
value: {
|
||||
origin: 'https://example.com',
|
||||
pathname: '/apps',
|
||||
href: 'https://example.com/apps',
|
||||
},
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
Object.defineProperty(globalThis, 'location', {
|
||||
value: originalLocation,
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('should return plain signin URL for non-OAuth pages', () => {
|
||||
const url = buildSigninUrlWithRedirect()
|
||||
expect(url).toBe('https://example.com/app/signin')
|
||||
})
|
||||
|
||||
it('should append redirect_url for OAuth authorize pages', () => {
|
||||
const oauthHref = 'https://example.com/account/oauth/authorize?client_id=abc&state=xyz'
|
||||
Object.defineProperty(globalThis, 'location', {
|
||||
value: {
|
||||
origin: 'https://example.com',
|
||||
pathname: '/account/oauth/authorize',
|
||||
href: oauthHref,
|
||||
},
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
const url = buildSigninUrlWithRedirect()
|
||||
expect(url).toBe(`https://example.com/app/signin?redirect_url=${encodeURIComponent(oauthHref)}`)
|
||||
})
|
||||
|
||||
it('should not include redirect_url for other paths containing partial match', () => {
|
||||
Object.defineProperty(globalThis, 'location', {
|
||||
value: {
|
||||
origin: 'https://example.com',
|
||||
pathname: '/settings/oauth',
|
||||
href: 'https://example.com/settings/oauth',
|
||||
},
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
const url = buildSigninUrlWithRedirect()
|
||||
expect(url).toBe('https://example.com/app/signin')
|
||||
})
|
||||
})
|
||||
@ -192,3 +192,11 @@ export const updateTracingConfig = ({ appId, body }: { appId: string, body: Trac
|
||||
export const removeTracingConfig = ({ appId, provider }: { appId: string, provider: TracingProvider }): Promise<CommonResponse> => {
|
||||
return del<CommonResponse>(`/apps/${appId}/trace-config?tracing_provider=${provider}`)
|
||||
}
|
||||
|
||||
type PublishToCreatorsPlatformResponse = {
|
||||
redirect_url: string
|
||||
}
|
||||
|
||||
export const publishToCreatorsPlatform = ({ appID }: { appID: string }): Promise<PublishToCreatorsPlatformResponse> => {
|
||||
return post<PublishToCreatorsPlatformResponse>(`apps/${appID}/publish-to-creators-platform`, { body: {} })
|
||||
}
|
||||
|
||||
@ -140,6 +140,20 @@ function jumpTo(url: string) {
|
||||
globalThis.location.href = url
|
||||
}
|
||||
|
||||
const OAUTH_AUTHORIZE_PATH = '/account/oauth/authorize'
|
||||
|
||||
export const buildSigninUrlWithRedirect = (): string => {
|
||||
const loginUrl = `${globalThis.location.origin}${basePath}/signin`
|
||||
|
||||
// Only preserve redirect URL for OAuth authorize pages
|
||||
if (globalThis.location.pathname.includes(OAUTH_AUTHORIZE_PATH)) {
|
||||
const currentUrl = globalThis.location.href
|
||||
return `${loginUrl}?redirect_url=${encodeURIComponent(currentUrl)}`
|
||||
}
|
||||
|
||||
return loginUrl
|
||||
}
|
||||
|
||||
function unicodeToChar(text: string) {
|
||||
if (!text)
|
||||
return ''
|
||||
@ -795,14 +809,14 @@ export const request = async<T>(url: string, options = {}, otherOptions?: IOther
|
||||
if (refreshErr === null)
|
||||
return baseFetch<T>(url, options, otherOptionsForBaseFetch)
|
||||
if (location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) {
|
||||
jumpTo(loginUrl)
|
||||
jumpTo(buildSigninUrlWithRedirect())
|
||||
return Promise.reject(err)
|
||||
}
|
||||
if (!silent) {
|
||||
toast.error(message)
|
||||
return Promise.reject(err)
|
||||
}
|
||||
jumpTo(loginUrl)
|
||||
jumpTo(buildSigninUrlWithRedirect())
|
||||
return Promise.reject(err)
|
||||
}
|
||||
else {
|
||||
|
||||
18
web/service/marketplace-templates.ts
Normal file
18
web/service/marketplace-templates.ts
Normal file
@ -0,0 +1,18 @@
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { MARKETPLACE_API_PREFIX } from '@/config'
|
||||
import { marketplaceQuery } from './client'
|
||||
|
||||
export const useMarketplaceTemplateDetail = (templateId: string | null) => {
|
||||
return useQuery({
|
||||
...marketplaceQuery.templateDetail.queryOptions({ input: { params: { templateId: templateId ?? '' } } }),
|
||||
enabled: !!templateId,
|
||||
})
|
||||
}
|
||||
|
||||
export const fetchMarketplaceTemplateDSL = async (templateId: string): Promise<string> => {
|
||||
const url = `${MARKETPLACE_API_PREFIX}/templates/${templateId}/dsl`
|
||||
const response = await fetch(url)
|
||||
if (!response.ok)
|
||||
throw new Error(`Failed to fetch DSL: ${response.statusText}`)
|
||||
return response.text()
|
||||
}
|
||||
@ -64,6 +64,7 @@ export type SystemFeatures = {
|
||||
allow_email_code_login: boolean
|
||||
allow_email_password_login: boolean
|
||||
}
|
||||
enable_creators_platform: boolean
|
||||
enable_trial_app: boolean
|
||||
enable_explore_banner: boolean
|
||||
}
|
||||
@ -108,6 +109,7 @@ export const defaultSystemFeatures: SystemFeatures = {
|
||||
allow_email_code_login: false,
|
||||
allow_email_password_login: false,
|
||||
},
|
||||
enable_creators_platform: false,
|
||||
enable_trial_app: false,
|
||||
enable_explore_banner: false,
|
||||
}
|
||||
|
||||
11
web/types/marketplace-template.ts
Normal file
11
web/types/marketplace-template.ts
Normal file
@ -0,0 +1,11 @@
|
||||
export type MarketplaceTemplate = {
|
||||
id: string
|
||||
template_name: string
|
||||
overview: string
|
||||
icon: string
|
||||
icon_background: string
|
||||
icon_file_key: string
|
||||
publisher_unique_handle: string
|
||||
usage_count: number
|
||||
categories: string[]
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user