diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3b9efeaab4..b2d61992c3 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -220,7 +220,7 @@ class DatasetDocumentListApi(Resource): DocumentService.document_create_args_validate(args) try: - document = DocumentService.save_document_with_dataset_id(dataset, args, current_user) + documents = DocumentService.save_document_with_dataset_id(dataset, args, current_user) except ProviderTokenNotInitError: raise ProviderNotInitializeError() except QuotaExceededError: @@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return document + return documents class DatasetInitApi(Resource): @@ -257,7 +257,7 @@ class DatasetInitApi(Resource): DocumentService.document_create_args_validate(args) try: - dataset, document = DocumentService.save_document_without_dataset_id( + dataset, documents = DocumentService.save_document_without_dataset_id( tenant_id=current_user.current_tenant_id, document_data=args, account=current_user @@ -271,7 +271,7 @@ class DatasetInitApi(Resource): response = { 'dataset': dataset, - 'document': document + 'documents': documents } return response diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 47a90756db..e6ab37ebc1 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -69,12 +69,16 @@ class DocumentListApi(DatasetApiResource): document_data = { 'data_source': { 'type': 'upload_file', - 'info': upload_file.id + 'info': [ + { + 'upload_file_id': upload_file.id + } + ] } } try: - document = DocumentService.save_document_with_dataset_id( + documents = DocumentService.save_document_with_dataset_id( dataset=dataset, document_data=document_data, account=dataset.created_by_account, @@ -83,7 +87,7 @@ class DocumentListApi(DatasetApiResource): ) except ProviderTokenNotInitError: raise ProviderNotInitializeError() - + document = documents[0] if doc_type and doc_metadata: metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 74aff357c3..635700cd91 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -38,42 +38,43 @@ class IndexingRunner: self.storage = storage self.embedding_model_name = embedding_model_name - def run(self, document: Document): + def run(self, documents: List[Document]): """Run the indexing process.""" - # get dataset - dataset = Dataset.query.filter_by( - id=document.dataset_id - ).first() + for document in documents: + # get dataset + dataset = Dataset.query.filter_by( + id=document.dataset_id + ).first() - if not dataset: - raise ValueError("no dataset found") + if not dataset: + raise ValueError("no dataset found") - # load file - text_docs = self._load_data(document) + # load file + text_docs = self._load_data(document) - # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ - first() + # get the process rule + processing_rule = db.session.query(DatasetProcessRule). \ + filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ + first() - # get node parser for splitting - node_parser = self._get_node_parser(processing_rule) + # get node parser for splitting + node_parser = self._get_node_parser(processing_rule) - # split to nodes - nodes = self._step_split( - text_docs=text_docs, - node_parser=node_parser, - dataset=dataset, - document=document, - processing_rule=processing_rule - ) + # split to nodes + nodes = self._step_split( + text_docs=text_docs, + node_parser=node_parser, + dataset=dataset, + document=document, + processing_rule=processing_rule + ) - # build index - self._build_index( - dataset=dataset, - document=document, - nodes=nodes - ) + # build index + self._build_index( + dataset=dataset, + document=document, + nodes=nodes + ) def run_in_splitting_status(self, document: Document): """Run the indexing process when the index_status is splitting.""" @@ -362,7 +363,7 @@ class IndexingRunner: embedding_model_name=self.embedding_model_name, document_id=document.id ) - + # add document segments doc_store.add_documents(nodes) # update document status to indexing diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 9007dd825b..77ffba61be 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -14,6 +14,7 @@ from extensions.ext_database import db from models.account import Account from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin from models.model import UploadFile +from models.source import DataSourceBinding from services.errors.account import NoPermissionError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError @@ -374,47 +375,85 @@ class DocumentService: ) db.session.add(dataset_process_rule) db.session.commit() - - file_name = '' - data_source_info = {} - if document_data["data_source"]["type"] == "upload_file": - file_id = document_data["data_source"]["info"] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() - - # raise error if file not found - if not file: - raise FileNotExistsError() - - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } - - # save document position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + documents = [] + if document_data["data_source"]["type"] == "upload_file": + upload_file_list = document_data["data_source"]["info"] + for upload_file in upload_file_list: + file_id = upload_file["upload_file_id"] + file = db.session.query(UploadFile).filter( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id == file_id + ).first() + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + document = DocumentService.save_document(dataset, dataset_process_rule.id, + document_data["data_source"]["type"], + data_source_info, created_from, position, + account, file_name) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + elif document_data["data_source"]["type"] == "notion_import": + notion_info_list = document_data["data_source"]['info'] + for notion_info in notion_info_list: + workspace_id = notion_info['workspace_id'] + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == workspace_id + ) + ).first() + if not data_source_binding: + raise ValueError('Data source binding not found.') + for page in notion_info['pages']: + data_source_info = { + "notion_page_id": page['page_id'], + } + document = DocumentService.save_document(dataset, dataset_process_rule.id, + document_data["data_source"]["type"], + data_source_info, created_from, position, + account, page['page_name']) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + + db.session.commit() + + # trigger async task + document_indexing_task.delay(dataset.id, document_ids) + + return documents + + @staticmethod + def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict, + created_from: str, position: int, account: Account, name: str): document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=position, - data_source_type=document_data["data_source"]["type"], + data_source_type=data_source_type, data_source_info=json.dumps(data_source_info), - dataset_process_rule_id=dataset_process_rule.id, + dataset_process_rule_id=process_rule_id, batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)), - name=file_name, + name=name, created_from=created_from, created_by=account.id, - # created_api_request_id = db.Column(UUID, nullable=True) ) - - db.session.add(document) - db.session.commit() - - # trigger async task - document_indexing_task.delay(document.dataset_id, document.id) - return document @staticmethod @@ -431,15 +470,15 @@ class DocumentService: db.session.add(dataset) db.session.flush() - document = DocumentService.save_document_with_dataset_id(dataset, document_data, account) + documents = DocumentService.save_document_with_dataset_id(dataset, document_data, account) cut_length = 18 - cut_name = document.name[:cut_length] - dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name - dataset.description = 'useful for when you want to answer queries about the ' + document.name + cut_name = documents[0].name[:cut_length] + dataset.name = cut_name + '...' if len(documents[0].name) > cut_length else cut_name + dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name db.session.commit() - return dataset, document + return dataset, documents @classmethod def document_create_args_validate(cls, args: dict): diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 59bbd4dc98..211d110fa8 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -13,32 +13,36 @@ from models.dataset import Document @shared_task -def document_indexing_task(dataset_id: str, document_id: str): +def document_indexing_task(dataset_id: str, document_ids: list): """ Async process document :param dataset_id: - :param document_id: + :param document_ids: Usage: document_indexing_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) - start_at = time.perf_counter() + documents = [] + for document_id in document_ids: + logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() - if not document: - raise NotFound('Document not found') + if not document: + raise NotFound('Document not found') - document.indexing_status = 'parsing' - document.processing_started_at = datetime.datetime.utcnow() + document.indexing_status = 'parsing' + document.processing_started_at = datetime.datetime.utcnow() + documents.append(document) + db.session.add(document) db.session.commit() try: indexing_runner = IndexingRunner() - indexing_runner.run(document) + indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) except DocumentIsPausedException: