From f1f5d45d2e9a002dc522ae4399bbfb03da5c914e Mon Sep 17 00:00:00 2001 From: Jyong <718720800@qq.com> Date: Thu, 25 May 2023 23:15:36 +0800 Subject: [PATCH] support mutil files and notion pages --- api/core/data_source/notion.py | 53 +++++++++++++++++++++++++++++++-- api/core/indexing_runner.py | 41 ++++++++++++++++++------- api/services/dataset_service.py | 1 + 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/api/core/data_source/notion.py b/api/core/data_source/notion.py index 496d2bcac8..7bb693c308 100644 --- a/api/core/data_source/notion.py +++ b/api/core/data_source/notion.py @@ -90,12 +90,61 @@ class NotionPageReader(BaseReader): result_lines = "\n".join(result_lines_arr) return result_lines + def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]: + """Read a block.""" + done = False + result_lines_arr = [] + cur_block_id = block_id + while not done: + block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) + query_dict: Dict[str, Any] = {} + + res = requests.request( + "GET", block_url, headers=self.headers, json=query_dict + ) + data = res.json() + + for result in data["results"]: + result_type = result["type"] + result_obj = result[result_type] + + cur_result_text_arr = [] + if "rich_text" in result_obj: + for rich_text in result_obj["rich_text"]: + # skip if doesn't have text object + if "text" in rich_text: + text = rich_text["text"]["content"] + prefix = "\t" * num_tabs + cur_result_text_arr.append(prefix + text) + + result_block_id = result["id"] + has_children = result["has_children"] + if has_children: + children_text = self._read_block( + result_block_id, num_tabs=num_tabs + 1 + ) + cur_result_text_arr.append(children_text) + + cur_result_text = "\n".join(cur_result_text_arr) + result_lines_arr.append(cur_result_text) + + if data["next_cursor"] is None: + done = True + break + else: + cur_block_id = data["next_cursor"] + return result_lines_arr + def read_page(self, page_id: str) -> str: """Read a page.""" return self._read_block(page_id) + def read_page_as_documents(self, page_id: str) -> List[str]: + """Read a page as documents.""" + return self._read_block(page_id) + def query_database( - self, database_id: str, query_dict: Dict[str, Any] = {} + self, database_id: str, query_dict: Dict[str, Any] = {} ) -> List[str]: """Get all the pages from a Notion database.""" res = requests.post( @@ -136,7 +185,7 @@ class NotionPageReader(BaseReader): return page_ids def load_data( - self, page_ids: List[str] = [], database_id: Optional[str] = None + self, page_ids: List[str] = [], database_id: Optional[str] = None ) -> List[Document]: """Load data from the input directory. diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 635700cd91..db200b2ab8 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -252,7 +252,7 @@ class IndexingRunner: tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) return { - "total_segments": len(total_segments), + "total_segments": total_segments, "tokens": tokens, "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "currency": TokenCalculator.get_currency(self.embedding_model_name), @@ -261,25 +261,30 @@ class IndexingRunner: def _load_data(self, document: Document) -> List[Document]: # load file - if document.data_source_type != "upload_file": + if document.data_source_type not in ["upload_file", "notion_import"]: return [] data_source_info = document.data_source_info_dict - if not data_source_info or 'upload_file_id' not in data_source_info: - raise ValueError("no upload file found") + text_docs = [] + if document.data_source_type == 'upload_file': + if not data_source_info or 'upload_file_id' not in data_source_info: + raise ValueError("no upload file found") - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info['upload_file_id']). \ - one_or_none() - - text_docs = self._load_data_from_file(file_detail) + file_detail = db.session.query(UploadFile). \ + filter(UploadFile.id == data_source_info['upload_file_id']). \ + one_or_none() + text_docs = self._load_data_from_file(file_detail) + elif document.data_source_type == 'notion_import': + if not data_source_info or 'notion_page_id' not in data_source_info \ + or 'notion_workspace_id' not in data_source_info: + raise ValueError("no notion page found") + text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id']) # update document status to splitting self._update_document_index_status( document_id=document.id, after_indexing_status="splitting", extra_update_params={ - Document.file_id: file_detail.id, Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]), Document.parsing_completed_at: datetime.datetime.utcnow() } @@ -314,6 +319,22 @@ class IndexingRunner: return text_docs + def _load_data_from_notion(self, workspace_id: str, page_id: str) -> List[Document]: + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == workspace_id + ) + ).first() + if not data_source_binding: + raise ValueError('Data source binding not found.') + page_ids = [page_id] + reader = NotionPageReader(integration_token=data_source_binding.access_token) + text_docs = reader.load_data(page_ids=page_ids) + return text_docs + def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: """ Get the NodeParser object according to the processing rule. diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 77ffba61be..793070dfff 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -420,6 +420,7 @@ class DocumentService: raise ValueError('Data source binding not found.') for page in notion_info['pages']: data_source_info = { + "notion_workspace_id": workspace_id, "notion_page_id": page['page_id'], } document = DocumentService.save_document(dataset, dataset_process_rule.id,