refactor: migrate apikey from marshal_with/api.model to Pydantic BaseModel (#34932)

Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
NVIDIAN 2026-04-12 22:18:42 -07:00 committed by GitHub
parent a3170f744c
commit e37aaa482d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 86 additions and 56 deletions

View File

@ -1,12 +1,16 @@
from datetime import datetime
import flask_restx
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import field_validator
from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from libs.helper import TimestampField
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
"type": fields.String,
"token": fields.String,
"last_used_at": TimestampField,
"created_at": TimestampField,
}
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)
class ApiKeyItem(ResponseModel):
id: str
type: str
token: str
last_used_at: int | None = None
created_at: int | None = None
@field_validator("last_used_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class ApiKeyList(ResponseModel):
data: list[ApiKeyItem]
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
def _get_resource(resource_id, tenant_id, resource_model):
@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
@marshal_with(api_key_list_model)
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource):
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return {"items": keys}
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
@marshal_with(api_key_item_model)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 201
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
class BaseApiKeyResource(Resource):
@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_app_api_keys")
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "Success", api_key_list_model)
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id): # type: ignore
"""Get all API keys for an app"""
return super().get(resource_id)
@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for an app"""
@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "Success", api_key_list_model)
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset"""
return super().get(resource_id)
@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset"""

View File

@ -1,8 +1,9 @@
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
@ -11,8 +12,6 @@ from libs.helper import EmailStr, timezone
from models import AccountStatus
from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None)
@ -39,8 +38,16 @@ class ActivatePayload(BaseModel):
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class ActivationCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether token is valid")
data: dict | None = Field(default=None, description="Activation data if valid")
class ActivationResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
@console_ns.route("/activate/check")
@ -51,13 +58,7 @@ class ActivateCheckApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
"data": fields.Raw(description="Activation data if valid"),
},
),
console_ns.models[ActivationCheckResponse.__name__],
)
def get(self):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -95,12 +96,7 @@ class ActivateApi(Resource):
@console_ns.response(
200,
"Account activated successfully",
console_ns.model(
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
},
),
console_ns.models[ActivationResponse.__name__],
)
@console_ns.response(400, "Already activated or invalid token")
def post(self):

View File

@ -11,10 +11,7 @@ import services
from configs import dify_config
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
api_key_list_model,
)
from controllers.console.apikey import ApiKeyItem, ApiKeyList
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import (
@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys")
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
return {"items": keys}
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@marshal_with(api_key_item_model)
def post(self):
_, current_tenant_id = current_account_with_tenant()
@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 200
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")

View File

@ -1555,7 +1555,17 @@ class TestDatasetApiKeyApi:
method = unwrap(api.get)
mock_key_1 = MagicMock(spec=ApiToken)
mock_key_1.id = "key-1"
mock_key_1.type = "dataset"
mock_key_1.token = "ds-abc"
mock_key_1.last_used_at = None
mock_key_1.created_at = None
mock_key_2 = MagicMock(spec=ApiToken)
mock_key_2.id = "key-2"
mock_key_2.type = "dataset"
mock_key_2.token = "ds-def"
mock_key_2.last_used_at = None
mock_key_2.created_at = None
with (
app.test_request_context("/"),
@ -1570,13 +1580,26 @@ class TestDatasetApiKeyApi:
):
response = method(api)
assert "items" in response
assert response["items"] == [mock_key_1, mock_key_2]
assert "data" in response
assert len(response["data"]) == 2
assert response["data"][0]["id"] == "key-1"
assert response["data"][0]["token"] == "ds-abc"
assert response["data"][1]["id"] == "key-2"
assert response["data"][1]["token"] == "ds-def"
def test_post_create_api_key_success(self, app):
api = DatasetApiKeyApi()
method = unwrap(api.post)
mock_token = MagicMock()
mock_token.id = "new-key-id"
mock_token.last_used_at = None
mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
mock_api_token_cls = MagicMock()
mock_api_token_cls.return_value = mock_token
mock_api_token_cls.generate_api_key.return_value = "dataset-abc123"
with (
app.test_request_context("/"),
patch(
@ -1588,8 +1611,8 @@ class TestDatasetApiKeyApi:
return_value=3,
),
patch(
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
return_value="dataset-abc123",
"controllers.console.datasets.datasets.ApiToken",
mock_api_token_cls,
),
patch(
"controllers.console.datasets.datasets.db.session.add",
@ -1603,9 +1626,11 @@ class TestDatasetApiKeyApi:
response, status = method(api)
assert status == 200
assert isinstance(response, ApiToken)
assert response.token == "dataset-abc123"
assert response.type == "dataset"
assert isinstance(response, dict)
assert response["id"] == "new-key-id"
assert response["token"] == "dataset-abc123"
assert response["type"] == "dataset"
assert response["created_at"] is not None
def test_post_exceed_max_keys(self, app):
api = DatasetApiKeyApi()