mirror of https://github.com/langgenius/dify.git
add data source binding and notion auth and callback
This commit is contained in:
parent
f7539ce4c5
commit
6e725e2ed4
|
|
@ -5,13 +5,13 @@ from typing import Optional
|
|||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask_login import current_user, login_required
|
||||
from flask_restful import Resource
|
||||
|
||||
from libs.oauth import OAuthUserInfo, NotionOAuth
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from .. import api
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
|
|
@ -20,7 +20,7 @@ def get_oauth_providers():
|
|||
client_secret=current_app.config.get(
|
||||
'NOTION_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/authorize/github')
|
||||
'CONSOLE_URL') + '/console/api/oauth/data-source/authorize/notion')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
|
|
@ -29,7 +29,13 @@ def get_oauth_providers():
|
|||
|
||||
|
||||
class OAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
|
|
@ -41,7 +47,7 @@ class OAuthDataSource(Resource):
|
|||
return redirect(auth_url)
|
||||
|
||||
|
||||
class OAuthCallback(Resource):
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
|
|
@ -51,20 +57,14 @@ class OAuthCallback(Resource):
|
|||
|
||||
code = request.args.get('code')
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth process failed'}, 400
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_login=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
|
||||
|
||||
|
||||
def _bind_access_token(provider: str, user_info: OAuthUserInfo):
|
||||
|
||||
# Link account
|
||||
return
|
||||
|
||||
|
||||
api.add_resource(OAuthDataSource, '/oauth/login/<provider>')
|
||||
api.add_resource(OAuthCallback, '/oauth/authorize/<provider>')
|
||||
api.add_resource(OAuthDataSource, '/oauth/data-source/<provider>')
|
||||
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/authorize/<provider>')
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
import json
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.data_source import DataSourceBinding
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -136,34 +141,3 @@ class GoogleOAuth(OAuth):
|
|||
)
|
||||
|
||||
|
||||
class NotionOAuth(OAuth):
|
||||
_AUTH_URL = 'https://api.notion.com/v1/oauth/authorize'
|
||||
_TOKEN_URL = 'https://api.notion.com/v1/oauth/token'
|
||||
|
||||
def get_authorization_url(self):
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'response_type': 'code',
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'owner': 'user'
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
data = {
|
||||
'code': code,
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': self.redirect_uri
|
||||
}
|
||||
headers = {'Accept': 'application/json'}
|
||||
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get('access_token')
|
||||
workspace_name = response_json.get('workspace_name')
|
||||
workspace_icon = response_json.get('workspace_icon')
|
||||
workspace_id = response_json.get('workspace_id')
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in Notion OAuth: {response_json}")
|
||||
|
||||
return access_token
|
||||
|
|
|
|||
|
|
@ -0,0 +1,98 @@
|
|||
import json
|
||||
import urllib.parse
|
||||
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.data_source import DataSourceBinding
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NotionOAuth(OAuthDataSource):
|
||||
_AUTH_URL = 'https://api.notion.com/v1/oauth/authorize'
|
||||
_TOKEN_URL = 'https://api.notion.com/v1/oauth/token'
|
||||
_NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search"
|
||||
|
||||
def get_authorization_url(self):
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'response_type': 'code',
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'owner': 'user'
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
data = {
|
||||
'code': code,
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': self.redirect_uri
|
||||
}
|
||||
headers = {'Accept': 'application/json'}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get('access_token')
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in Notion OAuth: {response_json}")
|
||||
workspace_name = response_json.get('workspace_name')
|
||||
workspace_icon = response_json.get('workspace_icon')
|
||||
workspace_id = response_json.get('workspace_id')
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
'workspace_name': workspace_name,
|
||||
'workspace_icon': workspace_icon,
|
||||
'workspace_id': workspace_id,
|
||||
'pages': pages
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=json.dumps(source_info)
|
||||
)
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def get_authorized_pages(self, access_token: str):
|
||||
pages = []
|
||||
data = {
|
||||
'filter': {
|
||||
"value": "page",
|
||||
"property": "object"
|
||||
}
|
||||
}
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
'Notion-Version': '2022-06-28',
|
||||
}
|
||||
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
results = response_json['results']
|
||||
for result in results:
|
||||
page_id = result['id']
|
||||
page_name = result['properties']['title']['title'][0]['plain_text']
|
||||
page_icon = result['icon']
|
||||
page = {
|
||||
'page_id': page_id,
|
||||
'page_name': page_name,
|
||||
'page_icon': page_icon
|
||||
}
|
||||
pages.append(page)
|
||||
return pages
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class DataSourceBinding(db.Model):
|
||||
__tablename__ = 'data_source_bindings'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='app_pkey'),
|
||||
db.Index('app_tenant_id_idx', 'tenant_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
access_token = db.Column(db.String(255), nullable=False)
|
||||
source_info = db.Column(db.Text, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
Loading…
Reference in New Issue