diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 75da505cd9..576e8c0d0d 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -67,4 +67,4 @@ class OAuthDataSourceCallback(Resource): api.add_resource(OAuthDataSource, '/oauth/data-source/') -api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/authorize/') +api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/') diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py new file mode 100644 index 0000000000..0dec41890e --- /dev/null +++ b/api/controllers/console/datasets/data_source.py @@ -0,0 +1,163 @@ +import datetime +import hashlib +import json +import tempfile +import time +import uuid +from pathlib import Path + +from cachetools import TTLCache +from flask import request, current_app +from flask_login import login_required, current_user +from flask_restful import Resource, marshal_with, fields +from werkzeug.exceptions import NotFound + +from controllers.console import api +from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ + UnsupportedFileTypeError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.index.readers.html_parser import HTMLParser +from core.index.readers.pdf_parser import PDFParser +from extensions.ext_storage import storage +from libs.helper import TimestampField +from extensions.ext_database import db +from libs.oauth_data_source import NotionOAuth +from models.dataset import Document +from models.model import UploadFile +from models.source import DataSourceBinding +from services.dataset_service import DatasetService, DocumentService + +cache = TTLCache(maxsize=None, ttl=30) + +FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB +ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm'] +PREVIEW_WORDS_LIMIT = 3000 + + +class DataSourceApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + # get workspace data source integrates + data_source_integrates = db.session.query(DataSourceBinding).filter( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.disabled == False + ).all() + + base_url = request.url_root.rstrip('/') + data_source_oauth_base_path = "/console/api/oauth/data-source" + providers = ["notion"] + + integrate_data = [] + for provider in providers: + existing_integrate = next((ai for ai in data_source_integrates if ai.provider == provider), None) + if existing_integrate: + integrate_data.append({ + 'id': existing_integrate.id, + 'provider': provider, + 'created_at': existing_integrate.created_at, + 'is_bound': True, + 'disabled': existing_integrate.disabled, + 'source_info': json.loads(existing_integrate.source_info), + 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' + }) + else: + integrate_data.append({ + 'id': None, + 'provider': provider, + 'created_at': None, + 'source_info': None, + 'is_bound': False, + 'disabled': None, + 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' + }) + + return {'data': integrate_data} + + @setup_required + @login_required + @account_initialization_required + def patch(self, binding_id, action): + data_source_binding = DataSourceBinding.query.filter_by( + id=binding_id + ).first() + if data_source_binding is None: + raise NotFound('Data source binding not found.') + # enable binding + if action == 'enable': + if data_source_binding.disabled: + data_source_binding.disabled = False + data_source_binding.updated_at = datetime.datetime.utcnow() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError('Data source is not disabled.') + # disable binding + if action == 'disable': + if not data_source_binding.disabled: + data_source_binding.disabled = True + data_source_binding.updated_at = datetime.datetime.utcnow() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError('Data source is disabled.') + return {'result': 'success'}, 200 + + +class DataSourceNotionApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + dataset_id = request.args.get('dataset_id', default=None, type=str) + exist_page_ids = [] + # import notion in the exist dataset + if dataset_id: + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + if dataset.data_source_type != 'notion': + raise ValueError('Dataset is not notion type.') + documents = Document.query.filter_by( + dataset_id=dataset_id, + tenant_id=current_user.current_tenant_id, + data_source_type='notion', + enabled=True + ).all() + if documents: + page_ids = list(map(lambda item: item.data_source_info, documents)) + exist_page_ids.append(page_ids) + # get all authorized pages + data_source_bindings = DataSourceBinding.query.filter_by( + tenant_id=current_user.current_tenant_id, + provider='notion', + disabled=False + ).all() + if not data_source_bindings: + raise NotFound('Data source binding not found.') + pre_import_info_list = [] + for data_source_binding in data_source_bindings: + pages = NotionOAuth.get_authorized_pages(data_source_binding.access_token) + # Filter out already bound pages + filter_pages = filter(lambda page: page['page_id'] not in exist_page_ids, pages) + source_info = json.loads(data_source_binding.source_info) + pre_import_info = { + 'workspace_name': source_info['workspace_name'], + 'workspace_icon': source_info['workspace_icon'], + 'workspace_id': source_info['workspace_id'], + 'pages': filter_pages, + } + pre_import_info_list.append(pre_import_info) + return { + 'notion_info': pre_import_info_list + }, 200 + + +api.add_resource(DataSourceApi, '/oauth/data-source/integrates') +api.add_resource(DataSourceApi, '/oauth/data-source/integrates//') +api.add_resource(DataSourceNotionApi, '/notion/pre-import/pages') + diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 7a4491670b..c89ac6d653 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -6,7 +6,7 @@ import requests from flask_login import current_user from extensions.ext_database import db -from models.data_source import DataSourceBinding +from models.source import DataSourceBinding @dataclass diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 5b59929178..d8ea7dd58d 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -5,7 +5,7 @@ import requests from flask_login import current_user from extensions.ext_database import db -from models.data_source import DataSourceBinding +from models.source import DataSourceBinding class OAuthDataSource: @@ -58,13 +58,15 @@ class NotionOAuth(OAuthDataSource): 'workspace_name': workspace_name, 'workspace_icon': workspace_icon, 'workspace_id': workspace_id, - 'pages': pages + 'pages': pages, + 'total': len(pages) } # save data source binding data_source_binding = DataSourceBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, - source_info=json.dumps(source_info) + source_info=json.dumps(source_info), + provider='notion' ) db.session.add(data_source_binding) db.session.commit() diff --git a/api/models/data_source.py b/api/models/source.py similarity index 84% rename from api/models/data_source.py rename to api/models/source.py index 9dde6ff555..53d47975ba 100644 --- a/api/models/data_source.py +++ b/api/models/source.py @@ -13,6 +13,8 @@ class DataSourceBinding(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(UUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) + provider = db.Column(db.String(255), nullable=False) source_info = db.Column(db.Text, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))