mirror of
https://github.com/langgenius/dify.git
synced 2026-04-17 20:09:34 +08:00
refactor: migrate apikey from marshal_with/api.model to Pydantic BaseModel (#34932)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a3170f744c
commit
e37aaa482d
@ -1,12 +1,16 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import flask_restx
|
import flask_restx
|
||||||
from flask_restx import Resource, fields, marshal_with
|
from flask_restx import Resource
|
||||||
from flask_restx._http import HTTPStatus
|
from flask_restx._http import HTTPStatus
|
||||||
|
from pydantic import field_validator
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete, func, select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from fields.base import ResponseModel
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.enums import ApiTokenType
|
from models.enums import ApiTokenType
|
||||||
@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache
|
|||||||
from . import console_ns
|
from . import console_ns
|
||||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
|
|
||||||
api_key_fields = {
|
|
||||||
"id": fields.String,
|
|
||||||
"type": fields.String,
|
|
||||||
"token": fields.String,
|
|
||||||
"last_used_at": TimestampField,
|
|
||||||
"created_at": TimestampField,
|
|
||||||
}
|
|
||||||
|
|
||||||
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
|
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return int(value.timestamp())
|
||||||
|
return value
|
||||||
|
|
||||||
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
|
||||||
|
|
||||||
api_key_list_model = console_ns.model(
|
class ApiKeyItem(ResponseModel):
|
||||||
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
id: str
|
||||||
)
|
type: str
|
||||||
|
token: str
|
||||||
|
last_used_at: int | None = None
|
||||||
|
created_at: int | None = None
|
||||||
|
|
||||||
|
@field_validator("last_used_at", "created_at", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||||
|
return _to_timestamp(value)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyList(ResponseModel):
|
||||||
|
data: list[ApiKeyItem]
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||||
|
|
||||||
|
|
||||||
def _get_resource(resource_id, tenant_id, resource_model):
|
def _get_resource(resource_id, tenant_id, resource_model):
|
||||||
@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource):
|
|||||||
token_prefix: str | None = None
|
token_prefix: str | None = None
|
||||||
max_keys = 10
|
max_keys = 10
|
||||||
|
|
||||||
@marshal_with(api_key_list_model)
|
|
||||||
def get(self, resource_id):
|
def get(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource):
|
|||||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
return {"items": keys}
|
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||||
|
|
||||||
@marshal_with(api_key_item_model)
|
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, resource_id):
|
def post(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource):
|
|||||||
api_token.type = self.resource_type
|
api_token.type = self.resource_type
|
||||||
db.session.add(api_token)
|
db.session.add(api_token)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return api_token, 201
|
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||||
|
|
||||||
|
|
||||||
class BaseApiKeyResource(Resource):
|
class BaseApiKeyResource(Resource):
|
||||||
@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||||||
@console_ns.doc("get_app_api_keys")
|
@console_ns.doc("get_app_api_keys")
|
||||||
@console_ns.doc(description="Get all API keys for an app")
|
@console_ns.doc(description="Get all API keys for an app")
|
||||||
@console_ns.doc(params={"resource_id": "App ID"})
|
@console_ns.doc(params={"resource_id": "App ID"})
|
||||||
@console_ns.response(200, "Success", api_key_list_model)
|
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||||
def get(self, resource_id): # type: ignore
|
def get(self, resource_id): # type: ignore
|
||||||
"""Get all API keys for an app"""
|
"""Get all API keys for an app"""
|
||||||
return super().get(resource_id)
|
return super().get(resource_id)
|
||||||
@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||||||
@console_ns.doc("create_app_api_key")
|
@console_ns.doc("create_app_api_key")
|
||||||
@console_ns.doc(description="Create a new API key for an app")
|
@console_ns.doc(description="Create a new API key for an app")
|
||||||
@console_ns.doc(params={"resource_id": "App ID"})
|
@console_ns.doc(params={"resource_id": "App ID"})
|
||||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||||
@console_ns.response(400, "Maximum keys exceeded")
|
@console_ns.response(400, "Maximum keys exceeded")
|
||||||
def post(self, resource_id): # type: ignore
|
def post(self, resource_id): # type: ignore
|
||||||
"""Create a new API key for an app"""
|
"""Create a new API key for an app"""
|
||||||
@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||||||
@console_ns.doc("get_dataset_api_keys")
|
@console_ns.doc("get_dataset_api_keys")
|
||||||
@console_ns.doc(description="Get all API keys for a dataset")
|
@console_ns.doc(description="Get all API keys for a dataset")
|
||||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||||
@console_ns.response(200, "Success", api_key_list_model)
|
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||||
def get(self, resource_id): # type: ignore
|
def get(self, resource_id): # type: ignore
|
||||||
"""Get all API keys for a dataset"""
|
"""Get all API keys for a dataset"""
|
||||||
return super().get(resource_id)
|
return super().get(resource_id)
|
||||||
@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||||||
@console_ns.doc("create_dataset_api_key")
|
@console_ns.doc("create_dataset_api_key")
|
||||||
@console_ns.doc(description="Create a new API key for a dataset")
|
@console_ns.doc(description="Create a new API key for a dataset")
|
||||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||||
@console_ns.response(400, "Maximum keys exceeded")
|
@console_ns.response(400, "Maximum keys exceeded")
|
||||||
def post(self, resource_id): # type: ignore
|
def post(self, resource_id): # type: ignore
|
||||||
"""Create a new API key for a dataset"""
|
"""Create a new API key for a dataset"""
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.error import AlreadyActivateError
|
from controllers.console.error import AlreadyActivateError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -11,8 +12,6 @@ from libs.helper import EmailStr, timezone
|
|||||||
from models import AccountStatus
|
from models import AccountStatus
|
||||||
from services.account_service import RegisterService
|
from services.account_service import RegisterService
|
||||||
|
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
|
||||||
|
|
||||||
|
|
||||||
class ActivateCheckQuery(BaseModel):
|
class ActivateCheckQuery(BaseModel):
|
||||||
workspace_id: str | None = Field(default=None)
|
workspace_id: str | None = Field(default=None)
|
||||||
@ -39,8 +38,16 @@ class ActivatePayload(BaseModel):
|
|||||||
return timezone(value)
|
return timezone(value)
|
||||||
|
|
||||||
|
|
||||||
for model in (ActivateCheckQuery, ActivatePayload):
|
class ActivationCheckResponse(BaseModel):
|
||||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
is_valid: bool = Field(description="Whether token is valid")
|
||||||
|
data: dict | None = Field(default=None, description="Activation data if valid")
|
||||||
|
|
||||||
|
|
||||||
|
class ActivationResponse(BaseModel):
|
||||||
|
result: str = Field(description="Operation result")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/activate/check")
|
@console_ns.route("/activate/check")
|
||||||
@ -51,13 +58,7 @@ class ActivateCheckApi(Resource):
|
|||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Success",
|
"Success",
|
||||||
console_ns.model(
|
console_ns.models[ActivationCheckResponse.__name__],
|
||||||
"ActivationCheckResponse",
|
|
||||||
{
|
|
||||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
|
||||||
"data": fields.Raw(description="Activation data if valid"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def get(self):
|
def get(self):
|
||||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
@ -95,12 +96,7 @@ class ActivateApi(Resource):
|
|||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Account activated successfully",
|
"Account activated successfully",
|
||||||
console_ns.model(
|
console_ns.models[ActivationResponse.__name__],
|
||||||
"ActivationResponse",
|
|
||||||
{
|
|
||||||
"result": fields.String(description="Operation result"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
@console_ns.response(400, "Already activated or invalid token")
|
@console_ns.response(400, "Already activated or invalid token")
|
||||||
def post(self):
|
def post(self):
|
||||||
|
|||||||
@ -11,10 +11,7 @@ import services
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.apikey import (
|
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||||
api_key_item_model,
|
|
||||||
api_key_list_model,
|
|
||||||
)
|
|
||||||
from controllers.console.app.error import ProviderNotInitializeError
|
from controllers.console.app.error import ProviderNotInitializeError
|
||||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource):
|
|||||||
|
|
||||||
@console_ns.doc("get_dataset_api_keys")
|
@console_ns.doc("get_dataset_api_keys")
|
||||||
@console_ns.doc(description="Get dataset API keys")
|
@console_ns.doc(description="Get dataset API keys")
|
||||||
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
|
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_key_list_model)
|
|
||||||
def get(self):
|
def get(self):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
keys = db.session.scalars(
|
keys = db.session.scalars(
|
||||||
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||||
).all()
|
).all()
|
||||||
return {"items": keys}
|
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||||
|
|
||||||
|
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||||
|
@console_ns.response(400, "Maximum keys exceeded")
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@is_admin_or_owner_required
|
@is_admin_or_owner_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_key_item_model)
|
|
||||||
def post(self):
|
def post(self):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource):
|
|||||||
api_token.type = self.resource_type
|
api_token.type = self.resource_type
|
||||||
db.session.add(api_token)
|
db.session.add(api_token)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return api_token, 200
|
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||||
|
|||||||
@ -1555,7 +1555,17 @@ class TestDatasetApiKeyApi:
|
|||||||
method = unwrap(api.get)
|
method = unwrap(api.get)
|
||||||
|
|
||||||
mock_key_1 = MagicMock(spec=ApiToken)
|
mock_key_1 = MagicMock(spec=ApiToken)
|
||||||
|
mock_key_1.id = "key-1"
|
||||||
|
mock_key_1.type = "dataset"
|
||||||
|
mock_key_1.token = "ds-abc"
|
||||||
|
mock_key_1.last_used_at = None
|
||||||
|
mock_key_1.created_at = None
|
||||||
mock_key_2 = MagicMock(spec=ApiToken)
|
mock_key_2 = MagicMock(spec=ApiToken)
|
||||||
|
mock_key_2.id = "key-2"
|
||||||
|
mock_key_2.type = "dataset"
|
||||||
|
mock_key_2.token = "ds-def"
|
||||||
|
mock_key_2.last_used_at = None
|
||||||
|
mock_key_2.created_at = None
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/"),
|
app.test_request_context("/"),
|
||||||
@ -1570,13 +1580,26 @@ class TestDatasetApiKeyApi:
|
|||||||
):
|
):
|
||||||
response = method(api)
|
response = method(api)
|
||||||
|
|
||||||
assert "items" in response
|
assert "data" in response
|
||||||
assert response["items"] == [mock_key_1, mock_key_2]
|
assert len(response["data"]) == 2
|
||||||
|
assert response["data"][0]["id"] == "key-1"
|
||||||
|
assert response["data"][0]["token"] == "ds-abc"
|
||||||
|
assert response["data"][1]["id"] == "key-2"
|
||||||
|
assert response["data"][1]["token"] == "ds-def"
|
||||||
|
|
||||||
def test_post_create_api_key_success(self, app):
|
def test_post_create_api_key_success(self, app):
|
||||||
api = DatasetApiKeyApi()
|
api = DatasetApiKeyApi()
|
||||||
method = unwrap(api.post)
|
method = unwrap(api.post)
|
||||||
|
|
||||||
|
mock_token = MagicMock()
|
||||||
|
mock_token.id = "new-key-id"
|
||||||
|
mock_token.last_used_at = None
|
||||||
|
mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
|
||||||
|
|
||||||
|
mock_api_token_cls = MagicMock()
|
||||||
|
mock_api_token_cls.return_value = mock_token
|
||||||
|
mock_api_token_cls.generate_api_key.return_value = "dataset-abc123"
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/"),
|
app.test_request_context("/"),
|
||||||
patch(
|
patch(
|
||||||
@ -1588,8 +1611,8 @@ class TestDatasetApiKeyApi:
|
|||||||
return_value=3,
|
return_value=3,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
|
"controllers.console.datasets.datasets.ApiToken",
|
||||||
return_value="dataset-abc123",
|
mock_api_token_cls,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets.db.session.add",
|
"controllers.console.datasets.datasets.db.session.add",
|
||||||
@ -1603,9 +1626,11 @@ class TestDatasetApiKeyApi:
|
|||||||
response, status = method(api)
|
response, status = method(api)
|
||||||
|
|
||||||
assert status == 200
|
assert status == 200
|
||||||
assert isinstance(response, ApiToken)
|
assert isinstance(response, dict)
|
||||||
assert response.token == "dataset-abc123"
|
assert response["id"] == "new-key-id"
|
||||||
assert response.type == "dataset"
|
assert response["token"] == "dataset-abc123"
|
||||||
|
assert response["type"] == "dataset"
|
||||||
|
assert response["created_at"] is not None
|
||||||
|
|
||||||
def test_post_exceed_max_keys(self, app):
|
def test_post_exceed_max_keys(self, app):
|
||||||
api = DatasetApiKeyApi()
|
api = DatasetApiKeyApi()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user