refactor: migrate apikey from marshal_with/api.model to Pydantic BaseModel (#34932)

Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
NVIDIAN 2026-04-12 22:18:42 -07:00 committed by GitHub
parent a3170f744c
commit e37aaa482d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 86 additions and 56 deletions

View File

@ -1,12 +1,16 @@
from datetime import datetime
import flask_restx import flask_restx
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource
from flask_restx._http import HTTPStatus from flask_restx._http import HTTPStatus
from pydantic import field_validator
from sqlalchemy import delete, func, select from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset from models.dataset import Dataset
from models.enums import ApiTokenType from models.enums import ApiTokenType
@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache
from . import console_ns from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
"type": fields.String,
"token": fields.String,
"last_used_at": TimestampField,
"created_at": TimestampField,
}
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields) def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
api_key_list_model = console_ns.model( class ApiKeyItem(ResponseModel):
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} id: str
) type: str
token: str
last_used_at: int | None = None
created_at: int | None = None
@field_validator("last_used_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class ApiKeyList(ResponseModel):
data: list[ApiKeyItem]
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
def _get_resource(resource_id, tenant_id, resource_model): def _get_resource(resource_id, tenant_id, resource_model):
@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None token_prefix: str | None = None
max_keys = 10 max_keys = 10
@marshal_with(api_key_list_model)
def get(self, resource_id): def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource):
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
) )
).all() ).all()
return {"items": keys} return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
@marshal_with(api_key_item_model)
@edit_permission_required @edit_permission_required
def post(self, resource_id): def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource):
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)
db.session.commit() db.session.commit()
return api_token, 201 return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
class BaseApiKeyResource(Resource): class BaseApiKeyResource(Resource):
@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_app_api_keys") @console_ns.doc("get_app_api_keys")
@console_ns.doc(description="Get all API keys for an app") @console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"}) @console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "Success", api_key_list_model) @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id): # type: ignore def get(self, resource_id): # type: ignore
"""Get all API keys for an app""" """Get all API keys for an app"""
return super().get(resource_id) return super().get(resource_id)
@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_app_api_key") @console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app") @console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"}) @console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model) @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded") @console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore def post(self, resource_id): # type: ignore
"""Create a new API key for an app""" """Create a new API key for an app"""
@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_dataset_api_keys") @console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get all API keys for a dataset") @console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"}) @console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "Success", api_key_list_model) @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id): # type: ignore def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset""" """Get all API keys for a dataset"""
return super().get(resource_id) return super().get(resource_id)
@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_dataset_api_key") @console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset") @console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"}) @console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model) @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded") @console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset""" """Create a new API key for a dataset"""

View File

@ -1,8 +1,9 @@
from flask import request from flask import request
from flask_restx import Resource, fields from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
@ -11,8 +12,6 @@ from libs.helper import EmailStr, timezone
from models import AccountStatus from models import AccountStatus
from services.account_service import RegisterService from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel): class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None) workspace_id: str | None = Field(default=None)
@ -39,8 +38,16 @@ class ActivatePayload(BaseModel):
return timezone(value) return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload): class ActivationCheckResponse(BaseModel):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) is_valid: bool = Field(description="Whether token is valid")
data: dict | None = Field(default=None, description="Activation data if valid")
class ActivationResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
@console_ns.route("/activate/check") @console_ns.route("/activate/check")
@ -51,13 +58,7 @@ class ActivateCheckApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
console_ns.model( console_ns.models[ActivationCheckResponse.__name__],
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
"data": fields.Raw(description="Activation data if valid"),
},
),
) )
def get(self): def get(self):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -95,12 +96,7 @@ class ActivateApi(Resource):
@console_ns.response( @console_ns.response(
200, 200,
"Account activated successfully", "Account activated successfully",
console_ns.model( console_ns.models[ActivationResponse.__name__],
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
},
),
) )
@console_ns.response(400, "Already activated or invalid token") @console_ns.response(400, "Already activated or invalid token")
def post(self): def post(self):

View File

@ -11,10 +11,7 @@ import services
from configs import dify_config from configs import dify_config
from controllers.common.schema import get_or_create_model, register_schema_models from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.apikey import ( from controllers.console.apikey import ApiKeyItem, ApiKeyList
api_key_item_model,
api_key_list_model,
)
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource):
@console_ns.doc("get_dataset_api_keys") @console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys") @console_ns.doc(description="Get dataset API keys")
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model) @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(api_key_list_model)
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars( keys = db.session.scalars(
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all() ).all()
return {"items": keys} return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required @is_admin_or_owner_required
@account_initialization_required @account_initialization_required
@marshal_with(api_key_item_model)
def post(self): def post(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource):
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)
db.session.commit() db.session.commit()
return api_token, 200 return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>") @console_ns.route("/datasets/api-keys/<uuid:api_key_id>")

View File

@ -1555,7 +1555,17 @@ class TestDatasetApiKeyApi:
method = unwrap(api.get) method = unwrap(api.get)
mock_key_1 = MagicMock(spec=ApiToken) mock_key_1 = MagicMock(spec=ApiToken)
mock_key_1.id = "key-1"
mock_key_1.type = "dataset"
mock_key_1.token = "ds-abc"
mock_key_1.last_used_at = None
mock_key_1.created_at = None
mock_key_2 = MagicMock(spec=ApiToken) mock_key_2 = MagicMock(spec=ApiToken)
mock_key_2.id = "key-2"
mock_key_2.type = "dataset"
mock_key_2.token = "ds-def"
mock_key_2.last_used_at = None
mock_key_2.created_at = None
with ( with (
app.test_request_context("/"), app.test_request_context("/"),
@ -1570,13 +1580,26 @@ class TestDatasetApiKeyApi:
): ):
response = method(api) response = method(api)
assert "items" in response assert "data" in response
assert response["items"] == [mock_key_1, mock_key_2] assert len(response["data"]) == 2
assert response["data"][0]["id"] == "key-1"
assert response["data"][0]["token"] == "ds-abc"
assert response["data"][1]["id"] == "key-2"
assert response["data"][1]["token"] == "ds-def"
def test_post_create_api_key_success(self, app): def test_post_create_api_key_success(self, app):
api = DatasetApiKeyApi() api = DatasetApiKeyApi()
method = unwrap(api.post) method = unwrap(api.post)
mock_token = MagicMock()
mock_token.id = "new-key-id"
mock_token.last_used_at = None
mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
mock_api_token_cls = MagicMock()
mock_api_token_cls.return_value = mock_token
mock_api_token_cls.generate_api_key.return_value = "dataset-abc123"
with ( with (
app.test_request_context("/"), app.test_request_context("/"),
patch( patch(
@ -1588,8 +1611,8 @@ class TestDatasetApiKeyApi:
return_value=3, return_value=3,
), ),
patch( patch(
"controllers.console.datasets.datasets.ApiToken.generate_api_key", "controllers.console.datasets.datasets.ApiToken",
return_value="dataset-abc123", mock_api_token_cls,
), ),
patch( patch(
"controllers.console.datasets.datasets.db.session.add", "controllers.console.datasets.datasets.db.session.add",
@ -1603,9 +1626,11 @@ class TestDatasetApiKeyApi:
response, status = method(api) response, status = method(api)
assert status == 200 assert status == 200
assert isinstance(response, ApiToken) assert isinstance(response, dict)
assert response.token == "dataset-abc123" assert response["id"] == "new-key-id"
assert response.type == "dataset" assert response["token"] == "dataset-abc123"
assert response["type"] == "dataset"
assert response["created_at"] is not None
def test_post_exceed_max_keys(self, app): def test_post_exceed_max_keys(self, app):
api = DatasetApiKeyApi() api = DatasetApiKeyApi()