From a3170f744c5e4abc9bfa25a29522d16e8e5af8d8 Mon Sep 17 00:00:00 2001 From: NVIDIAN Date: Sun, 12 Apr 2026 22:18:16 -0700 Subject: [PATCH] refactor: migrate app site from marshal_with/api.model to Pydantic BaseModel (#34933) Co-authored-by: ai-hpc --- api/controllers/console/app/site.py | 42 ++++++++++++------- .../controllers/console/app/test_app_apis.py | 36 +++++++++++++++- .../controllers/console/app/test_message.py | 10 +++-- 3 files changed, 67 insertions(+), 21 deletions(-) diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 7f44a99ff1..9991d78d94 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,11 +1,12 @@ from typing import Literal -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( @@ -15,13 +16,11 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.app_fields import app_site_fields +from fields.base import ResponseModel from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import Site -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AppSiteUpdatePayload(BaseModel): title: str | None = Field(default=None) @@ -49,13 +48,26 @@ class AppSiteUpdatePayload(BaseModel): return supported_language(value) -console_ns.schema_model( - AppSiteUpdatePayload.__name__, - AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class AppSiteResponse(ResponseModel): + app_id: str + access_token: str | None = Field(default=None, validation_alias="code") + code: str | None = None + title: str + icon: str | None = None + icon_background: str | None = None + description: str | None = None + default_language: str + customize_domain: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + customize_token_strategy: str + prompt_public: bool + show_workflow_steps: bool + use_icon_as_answer_icon: bool -# Register model for flask_restx to avoid dict type issues in Swagger -app_site_model = console_ns.model("AppSite", app_site_fields) + +register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse) @console_ns.route("/apps//site") @@ -64,7 +76,7 @@ class AppSite(Resource): @console_ns.doc(description="Update application site configuration") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__]) - @console_ns.response(200, "Site configuration updated successfully", app_site_model) + @console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "App not found") @setup_required @@ -72,7 +84,6 @@ class AppSite(Resource): @edit_permission_required @account_initialization_required @get_app_model - @marshal_with(app_site_model) def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() @@ -106,7 +117,7 @@ class AppSite(Resource): site.updated_at = naive_utc_now() db.session.commit() - return site + return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//site/access-token-reset") @@ -114,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @console_ns.doc("reset_app_site_access_token") @console_ns.doc(description="Reset access token for application site") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Access token reset successfully", app_site_model) + @console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__]) @console_ns.response(403, "Insufficient permissions (admin/owner required)") @console_ns.response(404, "App or site not found") @setup_required @@ -122,7 +133,6 @@ class AppSiteAccessTokenReset(Resource): @is_admin_or_owner_required @account_initialization_required @get_app_model - @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) @@ -135,4 +145,4 @@ class AppSiteAccessTokenReset(Resource): site.updated_at = naive_utc_now() db.session.commit() - return site + return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index c3a861c3e1..54e0496dbd 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -313,6 +313,21 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() + site.app_id = "app-1" + site.code = "test-code" + site.title = "My Site" + site.icon = None + site.icon_background = None + site.description = "Test site" + site.default_language = "en-US" + site.customize_domain = None + site.copyright = None + site.privacy_policy = None + site.custom_disclaimer = "" + site.customize_token_strategy = "not_allow" + site.prompt_public = False + site.show_workflow_steps = True + site.use_icon_as_answer_icon = False monkeypatch.setattr( site_module.db, "session", @@ -328,13 +343,29 @@ class TestSiteEndpoints: with app.test_request_context("/", json={"title": "My Site"}): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is site + assert isinstance(result, dict) + assert result["title"] == "My Site" def test_app_site_access_token_reset(self, app, monkeypatch): api = site_module.AppSiteAccessTokenReset() method = _unwrap(api.post) site = MagicMock() + site.app_id = "app-1" + site.code = "old-code" + site.title = "My Site" + site.icon = None + site.icon_background = None + site.description = None + site.default_language = "en-US" + site.customize_domain = None + site.copyright = None + site.privacy_policy = None + site.custom_disclaimer = "" + site.customize_token_strategy = "not_allow" + site.prompt_public = False + site.show_workflow_steps = True + site.use_icon_as_answer_icon = False monkeypatch.setattr( site_module.db, "session", @@ -351,7 +382,8 @@ class TestSiteEndpoints: with app.test_request_context("/"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is site + assert isinstance(result, dict) + assert result["access_token"] == "code" class TestWorkflowEndpoints: diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py index 6b51ec98bc..eff6dd789d 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py @@ -148,14 +148,18 @@ def test_chat_message_list_success( account.id, created_at_offset_seconds=1, ) + # Capture IDs before the HTTP request detaches ORM instances from the session + app_id = app.id + conversation_id = conversation.id + second_id = second.id with patch( "controllers.console.app.message.attach_message_extra_contents", side_effect=_attach_message_extra_contents, ): response = test_client_with_containers.get( - f"/console/api/apps/{app.id}/chat-messages", - query_string={"conversation_id": conversation.id, "limit": 1}, + f"/console/api/apps/{app_id}/chat-messages", + query_string={"conversation_id": conversation_id, "limit": 1}, headers=authenticate_console_client(test_client_with_containers, account), ) @@ -165,7 +169,7 @@ def test_chat_message_list_success( assert payload["limit"] == 1 assert payload["has_more"] is True assert len(payload["data"]) == 1 - assert payload["data"][0]["id"] == second.id + assert payload["data"][0]["id"] == second_id def test_message_feedback_not_found(