diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 79e52d565b..11bdf89add 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -210,6 +210,7 @@ class DatasetDocumentListApi(Resource): parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('duplicate', type=bool, nullable=False, location='json') + parser.add_argument('original_document_id', type=str, required=False, location='json') args = parser.parse_args() if not dataset.indexing_technique and not args['indexing_technique']: @@ -244,8 +245,8 @@ class DatasetInitApi(Resource): parser = reqparse.RequestParser() parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True, nullable=False, location='json') - parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + parser.add_argument('data_source', type=dict, required=False, location='json') + parser.add_argument('process_rule', type=dict, required=False, location='json') args = parser.parse_args() # validate args diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 39004c3437..9b6cfff716 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -18,6 +18,7 @@ from services.errors.account import NoPermissionError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError +from tasks import document_indexing_update_task from tasks.document_indexing_task import document_indexing_task @@ -270,6 +271,14 @@ class DocumentService: return document + @staticmethod + def get_document_by_id(document_id: str) -> Optional[Document]: + document = db.session.query(Document).filter( + Document.id == document_id + ).first() + + return document + @staticmethod def get_document_file_detail(file_id: str): file_detail = db.session.query(UploadFile). \ @@ -349,6 +358,8 @@ class DocumentService: if dataset.indexing_technique == 'high_quality': IndexBuilder.get_default_service_context(dataset.tenant_id) + if document_data["original_document_id"]: + DocumentService.update_document_with_dataset_id(dataset, document_data, account) # save process rule if not dataset_process_rule: process_rule = document_data["process_rule"] @@ -411,6 +422,71 @@ class DocumentService: return document + @staticmethod + def update_document_with_dataset_id(dataset: Dataset, document_data: dict, + account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = 'web'): + document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) + if document.display_status != 'available': + raise ValueError("Document is not available") + # save process rule + if 'process_rule' in document_data or document_data['process_rule']: + process_rule = document_data["process_rule"] + if process_rule["mode"] == "custom": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule["mode"], + rules=json.dumps(process_rule["rules"]), + created_by=account.id + ) + elif process_rule["mode"] == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule["mode"], + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id + ) + db.session.add(dataset_process_rule) + db.session.commit() + document.dataset_process_rule_id = dataset_process_rule.id + # update document data source + if 'data_source' in document_data or document_data['data_source']: + file_name = '' + data_source_info = {} + if document_data["data_source"]["type"] == "upload_file": + file_id = document_data["data_source"]["info"] + file = db.session.query(UploadFile).filter( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id == file_id + ).first() + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + document.data_source_type = document_data["data_source"]["type"] + document.data_source_info = json.dumps(data_source_info) + document.name = file_name + # update document to be waiting + document.indexing_status = 'waiting' + document.completed_at = None + document.processing_started_at = None + document.parsing_completed_at = None + document.cleaning_completed_at = None + document.splitting_completed_at = None + document.updated_at = datetime.datetime.utcnow() + document.created_from = created_from + db.session.add(document) + db.session.commit() + # trigger async task + document_indexing_update_task.delay(document.dataset_id, document.id) + + return document + @staticmethod def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account): # save dataset @@ -437,6 +513,21 @@ class DocumentService: @classmethod def document_create_args_validate(cls, args: dict): + if 'original_document_id ' not in args or not args['original_document_id']: + DocumentService.data_source_args_validate(args) + DocumentService.process_rule_args_validate(args) + else: + if ('data_source' not in args or not args['data_source']) and ( + 'process_rule' not in args or not args['process_rule']): + raise ValueError("Data source or Process rule is required") + else: + if 'data_source' in args or args['data_source']: + DocumentService.data_source_args_validate(args) + elif 'process_rule' in args or args['process_rule']: + DocumentService.process_rule_args_validate(args) + + @classmethod + def data_source_args_validate(cls, args: dict): if 'data_source' not in args or not args['data_source']: raise ValueError("Data source is required") @@ -453,6 +544,8 @@ class DocumentService: if 'info' not in args['data_source'] or not args['data_source']['info']: raise ValueError("Data source info is required") + @classmethod + def process_rule_args_validate(cls, args: dict): if 'process_rule' not in args or not args['process_rule']: raise ValueError("Process rule is required") diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 5ca7f2d5c2..63d66e4ea6 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -35,8 +35,7 @@ def clean_document_task(document_id: str, dataset_id: str): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - if dataset.indexing_technique == "high_quality": - vector_index.del_nodes(index_node_ids) + vector_index.del_nodes(index_node_ids) # delete from keyword index if index_node_ids: diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py new file mode 100644 index 0000000000..8ba8a8fc26 --- /dev/null +++ b/api/tasks/document_indexing_update_task.py @@ -0,0 +1,86 @@ +import datetime +import logging +import time + +import click +from celery import shared_task +from werkzeug.exceptions import NotFound + +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from core.indexing_runner import IndexingRunner, DocumentIsPausedException +from core.llm.error import ProviderTokenNotInitError +from extensions.ext_database import db +from models.dataset import Document, Dataset, DocumentSegment + + +@shared_task +def document_indexing_update_task(dataset_id: str, document_id: str): + """ + Async process document + :param dataset_id: + :param document_id: + + Usage: document_indexing_update_task.delay(dataset_id, document_id) + """ + logging.info(click.style('Start update document: {}'.format(document_id), fg='green')) + start_at = time.perf_counter() + + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() + + if not document: + raise NotFound('Document not found') + + document.indexing_status = 'parsing' + document.processing_started_at = datetime.datetime.utcnow() + db.session.commit() + + # delete all document segment and index + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise Exception('Dataset not found') + + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + vector_index.del_nodes(index_node_ids) + + # delete from keyword index + if index_node_ids: + keyword_table_index.del_nodes(index_node_ids) + + for segment in segments: + db.session.delete(segment) + + end_at = time.perf_counter() + logging.info( + click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + except Exception: + logging.exception("Cleaned document when document update data source or process rule failed") + # start document re_segment + try: + indexing_runner = IndexingRunner() + indexing_runner.run(document) + end_at = time.perf_counter() + logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + except DocumentIsPausedException: + logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) + except ProviderTokenNotInitError as e: + document.indexing_status = 'error' + document.error = str(e.description) + document.stopped_at = datetime.datetime.utcnow() + db.session.commit() + except Exception as e: + logging.exception("consume update document failed") + document.indexing_status = 'error' + document.error = str(e) + document.stopped_at = datetime.datetime.utcnow() + db.session.commit()