From 6e725e2ed46cef6a65d5ac5cd0d17af8a6037dcc Mon Sep 17 00:00:00 2001 From: Jyong <718720800@qq.com> Date: Thu, 18 May 2023 00:22:08 +0800 Subject: [PATCH] add data source binding and notion auth and callback --- .../console/auth/data_source_oauth.py | 36 +++---- api/libs/oauth.py | 36 +------ api/libs/oauth_data_source.py | 98 +++++++++++++++++++ api/models/data_source.py | 18 ++++ 4 files changed, 139 insertions(+), 49 deletions(-) create mode 100644 api/libs/oauth_data_source.py create mode 100644 api/models/data_source.py diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 95bc59fac6..75da505cd9 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -5,13 +5,13 @@ from typing import Optional import flask_login import requests from flask import request, redirect, current_app, session +from flask_login import current_user, login_required from flask_restful import Resource - -from libs.oauth import OAuthUserInfo, NotionOAuth -from extensions.ext_database import db -from models.account import Account, AccountStatus -from services.account_service import AccountService, RegisterService +from werkzeug.exceptions import Forbidden +from libs.oauth_data_source import NotionOAuth from .. import api +from ..setup import setup_required +from ..wraps import account_initialization_required def get_oauth_providers(): @@ -20,7 +20,7 @@ def get_oauth_providers(): client_secret=current_app.config.get( 'NOTION_CLIENT_SECRET'), redirect_uri=current_app.config.get( - 'CONSOLE_URL') + '/console/api/oauth/authorize/github') + 'CONSOLE_URL') + '/console/api/oauth/data-source/authorize/notion') OAUTH_PROVIDERS = { 'notion': notion_oauth @@ -29,7 +29,13 @@ def get_oauth_providers(): class OAuthDataSource(Resource): + @setup_required + @login_required + @account_initialization_required def get(self, provider: str): + # The role of the current user in the table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) @@ -41,7 +47,7 @@ class OAuthDataSource(Resource): return redirect(auth_url) -class OAuthCallback(Resource): +class OAuthDataSourceCallback(Resource): def get(self, provider: str): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -51,20 +57,14 @@ class OAuthCallback(Resource): code = request.args.get('code') try: - token = oauth_provider.get_access_token(code) + oauth_provider.get_access_token(code) except requests.exceptions.HTTPError as e: logging.exception( f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth process failed'}, 400 + return {'error': 'OAuth data source process failed'}, 400 - return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_login=success') + return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success') -def _bind_access_token(provider: str, user_info: OAuthUserInfo): - - # Link account - return - - -api.add_resource(OAuthDataSource, '/oauth/login/') -api.add_resource(OAuthCallback, '/oauth/authorize/') +api.add_resource(OAuthDataSource, '/oauth/data-source/') +api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/authorize/') diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 057713222e..7a4491670b 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,7 +1,12 @@ +import json import urllib.parse from dataclasses import dataclass import requests +from flask_login import current_user + +from extensions.ext_database import db +from models.data_source import DataSourceBinding @dataclass @@ -136,34 +141,3 @@ class GoogleOAuth(OAuth): ) -class NotionOAuth(OAuth): - _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' - _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' - - def get_authorization_url(self): - params = { - 'client_id': self.client_id, - 'response_type': 'code', - 'redirect_uri': self.redirect_uri, - 'owner': 'user' - } - return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - - def get_access_token(self, code: str): - data = { - 'code': code, - 'grant_type': 'authorization_code', - 'redirect_uri': self.redirect_uri - } - headers = {'Accept': 'application/json'} - response = requests.post(self._TOKEN_URL, data=data, headers=headers) - - response_json = response.json() - access_token = response_json.get('access_token') - workspace_name = response_json.get('workspace_name') - workspace_icon = response_json.get('workspace_icon') - workspace_id = response_json.get('workspace_id') - if not access_token: - raise ValueError(f"Error in Notion OAuth: {response_json}") - - return access_token diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py new file mode 100644 index 0000000000..5b59929178 --- /dev/null +++ b/api/libs/oauth_data_source.py @@ -0,0 +1,98 @@ +import json +import urllib.parse + +import requests +from flask_login import current_user + +from extensions.ext_database import db +from models.data_source import DataSourceBinding + + +class OAuthDataSource: + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + def get_authorization_url(self): + raise NotImplementedError() + + def get_access_token(self, code: str): + raise NotImplementedError() + + +class NotionOAuth(OAuthDataSource): + _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' + _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' + _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" + + def get_authorization_url(self): + params = { + 'client_id': self.client_id, + 'response_type': 'code', + 'redirect_uri': self.redirect_uri, + 'owner': 'user' + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def get_access_token(self, code: str): + data = { + 'code': code, + 'grant_type': 'authorization_code', + 'redirect_uri': self.redirect_uri + } + headers = {'Accept': 'application/json'} + auth = (self.client_id, self.client_secret) + response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) + + response_json = response.json() + access_token = response_json.get('access_token') + if not access_token: + raise ValueError(f"Error in Notion OAuth: {response_json}") + workspace_name = response_json.get('workspace_name') + workspace_icon = response_json.get('workspace_icon') + workspace_id = response_json.get('workspace_id') + # get all authorized pages + pages = self.get_authorized_pages(access_token) + source_info = { + 'workspace_name': workspace_name, + 'workspace_icon': workspace_icon, + 'workspace_id': workspace_id, + 'pages': pages + } + # save data source binding + data_source_binding = DataSourceBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=json.dumps(source_info) + ) + db.session.add(data_source_binding) + db.session.commit() + + def get_authorized_pages(self, access_token: str): + pages = [] + data = { + 'filter': { + "value": "page", + "property": "object" + } + } + headers = { + 'Content-Type': 'application/json', + 'Authorization': f"Bearer {access_token}", + 'Notion-Version': '2022-06-28', + } + response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response_json = response.json() + results = response_json['results'] + for result in results: + page_id = result['id'] + page_name = result['properties']['title']['title'][0]['plain_text'] + page_icon = result['icon'] + page = { + 'page_id': page_id, + 'page_name': page_name, + 'page_icon': page_icon + } + pages.append(page) + return pages diff --git a/api/models/data_source.py b/api/models/data_source.py new file mode 100644 index 0000000000..9dde6ff555 --- /dev/null +++ b/api/models/data_source.py @@ -0,0 +1,18 @@ +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db + + +class DataSourceBinding(db.Model): + __tablename__ = 'data_source_bindings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='app_pkey'), + db.Index('app_tenant_id_idx', 'tenant_id') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + access_token = db.Column(db.String(255), nullable=False) + source_info = db.Column(db.Text, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))