diff --git a/api/services/account_service.py b/api/services/account_service.py index 648b5e834f..f0eac2a522 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -74,6 +74,16 @@ from tasks.mail_reset_password_task import ( logger = logging.getLogger(__name__) +def _try_join_enterprise_default_workspace(account_id: str) -> None: + """Best-effort join to enterprise default workspace.""" + if not dify_config.ENTERPRISE_ENABLED: + return + + from services.enterprise.enterprise_service import try_join_default_workspace + + try_join_default_workspace(account_id) + + class TokenPair(BaseModel): access_token: str refresh_token: str @@ -287,13 +297,14 @@ class AccountService: email=email, name=name, interface_language=interface_language, password=password ) - TenantService.create_owner_tenant_if_not_exist(account=account) + try: + TenantService.create_owner_tenant_if_not_exist(account=account) + except Exception: + # Enterprise-only side-effect should run independently from personal workspace creation. + _try_join_enterprise_default_workspace(str(account.id)) + raise - # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). - if dify_config.ENTERPRISE_ENABLED: - from services.enterprise.enterprise_service import try_join_default_workspace - - try_join_default_workspace(str(account.id)) + _try_join_enterprise_default_workspace(str(account.id)) return account @@ -1407,18 +1418,18 @@ class RegisterService: and create_workspace_required and FeatureService.get_system_features().license.workspaces.is_available() ): - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role="owner") - account.current_tenant = tenant - tenant_was_created.send(tenant) + try: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + except Exception: + _try_join_enterprise_default_workspace(str(account.id)) + raise db.session.commit() - # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). - if dify_config.ENTERPRISE_ENABLED: - from services.enterprise.enterprise_service import try_join_default_workspace - - try_join_default_workspace(str(account.id)) + _try_join_enterprise_default_workspace(str(account.id)) except WorkSpaceNotAllowedCreateError: db.session.rollback() logger.exception("Register failed") diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 635c86a14b..dcd6785464 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1125,6 +1125,38 @@ class TestRegisterService: mock_create_workspace.assert_called_once_with(account=mock_account) mock_join_default_workspace.assert_not_called() + def test_create_account_and_tenant_still_calls_default_workspace_join_when_workspace_creation_fails( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Default workspace join should still be attempted when personal workspace creation fails.""" + from services.errors.workspace import WorkSpaceNotAllowedCreateError + + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + mock_create_workspace.side_effect = WorkSpaceNotAllowedCreateError() + + with pytest.raises(WorkSpaceNotAllowedCreateError): + AccountService.create_account_and_tenant( + email="test@example.com", + name="Test User", + interface_language="en-US", + password=None, + ) + + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies): """Test successful account registration.""" # Setup mocks @@ -1235,6 +1267,84 @@ class TestRegisterService: mock_join_default_workspace.assert_not_called() + def test_register_still_calls_default_workspace_join_when_personal_workspace_creation_fails( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Default workspace join should run even when personal workspace creation raises.""" + from services.errors.workspace import WorkSpaceNotAllowedCreateError + + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_tenant") as mock_create_tenant, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + mock_create_tenant.side_effect = WorkSpaceNotAllowedCreateError() + + with pytest.raises(AccountRegisterError, match="Workspace is not allowed to create."): + RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + ) + + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + mock_db_dependencies["db"].session.commit.assert_not_called() + + def test_register_still_calls_default_workspace_join_when_workspace_limit_exceeded( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Default workspace join should run before propagating workspace-limit registration failure.""" + from services.errors.workspace import WorkspacesLimitExceededError + + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_tenant") as mock_create_tenant, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + mock_create_tenant.side_effect = WorkspacesLimitExceededError() + + with pytest.raises(AccountRegisterError, match="Registration failed:"): + RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + ) + + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + mock_db_dependencies["db"].session.commit.assert_not_called() + def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies): """Test account registration with OAuth integration.""" # Setup mocks