diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py new file mode 100644 index 0000000000..95bc59fac6 --- /dev/null +++ b/api/controllers/console/auth/data_source_oauth.py @@ -0,0 +1,70 @@ +import logging +from datetime import datetime +from typing import Optional + +import flask_login +import requests +from flask import request, redirect, current_app, session +from flask_restful import Resource + +from libs.oauth import OAuthUserInfo, NotionOAuth +from extensions.ext_database import db +from models.account import Account, AccountStatus +from services.account_service import AccountService, RegisterService +from .. import api + + +def get_oauth_providers(): + with current_app.app_context(): + notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'), + client_secret=current_app.config.get( + 'NOTION_CLIENT_SECRET'), + redirect_uri=current_app.config.get( + 'CONSOLE_URL') + '/console/api/oauth/authorize/github') + + OAUTH_PROVIDERS = { + 'notion': notion_oauth + } + return OAUTH_PROVIDERS + + +class OAuthDataSource(Resource): + def get(self, provider: str): + OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() + with current_app.app_context(): + oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) + print(vars(oauth_provider)) + if not oauth_provider: + return {'error': 'Invalid provider'}, 400 + + auth_url = oauth_provider.get_authorization_url() + return redirect(auth_url) + + +class OAuthCallback(Resource): + def get(self, provider: str): + OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() + with current_app.app_context(): + oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) + if not oauth_provider: + return {'error': 'Invalid provider'}, 400 + + code = request.args.get('code') + try: + token = oauth_provider.get_access_token(code) + except requests.exceptions.HTTPError as e: + logging.exception( + f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") + return {'error': 'OAuth process failed'}, 400 + + return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_login=success') + + +def _bind_access_token(provider: str, user_info: OAuthUserInfo): + + # Link account + return + + +api.add_resource(OAuthDataSource, '/oauth/login/') +api.add_resource(OAuthCallback, '/oauth/authorize/') diff --git a/api/libs/oauth.py b/api/libs/oauth.py index ce41f0c22c..057713222e 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -134,3 +134,36 @@ class GoogleOAuth(OAuth): name=None, email=raw_info['email'] ) + + +class NotionOAuth(OAuth): + _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' + _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' + + def get_authorization_url(self): + params = { + 'client_id': self.client_id, + 'response_type': 'code', + 'redirect_uri': self.redirect_uri, + 'owner': 'user' + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def get_access_token(self, code: str): + data = { + 'code': code, + 'grant_type': 'authorization_code', + 'redirect_uri': self.redirect_uri + } + headers = {'Accept': 'application/json'} + response = requests.post(self._TOKEN_URL, data=data, headers=headers) + + response_json = response.json() + access_token = response_json.get('access_token') + workspace_name = response_json.get('workspace_name') + workspace_icon = response_json.get('workspace_icon') + workspace_id = response_json.get('workspace_id') + if not access_token: + raise ValueError(f"Error in Notion OAuth: {response_json}") + + return access_token