From dbd2babb05d627e855765307fea351e5a822f4f0 Mon Sep 17 00:00:00 2001 From: Jyong <718720800@qq.com> Date: Mon, 22 May 2023 21:18:58 +0800 Subject: [PATCH] notion index estimate --- .../console/datasets/data_source.py | 40 ++++- api/core/data_source/notion.py | 169 ++++++++++++++++++ api/core/indexing_runner.py | 37 ++++ api/models/source.py | 4 +- 4 files changed, 245 insertions(+), 5 deletions(-) create mode 100644 api/core/data_source/notion.py diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 0dec41890e..af7b99480a 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -17,6 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles UnsupportedFileTypeError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.data_source.notion import NotionPageReader from core.index.readers.html_parser import HTMLParser from core.index.readers.pdf_parser import PDFParser from extensions.ext_storage import storage @@ -107,7 +108,7 @@ class DataSourceApi(Resource): return {'result': 'success'}, 200 -class DataSourceNotionApi(Resource): +class DataSourceNotionListApi(Resource): @setup_required @login_required @@ -157,7 +158,40 @@ class DataSourceNotionApi(Resource): }, 200 +class DataSourceNotionApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self, workspace_id, page_id): + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == workspace_id + ) + ).first() + if not data_source_binding: + raise NotFound('Data source binding not found.') + reader = NotionPageReader(integration_token=data_source_binding.access_token) + page_content = reader.read_page(page_id) + return { + 'content': page_content + }, 200 + + @setup_required + @login_required + @account_initialization_required + def post(self): + segment_rule = request.get_json() + + indexing_runner = IndexingRunner() + response = indexing_runner.notion_indexing_estimate(file_detail, segment_rule['process_rule']) + return response, 200 + + api.add_resource(DataSourceApi, '/oauth/data-source/integrates') api.add_resource(DataSourceApi, '/oauth/data-source/integrates//') -api.add_resource(DataSourceNotionApi, '/notion/pre-import/pages') - +api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') +api.add_resource(DataSourceNotionApi, '/notion/workspaces//pages//preview') diff --git a/api/core/data_source/notion.py b/api/core/data_source/notion.py new file mode 100644 index 0000000000..496d2bcac8 --- /dev/null +++ b/api/core/data_source/notion.py @@ -0,0 +1,169 @@ +"""Notion reader.""" +import logging +import os +from typing import Any, Dict, List, Optional + +import requests # type: ignore + +from llama_index.readers.base import BaseReader +from llama_index.readers.schema.base import Document + +INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN" +BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" +DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" +SEARCH_URL = "https://api.notion.com/v1/search" + +logger = logging.getLogger(__name__) + + +# TODO: Notion DB reader coming soon! +class NotionPageReader(BaseReader): + """Notion Page reader. + + Reads a set of Notion pages. + + Args: + integration_token (str): Notion integration token. + + """ + + def __init__(self, integration_token: Optional[str] = None) -> None: + """Initialize with parameters.""" + if integration_token is None: + integration_token = os.getenv(INTEGRATION_TOKEN_NAME) + if integration_token is None: + raise ValueError( + "Must specify `integration_token` or set environment " + "variable `NOTION_INTEGRATION_TOKEN`." + ) + self.token = integration_token + self.headers = { + "Authorization": "Bearer " + self.token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + } + + def _read_block(self, block_id: str, num_tabs: int = 0) -> str: + """Read a block.""" + done = False + result_lines_arr = [] + cur_block_id = block_id + while not done: + block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) + query_dict: Dict[str, Any] = {} + + res = requests.request( + "GET", block_url, headers=self.headers, json=query_dict + ) + data = res.json() + + for result in data["results"]: + result_type = result["type"] + result_obj = result[result_type] + + cur_result_text_arr = [] + if "rich_text" in result_obj: + for rich_text in result_obj["rich_text"]: + # skip if doesn't have text object + if "text" in rich_text: + text = rich_text["text"]["content"] + prefix = "\t" * num_tabs + cur_result_text_arr.append(prefix + text) + + result_block_id = result["id"] + has_children = result["has_children"] + if has_children: + children_text = self._read_block( + result_block_id, num_tabs=num_tabs + 1 + ) + cur_result_text_arr.append(children_text) + + cur_result_text = "\n".join(cur_result_text_arr) + result_lines_arr.append(cur_result_text) + + if data["next_cursor"] is None: + done = True + break + else: + cur_block_id = data["next_cursor"] + + result_lines = "\n".join(result_lines_arr) + return result_lines + + def read_page(self, page_id: str) -> str: + """Read a page.""" + return self._read_block(page_id) + + def query_database( + self, database_id: str, query_dict: Dict[str, Any] = {} + ) -> List[str]: + """Get all the pages from a Notion database.""" + res = requests.post( + DATABASE_URL_TMPL.format(database_id=database_id), + headers=self.headers, + json=query_dict, + ) + data = res.json() + page_ids = [] + for result in data["results"]: + page_id = result["id"] + page_ids.append(page_id) + + return page_ids + + def search(self, query: str) -> List[str]: + """Search Notion page given a text query.""" + done = False + next_cursor: Optional[str] = None + page_ids = [] + while not done: + query_dict = { + "query": query, + } + if next_cursor is not None: + query_dict["start_cursor"] = next_cursor + res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict) + data = res.json() + for result in data["results"]: + page_id = result["id"] + page_ids.append(page_id) + + if data["next_cursor"] is None: + done = True + break + else: + next_cursor = data["next_cursor"] + return page_ids + + def load_data( + self, page_ids: List[str] = [], database_id: Optional[str] = None + ) -> List[Document]: + """Load data from the input directory. + + Args: + page_ids (List[str]): List of page ids to load. + + Returns: + List[Document]: List of documents. + + """ + if not page_ids and not database_id: + raise ValueError("Must specify either `page_ids` or `database_id`.") + docs = [] + if database_id is not None: + # get all the pages in the database + page_ids = self.query_database(database_id) + for page_id in page_ids: + page_text = self.read_page(page_id) + docs.append(Document(page_text, extra_info={"page_id": page_id})) + else: + for page_id in page_ids: + page_text = self.read_page(page_id) + docs.append(Document(page_text, extra_info={"page_id": page_id})) + + return docs + + +if __name__ == "__main__": + reader = NotionPageReader() + logger.info(reader.search("What I")) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f06f3a0034..403a46c6cc 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -201,6 +201,43 @@ class IndexingRunner: "preview": preview_texts } + def notion_indexing_estimate(self, notion_info, tmp_processing_rule: dict) -> dict: + """ + Estimate the indexing for the document. + """ + # load data from file + text_docs = self._load_data_from_file(file_detail) + + processing_rule = DatasetProcessRule( + mode=tmp_processing_rule["mode"], + rules=json.dumps(tmp_processing_rule["rules"]) + ) + + # get node parser for splitting + node_parser = self._get_node_parser(processing_rule) + + # split to nodes + nodes = self._split_to_nodes( + text_docs=text_docs, + node_parser=node_parser, + processing_rule=processing_rule + ) + + tokens = 0 + preview_texts = [] + for node in nodes: + if len(preview_texts) < 5: + preview_texts.append(node.get_text()) + + tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) + + return { + "total_segments": len(nodes), + "tokens": tokens, + "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), + "currency": TokenCalculator.get_currency(self.embedding_model_name), + "preview": preview_texts + } def _load_data(self, document: Document) -> List[Document]: # load file if document.data_source_type != "upload_file": diff --git a/api/models/source.py b/api/models/source.py index 53d47975ba..33eb38eea3 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,7 +1,7 @@ from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db - +from sqlalchemy.dialects.postgresql import JSONB class DataSourceBinding(db.Model): __tablename__ = 'data_source_bindings' @@ -14,7 +14,7 @@ class DataSourceBinding(db.Model): tenant_id = db.Column(UUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) - source_info = db.Column(db.Text, nullable=False) + source_info = db.Column(JSONB, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))