From 337161cdb96177aebe0066db62f558b3486f4710 Mon Sep 17 00:00:00 2001 From: L1nSn0w Date: Sun, 1 Mar 2026 16:53:09 +0800 Subject: [PATCH] feat(enterprise): auto-join newly registered accounts to the default workspace (#32308) Co-authored-by: Yunlu Wen --- api/services/account_service.py | 12 ++ api/services/enterprise/base.py | 14 +- api/services/enterprise/enterprise_service.py | 86 ++++++++++- .../enterprise/test_enterprise_service.py | 141 ++++++++++++++++++ .../services/test_account_service.py | 120 +++++++++++++++ 5 files changed, 371 insertions(+), 2 deletions(-) create mode 100644 api/tests/unit_tests/services/enterprise/test_enterprise_service.py diff --git a/api/services/account_service.py b/api/services/account_service.py index b4b25a1194..648b5e834f 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -289,6 +289,12 @@ class AccountService: TenantService.create_owner_tenant_if_not_exist(account=account) + # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). + if dify_config.ENTERPRISE_ENABLED: + from services.enterprise.enterprise_service import try_join_default_workspace + + try_join_default_workspace(str(account.id)) + return account @staticmethod @@ -1407,6 +1413,12 @@ class RegisterService: tenant_was_created.send(tenant) db.session.commit() + + # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). + if dify_config.ENTERPRISE_ENABLED: + from services.enterprise.enterprise_service import try_join_default_workspace + + try_join_default_workspace(str(account.id)) except WorkSpaceNotAllowedCreateError: db.session.rollback() logger.exception("Register failed") diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index e3832475aa..744b7992f8 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -39,6 +39,9 @@ class BaseRequest: endpoint: str, json: Any | None = None, params: Mapping[str, Any] | None = None, + *, + timeout: float | httpx.Timeout | None = None, + raise_for_status: bool = False, ) -> Any: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" @@ -53,7 +56,16 @@ class BaseRequest: logger.debug("Failed to generate traceparent header", exc_info=True) with httpx.Client(mounts=mounts) as client: - response = client.request(method, url, json=json, params=params, headers=headers) + # IMPORTANT: + # - In httpx, passing timeout=None disables timeouts (infinite) and overrides the library default. + # - To preserve httpx's default timeout behavior for existing call sites, only pass the kwarg when set. + request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": headers} + if timeout is not None: + request_kwargs["timeout"] = timeout + + response = client.request(method, url, **request_kwargs) + if raise_for_status: + response.raise_for_status() return response.json() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index a5133dfcb4..71d456aa2d 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,9 +1,16 @@ +import logging +import uuid from datetime import datetime -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator +from configs import dify_config from services.enterprise.base import EnterpriseRequest +logger = logging.getLogger(__name__) + +DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0 + class WebAppSettings(BaseModel): access_mode: str = Field( @@ -30,6 +37,55 @@ class WorkspacePermission(BaseModel): ) +class DefaultWorkspaceJoinResult(BaseModel): + """ + Result of ensuring an account is a member of the enterprise default workspace. + + - joined=True is idempotent (already a member also returns True) + - joined=False means enterprise default workspace is not configured or invalid/archived + """ + + workspace_id: str = Field(default="", alias="workspaceId") + joined: bool + message: str + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + @model_validator(mode="after") + def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult": + if self.joined and not self.workspace_id: + raise ValueError("workspace_id must be non-empty when joined is True") + return self + + +def try_join_default_workspace(account_id: str) -> None: + """ + Enterprise-only side-effect: ensure account is a member of the default workspace. + + This is a best-effort integration. Failures must not block user registration. + """ + + if not dify_config.ENTERPRISE_ENABLED: + return + + try: + result = EnterpriseService.join_default_workspace(account_id=account_id) + if result.joined: + logger.info( + "Joined enterprise default workspace for account %s (workspace_id=%s)", + account_id, + result.workspace_id, + ) + else: + logger.info( + "Skipped joining enterprise default workspace for account %s (message=%s)", + account_id, + result.message, + ) + except Exception: + logger.warning("Failed to join enterprise default workspace for account %s", account_id, exc_info=True) + + class EnterpriseService: @classmethod def get_info(cls): @@ -39,6 +95,34 @@ class EnterpriseService: def get_workspace_info(cls, tenant_id: str): return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") + @classmethod + def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult: + """ + Call enterprise inner API to add an account to the default workspace. + + NOTE: EnterpriseRequest.base_url is expected to already include the `/inner/api` prefix, + so the endpoint here is `/default-workspace/members`. + """ + + # Ensure we are sending a UUID-shaped string (enterprise side validates too). + try: + uuid.UUID(account_id) + except ValueError as e: + raise ValueError(f"account_id must be a valid UUID: {account_id}") from e + + data = EnterpriseRequest.send_request( + "POST", + "/default-workspace/members", + json={"account_id": account_id}, + timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS, + raise_for_status=True, + ) + if not isinstance(data, dict): + raise ValueError("Invalid response format from enterprise default workspace API") + if "joined" not in data or "message" not in data: + raise ValueError("Invalid response payload from enterprise default workspace API") + return DefaultWorkspaceJoinResult.model_validate(data) + @classmethod def get_app_sso_settings_last_update_time(cls) -> datetime: data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time") diff --git a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py new file mode 100644 index 0000000000..03c4f793cf --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py @@ -0,0 +1,141 @@ +"""Unit tests for enterprise service integrations. + +This module covers the enterprise-only default workspace auto-join behavior: +- Enterprise mode disabled: no external calls +- Successful join / skipped join: no errors +- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise +""" + +from unittest.mock import patch + +import pytest + +from services.enterprise.enterprise_service import ( + DefaultWorkspaceJoinResult, + EnterpriseService, + try_join_default_workspace, +) + + +class TestJoinDefaultWorkspace: + def test_join_default_workspace_success(self): + account_id = "11111111-1111-1111-1111-111111111111" + response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"} + + with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request: + mock_send_request.return_value = response + + result = EnterpriseService.join_default_workspace(account_id=account_id) + + assert isinstance(result, DefaultWorkspaceJoinResult) + assert result.workspace_id == response["workspace_id"] + assert result.joined is True + assert result.message == "ok" + + mock_send_request.assert_called_once_with( + "POST", + "/default-workspace/members", + json={"account_id": account_id}, + timeout=1.0, + raise_for_status=True, + ) + + def test_join_default_workspace_invalid_response_format_raises(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request: + mock_send_request.return_value = "not-a-dict" + + with pytest.raises(ValueError, match="Invalid response format"): + EnterpriseService.join_default_workspace(account_id=account_id) + + def test_join_default_workspace_invalid_account_id_raises(self): + with pytest.raises(ValueError): + EnterpriseService.join_default_workspace(account_id="not-a-uuid") + + def test_join_default_workspace_missing_required_fields_raises(self): + account_id = "11111111-1111-1111-1111-111111111111" + response = {"workspace_id": "", "message": "ok"} # missing "joined" + + with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request: + mock_send_request.return_value = response + + with pytest.raises(ValueError, match="Invalid response payload"): + EnterpriseService.join_default_workspace(account_id=account_id) + + def test_join_default_workspace_joined_without_workspace_id_raises(self): + with pytest.raises(ValueError, match="workspace_id must be non-empty when joined is True"): + DefaultWorkspaceJoinResult(workspace_id="", joined=True, message="ok") + + +class TestTryJoinDefaultWorkspace: + def test_try_join_default_workspace_enterprise_disabled_noop(self): + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = False + + try_join_default_workspace("11111111-1111-1111-1111-111111111111") + + mock_join.assert_not_called() + + def test_try_join_default_workspace_successful_join_does_not_raise(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_join.return_value = DefaultWorkspaceJoinResult( + workspace_id="22222222-2222-2222-2222-222222222222", + joined=True, + message="ok", + ) + + # Should not raise + try_join_default_workspace(account_id) + + mock_join.assert_called_once_with(account_id=account_id) + + def test_try_join_default_workspace_skipped_join_does_not_raise(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_join.return_value = DefaultWorkspaceJoinResult( + workspace_id="", + joined=False, + message="no default workspace configured", + ) + + # Should not raise + try_join_default_workspace(account_id) + + mock_join.assert_called_once_with(account_id=account_id) + + def test_try_join_default_workspace_api_failure_soft_fails(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_join.side_effect = Exception("network failure") + + # Should not raise + try_join_default_workspace(account_id) + + mock_join.assert_called_once_with(account_id=account_id) + + def test_try_join_default_workspace_invalid_account_id_soft_fails(self): + with patch("services.enterprise.enterprise_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Should not raise even though UUID parsing fails inside join_default_workspace + try_join_default_workspace("not-a-uuid") diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 1fc45d1c35..635c86a14b 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1064,6 +1064,67 @@ class TestRegisterService: # ==================== Registration Tests ==================== + def test_create_account_and_tenant_calls_default_workspace_join_when_enterprise_enabled( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Enterprise-only side effect should be invoked when ENTERPRISE_ENABLED is True.""" + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + + result = AccountService.create_account_and_tenant( + email="test@example.com", + name="Test User", + interface_language="en-US", + password=None, + ) + + assert result == mock_account + mock_create_workspace.assert_called_once_with(account=mock_account) + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + + def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False.""" + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False) + + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + + AccountService.create_account_and_tenant( + email="test@example.com", + name="Test User", + interface_language="en-US", + password=None, + ) + + mock_create_workspace.assert_called_once_with(account=mock_account) + mock_join_default_workspace.assert_not_called() + def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies): """Test successful account registration.""" # Setup mocks @@ -1115,6 +1176,65 @@ class TestRegisterService: mock_event.send.assert_called_once_with(mock_tenant) self._assert_database_operations_called(mock_db_dependencies["db"]) + def test_register_calls_default_workspace_join_when_enterprise_enabled( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Enterprise-only side effect should be invoked after successful register commit.""" + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + + result = RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + create_workspace_required=False, + ) + + assert result == mock_account + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + + def test_register_does_not_call_default_workspace_join_when_enterprise_disabled( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False.""" + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False) + + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + + RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + create_workspace_required=False, + ) + + mock_join_default_workspace.assert_not_called() + def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies): """Test account registration with OAuth integration.""" # Setup mocks