From 2c3859f39a5ad30dce56972ec3986856b8af1471 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Wed, 20 Aug 2025 20:09:34 +0800 Subject: [PATCH] feat(be): impl oauth server --- api/controllers/console/__init__.py | 2 +- api/controllers/console/auth/oauth_server.py | 181 ++++++++++++++++++ ...47-8d289573e1da_add_oauth_provider_apps.py | 45 +++++ api/models/model.py | 10 +- api/services/oauth_server.py | 94 +++++++++ 5 files changed, 327 insertions(+), 5 deletions(-) create mode 100644 api/controllers/console/auth/oauth_server.py create mode 100644 api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py create mode 100644 api/services/oauth_server.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 57dbc8da64..0455a3e072 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -70,7 +70,7 @@ from .app import ( ) # Import auth controllers -from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth +from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server # Import billing controllers from .billing import billing, compliance diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py new file mode 100644 index 0000000000..add1b904ae --- /dev/null +++ b/api/controllers/console/auth/oauth_server.py @@ -0,0 +1,181 @@ +from functools import wraps +from typing import cast + +import flask_login +from flask import request +from flask_restful import Resource, reqparse +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.console.wraps import account_initialization_required, setup_required +from core.model_runtime.utils.encoders import jsonable_encoder +from libs.login import login_required +from models.account import Account +from models.model import OAuthProviderApp +from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService + +from .. import api + + +def oauth_server_client_id_required(view): + @wraps(view) + def decorated(*args, **kwargs): + parser = reqparse.RequestParser() + parser.add_argument("client_id", type=str, required=True, location="json") + args = parser.parse_args() + client_id = args.get("client_id") + if not client_id: + raise BadRequest("client_id is required") + + oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id) + if not oauth_provider_app: + raise NotFound("client_id is invalid") + + kwargs["oauth_provider_app"] = oauth_provider_app + + return view(*args, **kwargs) + + return decorated + + +def oauth_server_access_token_required(view): + @wraps(view) + def decorated(*args, **kwargs): + oauth_provider_app: OAuthProviderApp = kwargs.get("oauth_provider_app") + + if not request.headers.get("Authorization"): + raise BadRequest("Authorization is required") + + token_type = request.headers.get("Authorization") + if not token_type: + raise BadRequest("token_type is required") + token_type = token_type.split(" ")[0] + if token_type != "Bearer": + raise BadRequest("token_type is invalid") + access_token = request.headers.get("Authorization").split(" ")[1] + if not access_token: + raise BadRequest("access_token is required") + + account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token) + if not account: + raise BadRequest("access_token or client_id is invalid") + + kwargs["account"] = account + + return view(*args, **kwargs) + + return decorated + + +class OAuthServerAppApi(Resource): + @setup_required + @oauth_server_client_id_required + def post(self, oauth_provider_app: OAuthProviderApp): + parser = reqparse.RequestParser() + parser.add_argument("redirect_uri", type=str, required=True, location="json") + args = parser.parse_args() + redirect_uri = args.get("redirect_uri") + + # check if redirect_uri is valid + if redirect_uri not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") + + return jsonable_encoder( + { + "app_icon": oauth_provider_app.app_icon, + "app_label": oauth_provider_app.app_label, + "scope": oauth_provider_app.scope, + } + ) + + +class OAuthServerUserAuthorizeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @oauth_server_client_id_required + def post(self, oauth_provider_app: OAuthProviderApp): + account = cast(Account, flask_login.current_user) + user_account_id = account.id + + code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id) + return jsonable_encoder( + { + "code": code, + } + ) + + +class OAuthServerUserTokenApi(Resource): + @setup_required + @oauth_server_client_id_required + def post(self, oauth_provider_app: OAuthProviderApp): + parser = reqparse.RequestParser() + parser.add_argument("grant_type", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=False, location="json") + parser.add_argument("client_secret", type=str, required=False, location="json") + parser.add_argument("redirect_uri", type=str, required=False, location="json") + parser.add_argument("refresh_token", type=str, required=False, location="json") + args = parser.parse_args() + + grant_type = OAuthGrantType(args["grant_type"]) + + if grant_type == OAuthGrantType.AUTHORIZATION_CODE: + if not args["code"]: + raise BadRequest("code is required") + + if args["client_secret"] != oauth_provider_app.client_secret: + raise BadRequest("client_secret is invalid") + + if args["redirect_uri"] not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, code=args["code"], client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + elif grant_type == OAuthGrantType.REFRESH_TOKEN: + if not args["refresh_token"]: + raise BadRequest("refresh_token is required") + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, refresh_token=args["refresh_token"], client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + else: + raise BadRequest("invalid grant_type") + + +class OAuthServerUserAccountApi(Resource): + @setup_required + @oauth_server_client_id_required + @oauth_server_access_token_required + def post(self, oauth_provider_app: OAuthProviderApp, account: Account): + return jsonable_encoder( + { + "name": account.name, + "email": account.email, + "avatar": account.avatar, + "interface_language": account.interface_language, + "timezone": account.timezone, + } + ) + + +api.add_resource(OAuthServerAppApi, "/oauth/provider") +api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize") +api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token") +api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account") diff --git a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py new file mode 100644 index 0000000000..e804f55bac --- /dev/null +++ b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py @@ -0,0 +1,45 @@ +"""empty message + +Revision ID: 8d289573e1da +Revises: fa8b0fa6f407 +Create Date: 2025-08-20 17:47:17.015695 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '8d289573e1da' +down_revision = 'fa8b0fa6f407' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('oauth_provider_apps', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_icon', sa.String(length=255), nullable=False), + sa.Column('app_label', sa.JSON(), server_default='{}', nullable=False), + sa.Column('client_id', sa.String(length=255), nullable=False), + sa.Column('client_secret', sa.String(length=255), nullable=False), + sa.Column('redirect_uris', sa.JSON(), server_default='[]', nullable=False), + sa.Column('scope', sa.String(length=255), server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='oauth_provider_app_pkey') + ) + with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op: + batch_op.create_index('oauth_provider_app_client_id_idx', ['client_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op: + batch_op.drop_index('oauth_provider_app_client_id_idx') + + op.drop_table('oauth_provider_apps') + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index e7ddcf2d92..6c9f51a75c 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -616,17 +616,19 @@ class OAuthProviderApp(Base): __tablename__ = "oauth_provider_apps" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="oauth_provider_app_pkey"), - sa.Index("oauth_provider_app_app_id_idx", "app_id"), + sa.Index("oauth_provider_app_client_id_idx", "client_id"), ) id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_icon = mapped_column(String(255), nullable=False) - app_label = mapped_column(sa.Text, nullable=False, server_default="{}") + app_label = mapped_column(sa.JSON, nullable=False, server_default="{}") client_id = mapped_column(String(255), nullable=False) client_secret = mapped_column(String(255), nullable=False) - redirect_uri = mapped_column(String(255), nullable=False) + redirect_uris = mapped_column(sa.JSON, nullable=False, server_default="[]") scope = mapped_column( - String(255), nullable=False, server_default=sa.text("'name email avatar interface_language timezone'") + String(255), + nullable=False, + server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), ) created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/services/oauth_server.py b/api/services/oauth_server.py new file mode 100644 index 0000000000..b722dbee22 --- /dev/null +++ b/api/services/oauth_server.py @@ -0,0 +1,94 @@ +import enum +import uuid + +from sqlalchemy import select +from sqlalchemy.orm import Session +from werkzeug.exceptions import BadRequest + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account +from models.model import OAuthProviderApp +from services.account_service import AccountService + + +class OAuthGrantType(enum.StrEnum): + AUTHORIZATION_CODE = "authorization_code" + REFRESH_TOKEN = "refresh_token" + + +OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}" +OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}" +OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours +OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}" +OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days + + +class OAuthServerService: + @staticmethod + def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None: + query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id) + + with Session(db.engine) as session: + return session.execute(query).scalar_one_or_none() + + @staticmethod + def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str: + code = str(uuid.uuid4()) + redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code) + redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes + return code + + @staticmethod + def sign_oauth_access_token( + grant_type: OAuthGrantType, + code: str = "", + client_id: str = "", + refresh_token: str = "", + ) -> tuple[str, str]: + match grant_type: + case OAuthGrantType.AUTHORIZATION_CODE: + redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code) + user_account_id = redis_client.get(redis_key) + if not user_account_id: + raise BadRequest("invalid code") + + # delete code + redis_client.delete(redis_key) + + access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id) + refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id) + return access_token, refresh_token + case OAuthGrantType.REFRESH_TOKEN: + redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token) + user_account_id = redis_client.get(redis_key) + if not user_account_id: + raise BadRequest("invalid refresh token") + + access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id) + return access_token, refresh_token + + @staticmethod + def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str: + token = str(uuid.uuid4()) + redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token) + redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN) + return token + + @staticmethod + def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str: + token = str(uuid.uuid4()) + redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token) + redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN) + return token + + @staticmethod + def validate_oauth_access_token(client_id: str, token: str) -> Account | None: + redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token) + user_account_id = redis_client.get(redis_key) + if not user_account_id: + return None + + user_id_str = user_account_id.decode("utf-8") + + return AccountService.load_user(user_id_str)