diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index af7b99480a..2c2b9a1fd8 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -9,7 +9,7 @@ from pathlib import Path from cachetools import TTLCache from flask import request, current_app from flask_login import login_required, current_user -from flask_restful import Resource, marshal_with, fields +from flask_restful import Resource, marshal_with, fields, reqparse from werkzeug.exceptions import NotFound from controllers.console import api @@ -20,6 +20,7 @@ from controllers.console.wraps import account_initialization_required from core.data_source.notion import NotionPageReader from core.index.readers.html_parser import HTMLParser from core.index.readers.pdf_parser import PDFParser +from core.indexing_runner import IndexingRunner from extensions.ext_storage import storage from libs.helper import TimestampField from extensions.ext_database import db @@ -184,10 +185,15 @@ class DataSourceNotionApi(Resource): @login_required @account_initialization_required def post(self): - segment_rule = request.get_json() - + notion_import_info = request.get_json() + parser = reqparse.RequestParser() + parser.add_argument('notion_info_list', type=dict, required=True, nullable=True, location='json') + parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + args = parser.parse_args() + # validate args + DocumentService.notion_estimate_args_validate(args) indexing_runner = IndexingRunner() - response = indexing_runner.notion_indexing_estimate(file_detail, segment_rule['process_rule']) + response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule']) return response, 200 diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 403a46c6cc..12aee6e030 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,6 +5,8 @@ import tempfile import time from pathlib import Path from typing import Optional, List + +from flask_login import current_user from langchain.text_splitter import RecursiveCharacterTextSplitter from llama_index import SimpleDirectoryReader @@ -14,6 +16,7 @@ from llama_index.node_parser import SimpleNodeParser, NodeParser from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR from llama_index.readers.file.markdown_parser import MarkdownParser +from core.data_source.notion import NotionPageReader from core.docstore.dataset_docstore import DatesetDocumentStore from core.index.keyword_table_index import KeywordTableIndex from core.index.readers.html_parser import HTMLParser @@ -26,6 +29,7 @@ from extensions.ext_redis import redis_client from extensions.ext_storage import storage from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule from models.model import UploadFile +from models.source import DataSourceBinding class IndexingRunner: @@ -201,43 +205,59 @@ class IndexingRunner: "preview": preview_texts } - def notion_indexing_estimate(self, notion_info, tmp_processing_rule: dict) -> dict: + def notion_indexing_estimate(self, notion_info_list: dict, tmp_processing_rule: dict) -> dict: """ Estimate the indexing for the document. """ - # load data from file - text_docs = self._load_data_from_file(file_detail) - - processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], - rules=json.dumps(tmp_processing_rule["rules"]) - ) - - # get node parser for splitting - node_parser = self._get_node_parser(processing_rule) - - # split to nodes - nodes = self._split_to_nodes( - text_docs=text_docs, - node_parser=node_parser, - processing_rule=processing_rule - ) - + # load data from notion tokens = 0 preview_texts = [] - for node in nodes: - if len(preview_texts) < 5: - preview_texts.append(node.get_text()) + total_segments = 0 + for notion_info in notion_info_list: + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == notion_info['workspace_id'] + ) + ).first() + if not data_source_binding: + raise ValueError('Data source binding not found.') + reader = NotionPageReader(integration_token=data_source_binding.access_token) + for page in notion_info['pages']: + page_ids = [page['page_id']] + documents = reader.load_data(page_ids=page_ids) - tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) + processing_rule = DatasetProcessRule( + mode=tmp_processing_rule["mode"], + rules=json.dumps(tmp_processing_rule["rules"]) + ) + + # get node parser for splitting + node_parser = self._get_node_parser(processing_rule) + + # split to nodes + nodes = self._split_to_nodes( + text_docs=documents, + node_parser=node_parser, + processing_rule=processing_rule + ) + total_segments += len(nodes) + for node in nodes: + if len(preview_texts) < 5: + preview_texts.append(node.get_text()) + + tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) return { - "total_segments": len(nodes), + "total_segments": len(total_segments), "tokens": tokens, "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "currency": TokenCalculator.get_currency(self.embedding_model_name), "preview": preview_texts } + def _load_data(self, document: Document) -> List[Document]: # load file if document.data_source_type != "upload_file": diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 39004c3437..1a032c9137 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -519,3 +519,78 @@ class DocumentService: if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): raise ValueError("Process rule segmentation max_tokens is invalid") + + @classmethod + def notion_estimate_args_validate(cls, args: dict): + if 'notion_info_list' not in args or not args['notion_info_list']: + raise ValueError("Notion info is required") + + if not isinstance(args['notion_info_list'], dict): + raise ValueError("Notion info is invalid") + + if 'process_rule' not in args or not args['process_rule']: + raise ValueError("Process rule is required") + + if not isinstance(args['process_rule'], dict): + raise ValueError("Process rule is invalid") + + if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + raise ValueError("Process rule mode is required") + + if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + raise ValueError("Process rule mode is invalid") + + if args['process_rule']['mode'] == 'automatic': + args['process_rule']['rules'] = {} + else: + if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + raise ValueError("Process rule rules is required") + + if not isinstance(args['process_rule']['rules'], dict): + raise ValueError("Process rule rules is invalid") + + if 'pre_processing_rules' not in args['process_rule']['rules'] \ + or args['process_rule']['rules']['pre_processing_rules'] is None: + raise ValueError("Process rule pre_processing_rules is required") + + if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + raise ValueError("Process rule pre_processing_rules is invalid") + + unique_pre_processing_rule_dicts = {} + for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: + if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + raise ValueError("Process rule pre_processing_rules id is required") + + if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + raise ValueError("Process rule pre_processing_rules id is invalid") + + if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + raise ValueError("Process rule pre_processing_rules enabled is required") + + if not isinstance(pre_processing_rule['enabled'], bool): + raise ValueError("Process rule pre_processing_rules enabled is invalid") + + unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + + args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + + if 'segmentation' not in args['process_rule']['rules'] \ + or args['process_rule']['rules']['segmentation'] is None: + raise ValueError("Process rule segmentation is required") + + if not isinstance(args['process_rule']['rules']['segmentation'], dict): + raise ValueError("Process rule segmentation is invalid") + + if 'separator' not in args['process_rule']['rules']['segmentation'] \ + or not args['process_rule']['rules']['segmentation']['separator']: + raise ValueError("Process rule segmentation separator is required") + + if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + raise ValueError("Process rule segmentation separator is invalid") + + if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ + or not args['process_rule']['rules']['segmentation']['max_tokens']: + raise ValueError("Process rule segmentation max_tokens is required") + + if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + raise ValueError("Process rule segmentation max_tokens is invalid")