From ef396ac84eca9ed16c9a2c3be928fda965f25817 Mon Sep 17 00:00:00 2001 From: NVIDIAN Date: Tue, 14 Apr 2026 12:48:09 -0700 Subject: [PATCH] refactor(api): migrate workspace current response from marshal_with to BaseModel (#35207) Co-authored-by: ai-hpc --- .../console/workspace/workspace.py | 48 +++++++++++++++++-- .../console/workspace/test_workspace.py | 18 +++++++ 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 42874e6033..565099db61 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,9 @@ import logging +from datetime import datetime from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource, fields, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import Unauthorized @@ -26,6 +27,7 @@ from controllers.console.wraps import ( ) from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from fields.base import ResponseModel from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.account import Tenant, TenantCustomConfigDict, TenantStatus @@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel): name: str +class TenantInfoResponse(ResponseModel): + id: str + name: str | None = None + plan: str | None = None + status: str | None = None + created_at: int | None = None + role: str | None = None + in_trial: bool | None = None + trial_end_reason: str | None = None + custom_config: dict | None = None + trial_credits: int | None = None + trial_credits_used: int | None = None + next_credit_reset_date: int | None = None + + @field_validator("plan", "status", "trial_end_reason", mode="before") + @classmethod + def _normalize_enum_like(cls, value): + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None): + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) @@ -66,6 +99,7 @@ reg(WorkspaceListQuery) reg(SwitchWorkspacePayload) reg(WorkspaceCustomConfigPayload) reg(WorkspaceInfoPayload) +reg(TenantInfoResponse) provider_fields = { "provider_name": fields.String, @@ -180,7 +214,7 @@ class TenantApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(tenant_fields) + @console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__]) def post(self): if request.path == "/info": logger.warning("Deprecated URL /info was used.") @@ -200,7 +234,13 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") - return WorkspaceService.get_tenant_info(tenant), 200 + return ( + TenantInfoResponse.model_validate( + WorkspaceService.get_tenant_info(tenant), + from_attributes=True, + ).model_dump(mode="json"), + 200, + ) @console_ns.route("/workspaces/switch") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index b2d13dbbdf..e82a29f045 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -18,6 +18,7 @@ from controllers.console.workspace.workspace import ( CustomConfigWorkspaceApi, SwitchWorkspaceApi, TenantApi, + TenantInfoResponse, TenantListApi, WebappLogoWorkspaceApi, WorkspaceInfoApi, @@ -435,6 +436,23 @@ class TestTenantApi: assert status == 200 +class TestTenantInfoResponse: + def test_tenant_info_response_normalizes_enum_and_datetime(self): + created_at = naive_utc_now() + payload = TenantInfoResponse.model_validate( + { + "id": "t1", + "status": TenantStatus.NORMAL, + "plan": CloudPlan.TEAM, + "created_at": created_at, + } + ).model_dump(mode="json") + + assert payload["status"] == "normal" + assert payload["plan"] == "team" + assert payload["created_at"] == int(created_at.timestamp()) + + class TestSwitchWorkspaceApi: def test_switch_success(self, app): api = SwitchWorkspaceApi()