support notion import documents

This commit is contained in:
Jyong 2023-05-25 00:15:54 +08:00
parent 201d9943bb
commit e2ef272f48
5 changed files with 134 additions and 86 deletions

View File

@ -220,7 +220,7 @@ class DatasetDocumentListApi(Resource):
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
try: try:
document = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
raise ProviderNotInitializeError() raise ProviderNotInitializeError()
except QuotaExceededError: except QuotaExceededError:
@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
return document return documents
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
@ -257,7 +257,7 @@ class DatasetInitApi(Resource):
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
try: try:
dataset, document = DocumentService.save_document_without_dataset_id( dataset, documents = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
document_data=args, document_data=args,
account=current_user account=current_user
@ -271,7 +271,7 @@ class DatasetInitApi(Resource):
response = { response = {
'dataset': dataset, 'dataset': dataset,
'document': document 'documents': documents
} }
return response return response

View File

@ -69,12 +69,16 @@ class DocumentListApi(DatasetApiResource):
document_data = { document_data = {
'data_source': { 'data_source': {
'type': 'upload_file', 'type': 'upload_file',
'info': upload_file.id 'info': [
{
'upload_file_id': upload_file.id
}
]
} }
} }
try: try:
document = DocumentService.save_document_with_dataset_id( documents = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=document_data, document_data=document_data,
account=dataset.created_by_account, account=dataset.created_by_account,
@ -83,7 +87,7 @@ class DocumentListApi(DatasetApiResource):
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
raise ProviderNotInitializeError() raise ProviderNotInitializeError()
document = documents[0]
if doc_type and doc_metadata: if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]

View File

@ -38,42 +38,43 @@ class IndexingRunner:
self.storage = storage self.storage = storage
self.embedding_model_name = embedding_model_name self.embedding_model_name = embedding_model_name
def run(self, document: Document): def run(self, documents: List[Document]):
"""Run the indexing process.""" """Run the indexing process."""
# get dataset for document in documents:
dataset = Dataset.query.filter_by( # get dataset
id=document.dataset_id dataset = Dataset.query.filter_by(
).first() id=document.dataset_id
).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# load file # load file
text_docs = self._load_data(document) text_docs = self._load_data(document)
# get the process rule # get the process rule
processing_rule = db.session.query(DatasetProcessRule). \ processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
first() first()
# get node parser for splitting # get node parser for splitting
node_parser = self._get_node_parser(processing_rule) node_parser = self._get_node_parser(processing_rule)
# split to nodes # split to nodes
nodes = self._step_split( nodes = self._step_split(
text_docs=text_docs, text_docs=text_docs,
node_parser=node_parser, node_parser=node_parser,
dataset=dataset, dataset=dataset,
document=document, document=document,
processing_rule=processing_rule processing_rule=processing_rule
) )
# build index # build index
self._build_index( self._build_index(
dataset=dataset, dataset=dataset,
document=document, document=document,
nodes=nodes nodes=nodes
) )
def run_in_splitting_status(self, document: Document): def run_in_splitting_status(self, document: Document):
"""Run the indexing process when the index_status is splitting.""" """Run the indexing process when the index_status is splitting."""
@ -362,7 +363,7 @@ class IndexingRunner:
embedding_model_name=self.embedding_model_name, embedding_model_name=self.embedding_model_name,
document_id=document.id document_id=document.id
) )
# add document segments
doc_store.add_documents(nodes) doc_store.add_documents(nodes)
# update document status to indexing # update document status to indexing

View File

@ -14,6 +14,7 @@ from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceBinding
from services.errors.account import NoPermissionError from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError from services.errors.document import DocumentIndexingError
@ -374,47 +375,85 @@ class DocumentService:
) )
db.session.add(dataset_process_rule) db.session.add(dataset_process_rule)
db.session.commit() db.session.commit()
file_name = ''
data_source_info = {}
if document_data["data_source"]["type"] == "upload_file":
file_id = document_data["data_source"]["info"]
file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id
).first()
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# save document
position = DocumentService.get_documents_position(dataset.id) position = DocumentService.get_documents_position(dataset.id)
document_ids = []
documents = []
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info"]
for upload_file in upload_file_list:
file_id = upload_file["upload_file_id"]
file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id
).first()
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, file_name)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
for page in notion_info['pages']:
data_source_info = {
"notion_page_id": page['page_id'],
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, page['page_name'])
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
document_indexing_task.delay(dataset.id, document_ids)
return documents
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
created_from: str, position: int, account: Account, name: str):
document = Document( document = Document(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
position=position, position=position,
data_source_type=document_data["data_source"]["type"], data_source_type=data_source_type,
data_source_info=json.dumps(data_source_info), data_source_info=json.dumps(data_source_info),
dataset_process_rule_id=dataset_process_rule.id, dataset_process_rule_id=process_rule_id,
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)), batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
name=file_name, name=name,
created_from=created_from, created_from=created_from,
created_by=account.id, created_by=account.id,
# created_api_request_id = db.Column(UUID, nullable=True)
) )
db.session.add(document)
db.session.commit()
# trigger async task
document_indexing_task.delay(document.dataset_id, document.id)
return document return document
@staticmethod @staticmethod
@ -431,15 +470,15 @@ class DocumentService:
db.session.add(dataset) db.session.add(dataset)
db.session.flush() db.session.flush()
document = DocumentService.save_document_with_dataset_id(dataset, document_data, account) documents = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
cut_length = 18 cut_length = 18
cut_name = document.name[:cut_length] cut_name = documents[0].name[:cut_length]
dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name dataset.name = cut_name + '...' if len(documents[0].name) > cut_length else cut_name
dataset.description = 'useful for when you want to answer queries about the ' + document.name dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
db.session.commit() db.session.commit()
return dataset, document return dataset, documents
@classmethod @classmethod
def document_create_args_validate(cls, args: dict): def document_create_args_validate(cls, args: dict):

View File

@ -13,32 +13,36 @@ from models.dataset import Document
@shared_task @shared_task
def document_indexing_task(dataset_id: str, document_id: str): def document_indexing_task(dataset_id: str, document_ids: list):
""" """
Async process document Async process document
:param dataset_id: :param dataset_id:
:param document_id: :param document_ids:
Usage: document_indexing_task.delay(dataset_id, document_id) Usage: document_indexing_task.delay(dataset_id, document_id)
""" """
logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) documents = []
start_at = time.perf_counter() for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
start_at = time.perf_counter()
document = db.session.query(Document).filter( document = db.session.query(Document).filter(
Document.id == document_id, Document.id == document_id,
Document.dataset_id == dataset_id Document.dataset_id == dataset_id
).first() ).first()
if not document: if not document:
raise NotFound('Document not found') raise NotFound('Document not found')
document.indexing_status = 'parsing' document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow() document.processing_started_at = datetime.datetime.utcnow()
documents.append(document)
db.session.add(document)
db.session.commit() db.session.commit()
try: try:
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run(document) indexing_runner.run(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException: