From 9ca453f7f7097c83e4ac6629f4785908dbb80ab4 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 26 Nov 2024 18:35:30 +0800 Subject: [PATCH] update text spliter --- api/core/indexing_runner.py | 6 ++---- api/core/model_manager.py | 2 +- api/core/rag/docstore/dataset_docstore.py | 13 ++++++------- api/services/dataset_service.py | 10 +++++----- api/tasks/batch_create_segment_to_index_task.py | 8 +++++--- 5 files changed, 19 insertions(+), 20 deletions(-) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 29e161cb74..9e6d698122 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -720,10 +720,8 @@ class IndexingRunner: tokens = 0 if embedding_model_instance: - tokens += sum( - embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) - for document in chunk_documents - ) + page_content_list = [document.page_content for document in chunk_documents] + tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) # load index index_processor.load(dataset, chunk_documents, with_keywords=False) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 1986688551..5115e60118 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -183,7 +183,7 @@ class ModelInstance: input_type=input_type, ) - def get_text_embedding_num_tokens(self, texts: list[str]) -> int: + def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: """ Get number of tokens for text embedding diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 319a2612c7..08cc7a1c34 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -78,8 +78,13 @@ class DatasetDocumentStore: model_type=ModelType.TEXT_EMBEDDING, model=self._dataset.embedding_model, ) + if embedding_model: + page_content_list = [doc.page_content for doc in docs] + tokens_list = embedding_model.get_text_embedding_num_tokens(page_content_list) + else: + tokens_list = [0] * len(docs) - for doc in docs: + for doc, tokens in zip(docs, tokens_list): if not isinstance(doc, Document): raise ValueError("doc must be a Document") @@ -91,12 +96,6 @@ class DatasetDocumentStore: f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite." ) - # calc embedding use tokens - if embedding_model: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content]) - else: - tokens = 0 - if not segment_document: max_position += 1 diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index d38729f31e..6e071c17ec 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1390,7 +1390,7 @@ class SegmentService: model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] lock_name = "add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): max_position = ( @@ -1467,9 +1467,9 @@ class SegmentService: if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens if document.doc_form == "qa_model": - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]]) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]])[0] else: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1577,9 +1577,9 @@ class SegmentService: # calc embedding use tokens if document.doc_form == "qa_model": - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer]) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] else: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dcb7009e44..39c032caad 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -58,12 +58,14 @@ def batch_create_segment_to_index_task( model=dataset.embedding_model, ) word_count_change = 0 - for segment in content: + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content]) + else: + tokens_list = [0] * len(content) + for segment, tokens in zip(content, tokens_list): content = segment["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) - # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == dataset_document.id)