feat(enterprise): auto-join newly registered accounts to the default workspace (#32308)

Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
This commit is contained in:
L1nSn0w 2026-03-01 16:53:09 +08:00 committed by GitHub
parent 6a3db151a8
commit 337161cdb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 371 additions and 2 deletions

View File

@ -289,6 +289,12 @@ class AccountService:
TenantService.create_owner_tenant_if_not_exist(account=account)
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
if dify_config.ENTERPRISE_ENABLED:
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(str(account.id))
return account
@staticmethod
@ -1407,6 +1413,12 @@ class RegisterService:
tenant_was_created.send(tenant)
db.session.commit()
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
if dify_config.ENTERPRISE_ENABLED:
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(str(account.id))
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
logger.exception("Register failed")

View File

@ -39,6 +39,9 @@ class BaseRequest:
endpoint: str,
json: Any | None = None,
params: Mapping[str, Any] | None = None,
*,
timeout: float | httpx.Timeout | None = None,
raise_for_status: bool = False,
) -> Any:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
@ -53,7 +56,16 @@ class BaseRequest:
logger.debug("Failed to generate traceparent header", exc_info=True)
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
# IMPORTANT:
# - In httpx, passing timeout=None disables timeouts (infinite) and overrides the library default.
# - To preserve httpx's default timeout behavior for existing call sites, only pass the kwarg when set.
request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": headers}
if timeout is not None:
request_kwargs["timeout"] = timeout
response = client.request(method, url, **request_kwargs)
if raise_for_status:
response.raise_for_status()
return response.json()

View File

@ -1,9 +1,16 @@
import logging
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
from services.enterprise.base import EnterpriseRequest
logger = logging.getLogger(__name__)
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
class WebAppSettings(BaseModel):
access_mode: str = Field(
@ -30,6 +37,55 @@ class WorkspacePermission(BaseModel):
)
class DefaultWorkspaceJoinResult(BaseModel):
"""
Result of ensuring an account is a member of the enterprise default workspace.
- joined=True is idempotent (already a member also returns True)
- joined=False means enterprise default workspace is not configured or invalid/archived
"""
workspace_id: str = Field(default="", alias="workspaceId")
joined: bool
message: str
model_config = ConfigDict(extra="forbid", populate_by_name=True)
@model_validator(mode="after")
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
if self.joined and not self.workspace_id:
raise ValueError("workspace_id must be non-empty when joined is True")
return self
def try_join_default_workspace(account_id: str) -> None:
"""
Enterprise-only side-effect: ensure account is a member of the default workspace.
This is a best-effort integration. Failures must not block user registration.
"""
if not dify_config.ENTERPRISE_ENABLED:
return
try:
result = EnterpriseService.join_default_workspace(account_id=account_id)
if result.joined:
logger.info(
"Joined enterprise default workspace for account %s (workspace_id=%s)",
account_id,
result.workspace_id,
)
else:
logger.info(
"Skipped joining enterprise default workspace for account %s (message=%s)",
account_id,
result.message,
)
except Exception:
logger.warning("Failed to join enterprise default workspace for account %s", account_id, exc_info=True)
class EnterpriseService:
@classmethod
def get_info(cls):
@ -39,6 +95,34 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
"""
Call enterprise inner API to add an account to the default workspace.
NOTE: EnterpriseRequest.base_url is expected to already include the `/inner/api` prefix,
so the endpoint here is `/default-workspace/members`.
"""
# Ensure we are sending a UUID-shaped string (enterprise side validates too).
try:
uuid.UUID(account_id)
except ValueError as e:
raise ValueError(f"account_id must be a valid UUID: {account_id}") from e
data = EnterpriseRequest.send_request(
"POST",
"/default-workspace/members",
json={"account_id": account_id},
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
raise_for_status=True,
)
if not isinstance(data, dict):
raise ValueError("Invalid response format from enterprise default workspace API")
if "joined" not in data or "message" not in data:
raise ValueError("Invalid response payload from enterprise default workspace API")
return DefaultWorkspaceJoinResult.model_validate(data)
@classmethod
def get_app_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")

View File

@ -0,0 +1,141 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
"""
from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
)
class TestJoinDefaultWorkspace:
def test_join_default_workspace_success(self):
account_id = "11111111-1111-1111-1111-111111111111"
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = response
result = EnterpriseService.join_default_workspace(account_id=account_id)
assert isinstance(result, DefaultWorkspaceJoinResult)
assert result.workspace_id == response["workspace_id"]
assert result.joined is True
assert result.message == "ok"
mock_send_request.assert_called_once_with(
"POST",
"/default-workspace/members",
json={"account_id": account_id},
timeout=1.0,
raise_for_status=True,
)
def test_join_default_workspace_invalid_response_format_raises(self):
account_id = "11111111-1111-1111-1111-111111111111"
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = "not-a-dict"
with pytest.raises(ValueError, match="Invalid response format"):
EnterpriseService.join_default_workspace(account_id=account_id)
def test_join_default_workspace_invalid_account_id_raises(self):
with pytest.raises(ValueError):
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
def test_join_default_workspace_missing_required_fields_raises(self):
account_id = "11111111-1111-1111-1111-111111111111"
response = {"workspace_id": "", "message": "ok"} # missing "joined"
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = response
with pytest.raises(ValueError, match="Invalid response payload"):
EnterpriseService.join_default_workspace(account_id=account_id)
def test_join_default_workspace_joined_without_workspace_id_raises(self):
with pytest.raises(ValueError, match="workspace_id must be non-empty when joined is True"):
DefaultWorkspaceJoinResult(workspace_id="", joined=True, message="ok")
class TestTryJoinDefaultWorkspace:
def test_try_join_default_workspace_enterprise_disabled_noop(self):
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = False
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
mock_join.assert_not_called()
def test_try_join_default_workspace_successful_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="22222222-2222-2222-2222-222222222222",
joined=True,
message="ok",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="",
joined=False,
message="no default workspace configured",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_api_failure_soft_fails(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.side_effect = Exception("network failure")
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = True
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")

View File

@ -1064,6 +1064,67 @@ class TestRegisterService:
# ==================== Registration Tests ====================
def test_create_account_and_tenant_calls_default_workspace_join_when_enterprise_enabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should be invoked when ENTERPRISE_ENABLED is True."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
result = AccountService.create_account_and_tenant(
email="test@example.com",
name="Test User",
interface_language="en-US",
password=None,
)
assert result == mock_account
mock_create_workspace.assert_called_once_with(account=mock_account)
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
AccountService.create_account_and_tenant(
email="test@example.com",
name="Test User",
interface_language="en-US",
password=None,
)
mock_create_workspace.assert_called_once_with(account=mock_account)
mock_join_default_workspace.assert_not_called()
def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies):
"""Test successful account registration."""
# Setup mocks
@ -1115,6 +1176,65 @@ class TestRegisterService:
mock_event.send.assert_called_once_with(mock_tenant)
self._assert_database_operations_called(mock_db_dependencies["db"])
def test_register_calls_default_workspace_join_when_enterprise_enabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should be invoked after successful register commit."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
result = RegisterService.register(
email="test@example.com",
name="Test User",
password="password123",
language="en-US",
create_workspace_required=False,
)
assert result == mock_account
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
def test_register_does_not_call_default_workspace_join_when_enterprise_disabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
RegisterService.register(
email="test@example.com",
name="Test User",
password="password123",
language="en-US",
create_workspace_required=False,
)
mock_join_default_workspace.assert_not_called()
def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies):
"""Test account registration with OAuth integration."""
# Setup mocks