diff --git a/api/controllers/console/workspace/sandbox_providers.py b/api/controllers/console/workspace/sandbox_providers.py index 4ffbd3baef..b9f18b49ea 100644 --- a/api/controllers/console/workspace/sandbox_providers.py +++ b/api/controllers/console/workspace/sandbox_providers.py @@ -47,6 +47,7 @@ class SandboxProviderConfigApi(Resource): tenant_id=current_tenant_id, provider_type=provider_type, config=args["config"], + activate=args["activate"], ) return result except ValueError as e: @@ -71,6 +72,10 @@ class SandboxProviderConfigApi(Resource): return {"message": str(e)}, 400 +activate_parser = reqparse.RequestParser() +activate_parser.add_argument("type", type=str, required=True, location="json") + + @console_ns.route("/workspaces/current/sandbox-provider//activate") class SandboxProviderActivateApi(Resource): """Activate a sandbox provider.""" @@ -86,9 +91,11 @@ class SandboxProviderActivateApi(Resource): _, current_tenant_id = current_account_with_tenant() try: + args = activate_parser.parse_args() result = SandboxProviderService.activate_provider( tenant_id=current_tenant_id, provider_type=provider_type, + type=args["type"], ) return result except ValueError as e: diff --git a/api/services/sandbox/sandbox_provider_service.py b/api/services/sandbox/sandbox_provider_service.py index 22ee009405..69cb76fe81 100644 --- a/api/services/sandbox/sandbox_provider_service.py +++ b/api/services/sandbox/sandbox_provider_service.py @@ -83,7 +83,9 @@ class SandboxProviderService: SandboxBuilder.validate(SandboxType(provider_type), config) @classmethod - def save_config(cls, tenant_id: str, provider_type: str, config: Mapping[str, Any]) -> dict[str, Any]: + def save_config( + cls, tenant_id: str, provider_type: str, config: Mapping[str, Any], activate: bool + ) -> dict[str, Any]: if provider_type not in SandboxType.get_all(): raise ValueError(f"Invalid provider type: {provider_type}") @@ -107,7 +109,7 @@ class SandboxProviderService: cls.validate_config(provider_type, new_config) provider.encrypted_config = json.dumps(encrypter.encrypt(new_config)) - provider.is_active = provider.is_active or cls.is_system_default_config(session, tenant_id) + provider.is_active = activate or provider.is_active or cls.is_system_default_config(session, tenant_id) provider.configure_type = "user" session.commit() return {"result": "success"} @@ -129,7 +131,7 @@ class SandboxProviderService: return active_config.id == system_configed.id @classmethod - def activate_provider(cls, tenant_id: str, provider_type: str) -> dict[str, Any]: + def activate_provider(cls, tenant_id: str, provider_type: str, type: str | None = None) -> dict[str, Any]: if provider_type not in SandboxType.get_all(): raise ValueError(f"Invalid provider type: {provider_type}") @@ -142,6 +144,7 @@ class SandboxProviderService: # using tenant config if tenant_config: tenant_config.is_active = True + tenant_config.configure_type = type or tenant_config.configure_type session.commit() return {"result": "success"}