mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor: migrate apikey from marshal_with/api.model to Pydantic BaseModel (#34932)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a3170f744c
commit
e37aaa482d
@ -1,12 +1,16 @@
|
||||
from datetime import datetime
|
||||
|
||||
import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
"type": fields.String,
|
||||
"token": fields.String,
|
||||
"last_used_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
|
||||
api_key_list_model = console_ns.model(
|
||||
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
)
|
||||
class ApiKeyItem(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
token: str
|
||||
last_used_at: int | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("last_used_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class ApiKeyList(ResponseModel):
|
||||
data: list[ApiKeyItem]
|
||||
|
||||
|
||||
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource):
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@marshal_with(api_key_item_model)
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 201
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||
|
||||
|
||||
class BaseApiKeyResource(Resource):
|
||||
@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_app_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for an app"""
|
||||
@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for a dataset"""
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
@ -11,8 +12,6 @@ from libs.helper import EmailStr, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ActivateCheckQuery(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
@ -39,8 +38,16 @@ class ActivatePayload(BaseModel):
|
||||
return timezone(value)
|
||||
|
||||
|
||||
for model in (ActivateCheckQuery, ActivatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: dict | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
class ActivationResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
@ -51,13 +58,7 @@ class ActivateCheckApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"ActivationCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
||||
"data": fields.Raw(description="Activation data if valid"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ActivationCheckResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
@ -95,12 +96,7 @@ class ActivateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
console_ns.model(
|
||||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ActivationResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
|
||||
@ -11,10 +11,7 @@ import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import (
|
||||
api_key_item_model,
|
||||
api_key_list_model,
|
||||
)
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.wraps import (
|
||||
@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_item_model)
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 200
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||
|
||||
@ -1555,7 +1555,17 @@ class TestDatasetApiKeyApi:
|
||||
method = unwrap(api.get)
|
||||
|
||||
mock_key_1 = MagicMock(spec=ApiToken)
|
||||
mock_key_1.id = "key-1"
|
||||
mock_key_1.type = "dataset"
|
||||
mock_key_1.token = "ds-abc"
|
||||
mock_key_1.last_used_at = None
|
||||
mock_key_1.created_at = None
|
||||
mock_key_2 = MagicMock(spec=ApiToken)
|
||||
mock_key_2.id = "key-2"
|
||||
mock_key_2.type = "dataset"
|
||||
mock_key_2.token = "ds-def"
|
||||
mock_key_2.last_used_at = None
|
||||
mock_key_2.created_at = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -1570,13 +1580,26 @@ class TestDatasetApiKeyApi:
|
||||
):
|
||||
response = method(api)
|
||||
|
||||
assert "items" in response
|
||||
assert response["items"] == [mock_key_1, mock_key_2]
|
||||
assert "data" in response
|
||||
assert len(response["data"]) == 2
|
||||
assert response["data"][0]["id"] == "key-1"
|
||||
assert response["data"][0]["token"] == "ds-abc"
|
||||
assert response["data"][1]["id"] == "key-2"
|
||||
assert response["data"][1]["token"] == "ds-def"
|
||||
|
||||
def test_post_create_api_key_success(self, app):
|
||||
api = DatasetApiKeyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "new-key-id"
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
|
||||
|
||||
mock_api_token_cls = MagicMock()
|
||||
mock_api_token_cls.return_value = mock_token
|
||||
mock_api_token_cls.generate_api_key.return_value = "dataset-abc123"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
@ -1588,8 +1611,8 @@ class TestDatasetApiKeyApi:
|
||||
return_value=3,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
|
||||
return_value="dataset-abc123",
|
||||
"controllers.console.datasets.datasets.ApiToken",
|
||||
mock_api_token_cls,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.add",
|
||||
@ -1603,9 +1626,11 @@ class TestDatasetApiKeyApi:
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert isinstance(response, ApiToken)
|
||||
assert response.token == "dataset-abc123"
|
||||
assert response.type == "dataset"
|
||||
assert isinstance(response, dict)
|
||||
assert response["id"] == "new-key-id"
|
||||
assert response["token"] == "dataset-abc123"
|
||||
assert response["type"] == "dataset"
|
||||
assert response["created_at"] is not None
|
||||
|
||||
def test_post_exceed_max_keys(self, app):
|
||||
api = DatasetApiKeyApi()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user