diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index fcafbce798..acb9cd807e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -10,7 +10,9 @@ body: options: - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + required: true + - label: "Pleas do not modify this template :) and fill in all the required fields." required: true - type: input diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml index 750bc3501a..44115b2097 100644 --- a/.github/ISSUE_TEMPLATE/document_issue.yml +++ b/.github/ISSUE_TEMPLATE/document_issue.yml @@ -10,7 +10,9 @@ body: options: - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + required: true + - label: "Pleas do not modify this template :) and fill in all the required fields." required: true - type: textarea attributes: diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index bebf72efc3..694bd3975d 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -10,7 +10,9 @@ body: options: - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + required: true + - label: "Pleas do not modify this template :) and fill in all the required fields." required: true - type: textarea attributes: diff --git a/.github/ISSUE_TEMPLATE/help_wanted.yml b/.github/ISSUE_TEMPLATE/help_wanted.yml index 88ffae7f09..1834d63f52 100644 --- a/.github/ISSUE_TEMPLATE/help_wanted.yml +++ b/.github/ISSUE_TEMPLATE/help_wanted.yml @@ -10,7 +10,9 @@ body: options: - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + required: true + - label: "Pleas do not modify this template :) and fill in all the required fields." required: true - type: textarea attributes: diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml index aa6d077c61..589e071e14 100644 --- a/.github/ISSUE_TEMPLATE/translation_issue.yml +++ b/.github/ISSUE_TEMPLATE/translation_issue.yml @@ -10,7 +10,9 @@ body: options: - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + required: true + - label: "Pleas do not modify this template :) and fill in all the required fields." required: true - type: input attributes: diff --git a/api/.env.example b/api/.env.example index d492c1f8be..89d550ba5a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -130,3 +130,5 @@ UNSTRUCTURED_API_URL= SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= + +BATCH_UPLOAD_LIMIT=10 \ No newline at end of file diff --git a/api/celerybeat-schedule.db b/api/celerybeat-schedule.db new file mode 100644 index 0000000000..b8c01de27b Binary files /dev/null and b/api/celerybeat-schedule.db differ diff --git a/api/config.py b/api/config.py index b6a8ce1438..83336e6c45 100644 --- a/api/config.py +++ b/api/config.py @@ -56,6 +56,8 @@ DEFAULTS = { 'BILLING_ENABLED': 'False', 'CAN_REPLACE_LOGO': 'False', 'ETL_TYPE': 'dify', + 'KEYWORD_STORE': 'jieba', + 'BATCH_UPLOAD_LIMIT': 20 } @@ -182,7 +184,7 @@ class Config: # Currently, only support: qdrant, milvus, zilliz, weaviate # ------------------------ self.VECTOR_STORE = get_env('VECTOR_STORE') - + self.KEYWORD_STORE = get_env('KEYWORD_STORE') # qdrant settings self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') @@ -285,6 +287,8 @@ class Config: self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED') self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO') + self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') + class CloudEditionConfig(Config): diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 86fcf704c7..c0c345baea 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.data_loader.loader.notion import NotionLoader from core.indexing_runner import IndexingRunner +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from libs.login import login_required @@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource): if not data_source_binding: raise NotFound('Data source binding not found.') - loader = NotionLoader( - notion_access_token=data_source_binding.access_token, + extractor = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, - notion_page_type=page_type + notion_page_type=page_type, + notion_access_token=data_source_binding.access_token ) - text_docs = loader.load() + text_docs = extractor.extract() return { 'content': "\n".join([doc.page_content for doc in text_docs]) }, 200 @@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource): parser = reqparse.RequestParser() parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) + notion_info_list = args['notion_info_list'] + extract_settings = [] + for notion_info in notion_info_list: + workspace_id = notion_info['workspace_id'] + for page in notion_info['pages']: + extract_setting = ExtractSetting( + datasource_type="notion_import", + notion_info={ + "notion_workspace_id": workspace_id, + "notion_obj_id": page['page_id'], + "notion_page_type": page['type'] + }, + document_model=args['doc_form'] + ) + extract_settings.append(extract_setting) indexing_runner = IndexingRunner() - response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) + response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, + args['process_rule'], args['doc_form'], + args['doc_language']) return response, 200 diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 2d26d0ecf4..f80b4de48d 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager +from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields @@ -178,9 +179,9 @@ class DatasetApi(Resource): location='json', store_missing=False, type=_validate_description_length) parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help='Invalid indexing technique.') parser.add_argument('permission', type=str, location='json', choices=( 'only_me', 'all_team_members'), help='Invalid permission.') parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') @@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource): parser = reqparse.RequestParser() parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('indexing_technique', type=str, required=True, + parser.add_argument('indexing_technique', type=str, required=True, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=True, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') @@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource): args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) + extract_settings = [] if args['info_list']['data_source_type'] == 'upload_file': file_ids = args['info_list']['file_info_list']['file_ids'] file_details = db.session.query(UploadFile).filter( @@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource): if file_details is None: raise NotFound("File not found.") - indexing_runner = IndexingRunner() - - try: - response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, - args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id'], - args['indexing_technique']) - except LLMBadRequestError: - raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) + if file_details: + for file_detail in file_details: + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file_detail, + document_model=args['doc_form'] + ) + extract_settings.append(extract_setting) elif args['info_list']['data_source_type'] == 'notion_import': - - indexing_runner = IndexingRunner() - - try: - response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, - args['info_list']['notion_info_list'], - args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id'], - args['indexing_technique']) - except LLMBadRequestError: - raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) + notion_info_list = args['info_list']['notion_info_list'] + for notion_info in notion_info_list: + workspace_id = notion_info['workspace_id'] + for page in notion_info['pages']: + extract_setting = ExtractSetting( + datasource_type="notion_import", + notion_info={ + "notion_workspace_id": workspace_id, + "notion_obj_id": page['page_id'], + "notion_page_type": page['type'] + }, + document_model=args['doc_form'] + ) + extract_settings.append(extract_setting) else: raise ValueError('Data source type not support') + indexing_runner = IndexingRunner() + try: + response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, + args['process_rule'], args['doc_form'], + args['doc_language'], args['dataset_id'], + args['indexing_technique']) + except LLMBadRequestError: + raise ProviderNotInitializeError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + return response, 200 @@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/') - diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3fb6f16cd6..a990ef96ee 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.document_fields import ( @@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource): req_data = request.args document_id = req_data.get('document_id') - + # get default rules mode = DocumentService.DEFAULT_RULES['mode'] rules = DocumentService.DEFAULT_RULES['rules'] @@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource): if not file: raise NotFound('File not found.') + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file, + document_model=document.doc_form + ) + indexing_runner = IndexingRunner() try: - response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], - data_process_rule_dict, None, - 'English', dataset_id) + response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting], + data_process_rule_dict, document.doc_form, + 'English', dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): data_process_rule = documents[0].dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() info_list = [] + extract_settings = [] for document in documents: if document.indexing_status in ['completed', 'error']: raise DocumentAlreadyFinishedError() @@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): } info_list.append(notion_info) - if dataset.data_source_type == 'upload_file': - file_details = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id.in_(info_list) - ).all() + if document.data_source_type == 'upload_file': + file_id = data_source_info['upload_file_id'] + file_detail = db.session.query(UploadFile).filter( + UploadFile.tenant_id == current_user.current_tenant_id, + UploadFile.id == file_id + ).first() - if file_details is None: - raise NotFound("File not found.") + if file_detail is None: + raise NotFound("File not found.") + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file_detail, + document_model=document.doc_form + ) + extract_settings.append(extract_setting) + + elif document.data_source_type == 'notion_import': + extract_setting = ExtractSetting( + datasource_type="notion_import", + notion_info={ + "notion_workspace_id": data_source_info['notion_workspace_id'], + "notion_obj_id": data_source_info['notion_page_id'], + "notion_page_type": data_source_info['type'] + }, + document_model=document.doc_form + ) + extract_settings.append(extract_setting) + + else: + raise ValueError('Data source type not support') indexing_runner = IndexingRunner() try: - response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, - data_process_rule_dict, None, - 'English', dataset_id) + response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, + data_process_rule_dict, document.doc_form, + 'English', dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - elif dataset.data_source_type == 'notion_import': - - indexing_runner = IndexingRunner() - try: - response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, - info_list, - data_process_rule_dict, - None, 'English', dataset_id) - except LLMBadRequestError: - raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - else: - raise ValueError('Data source type not support') return response diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 903953486a..879c9df69d 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,8 +1,7 @@ -from langchain.schema import Document - from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import InvokeFrom +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import DatasetQuery, DocumentSegment from models.model import DatasetRetrieverResource diff --git a/api/core/data_loader/file_extractor.py b/api/core/data_loader/file_extractor.py index 4741014c96..e69de29bb2 100644 --- a/api/core/data_loader/file_extractor.py +++ b/api/core/data_loader/file_extractor.py @@ -1,107 +0,0 @@ -import tempfile -from pathlib import Path -from typing import Optional, Union - -import requests -from flask import current_app -from langchain.document_loaders import Docx2txtLoader, TextLoader -from langchain.schema import Document - -from core.data_loader.loader.csv_loader import CSVLoader -from core.data_loader.loader.excel import ExcelLoader -from core.data_loader.loader.html import HTMLLoader -from core.data_loader.loader.markdown import MarkdownLoader -from core.data_loader.loader.pdf import PdfLoader -from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader -from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader -from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader -from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader -from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader -from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader -from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader -from extensions.ext_storage import storage -from models.model import UploadFile - -SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] -USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - - -class FileExtractor: - @classmethod - def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]: - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - storage.download(upload_file.key, file_path) - - return cls.load_from_file(file_path, return_text, upload_file, is_automatic) - - @classmethod - def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = requests.get(url, headers={ - "User-Agent": USER_AGENT - }) - - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(url).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - with open(file_path, 'wb') as file: - file.write(response.content) - - return cls.load_from_file(file_path, return_text) - - @classmethod - def load_from_file(cls, file_path: str, return_text: bool = False, - upload_file: Optional[UploadFile] = None, - is_automatic: bool = False) -> Union[list[Document], str]: - input_file = Path(file_path) - delimiter = '\n' - file_extension = input_file.suffix.lower() - etl_type = current_app.config['ETL_TYPE'] - unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] - if etl_type == 'Unstructured': - if file_extension == '.xlsx': - loader = ExcelLoader(file_path) - elif file_extension == '.pdf': - loader = PdfLoader(file_path, upload_file=upload_file) - elif file_extension in ['.md', '.markdown']: - loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \ - else MarkdownLoader(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: - loader = HTMLLoader(file_path) - elif file_extension in ['.docx']: - loader = Docx2txtLoader(file_path) - elif file_extension == '.csv': - loader = CSVLoader(file_path, autodetect_encoding=True) - elif file_extension == '.msg': - loader = UnstructuredMsgLoader(file_path, unstructured_api_url) - elif file_extension == '.eml': - loader = UnstructuredEmailLoader(file_path, unstructured_api_url) - elif file_extension == '.ppt': - loader = UnstructuredPPTLoader(file_path, unstructured_api_url) - elif file_extension == '.pptx': - loader = UnstructuredPPTXLoader(file_path, unstructured_api_url) - elif file_extension == '.xml': - loader = UnstructuredXmlLoader(file_path, unstructured_api_url) - else: - # txt - loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \ - else TextLoader(file_path, autodetect_encoding=True) - else: - if file_extension == '.xlsx': - loader = ExcelLoader(file_path) - elif file_extension == '.pdf': - loader = PdfLoader(file_path, upload_file=upload_file) - elif file_extension in ['.md', '.markdown']: - loader = MarkdownLoader(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: - loader = HTMLLoader(file_path) - elif file_extension in ['.docx']: - loader = Docx2txtLoader(file_path) - elif file_extension == '.csv': - loader = CSVLoader(file_path, autodetect_encoding=True) - else: - # txt - loader = TextLoader(file_path, autodetect_encoding=True) - - return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() diff --git a/api/core/data_loader/loader/html.py b/api/core/data_loader/loader/html.py deleted file mode 100644 index 6a9b48a5b2..0000000000 --- a/api/core/data_loader/loader/html.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging - -from bs4 import BeautifulSoup -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document - -logger = logging.getLogger(__name__) - - -class HTMLLoader(BaseLoader): - """Load html files. - - - Args: - file_path: Path to the file to load. - """ - - def __init__( - self, - file_path: str - ): - """Initialize with file path.""" - self._file_path = file_path - - def load(self) -> list[Document]: - return [Document(page_content=self._load_as_text())] - - def _load_as_text(self) -> str: - with open(self._file_path, "rb") as fp: - soup = BeautifulSoup(fp, 'html.parser') - text = soup.get_text() - text = text.strip() if text else '' - - return text diff --git a/api/core/data_loader/loader/pdf.py b/api/core/data_loader/loader/pdf.py deleted file mode 100644 index a3452b367b..0000000000 --- a/api/core/data_loader/loader/pdf.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -from typing import Optional - -from langchain.document_loaders import PyPDFium2Loader -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document - -from extensions.ext_storage import storage -from models.model import UploadFile - -logger = logging.getLogger(__name__) - - -class PdfLoader(BaseLoader): - """Load pdf files. - - - Args: - file_path: Path to the file to load. - """ - - def __init__( - self, - file_path: str, - upload_file: Optional[UploadFile] = None - ): - """Initialize with file path.""" - self._file_path = file_path - self._upload_file = upload_file - - def load(self) -> list[Document]: - plaintext_file_key = '' - plaintext_file_exists = False - if self._upload_file: - if self._upload_file.hash: - plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \ - + self._upload_file.hash + '.0625.plaintext' - try: - text = storage.load(plaintext_file_key).decode('utf-8') - plaintext_file_exists = True - return [Document(page_content=text)] - except FileNotFoundError: - pass - documents = PyPDFium2Loader(file_path=self._file_path).load() - text_list = [] - for document in documents: - text_list.append(document.page_content) - text = "\n\n".join(text_list) - - # save plaintext file for caching - if not plaintext_file_exists and plaintext_file_key: - storage.save(plaintext_file_key, text.encode('utf-8')) - - return documents - diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 556b3aceda..9a051fd4cb 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from typing import Any, Optional, cast -from langchain.schema import Document from sqlalchemy import func from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment diff --git a/api/core/features/annotation_reply.py b/api/core/features/annotation_reply.py index bdc5467e62..e1b64cf73f 100644 --- a/api/core/features/annotation_reply.py +++ b/api/core/features/annotation_reply.py @@ -1,13 +1,8 @@ import logging from typing import Optional -from flask import current_app - -from core.embedding.cached_embedding import CacheEmbedding from core.entities.application_entities import InvokeFrom -from core.index.vector_index.vector_index import VectorIndex -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from models.dataset import Dataset from models.model import App, AppAnnotationSetting, Message, MessageAnnotation @@ -45,17 +40,6 @@ class AnnotationReplyFeature: embedding_provider_name = collection_binding_detail.provider_name embedding_model_name = collection_binding_detail.model_name - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=app_record.tenant_id, - provider=embedding_provider_name, - model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model_name - ) - - # get embedding model - embeddings = CacheEmbedding(model_instance) - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_provider_name, embedding_model_name, @@ -71,22 +55,14 @@ class AnnotationReplyFeature: collection_binding_id=dataset_collection_binding.id ) - vector_index = VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings, - attributes=['doc_id', 'annotation_id', 'app_id'] - ) + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - documents = vector_index.search( + documents = vector.search_by_vector( query=query, - search_type='similarity_score_threshold', - search_kwargs={ - 'k': 1, - 'score_threshold': score_threshold, - 'filter': { - 'group_id': [dataset.id] - } + k=1, + score_threshold=score_threshold, + filter={ + 'group_id': [dataset.id] } ) diff --git a/api/core/index/index.py b/api/core/index/index.py deleted file mode 100644 index 42971c895e..0000000000 --- a/api/core/index/index.py +++ /dev/null @@ -1,51 +0,0 @@ -from flask import current_app -from langchain.embeddings import OpenAIEmbeddings - -from core.embedding.cached_embedding import CacheEmbedding -from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex -from core.index.vector_index.vector_index import VectorIndex -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from models.dataset import Dataset - - -class IndexBuilder: - @classmethod - def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False): - if indexing_technique == "high_quality": - if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': - return None - - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - provider=dataset.embedding_model_provider, - model=dataset.embedding_model - ) - - embeddings = CacheEmbedding(embedding_model) - - return VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) - elif indexing_technique == "economy": - return KeywordTableIndex( - dataset=dataset, - config=KeywordTableConfig( - max_keywords_per_chunk=10 - ) - ) - else: - raise ValueError('Unknown indexing technique') - - @classmethod - def get_default_high_quality_index(cls, dataset: Dataset): - embeddings = OpenAIEmbeddings(openai_api_key=' ') - return VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py deleted file mode 100644 index 36aa1917a6..0000000000 --- a/api/core/index/vector_index/base.py +++ /dev/null @@ -1,305 +0,0 @@ -import json -import logging -from abc import abstractmethod -from typing import Any, cast - -from langchain.embeddings.base import Embeddings -from langchain.schema import BaseRetriever, Document -from langchain.vectorstores import VectorStore - -from core.index.base import BaseIndex -from extensions.ext_database import db -from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment -from models.dataset import Document as DatasetDocument - - -class BaseVectorIndex(BaseIndex): - - def __init__(self, dataset: Dataset, embeddings: Embeddings): - super().__init__(dataset) - self._embeddings = embeddings - self._vector_store = None - - def get_type(self) -> str: - raise NotImplementedError - - @abstractmethod - def get_index_name(self, dataset: Dataset) -> str: - raise NotImplementedError - - @abstractmethod - def to_index_struct(self) -> dict: - raise NotImplementedError - - @abstractmethod - def _get_vector_store(self) -> VectorStore: - raise NotImplementedError - - @abstractmethod - def _get_vector_store_class(self) -> type: - raise NotImplementedError - - @abstractmethod - def search_by_full_text_index( - self, query: str, - **kwargs: Any - ) -> list[Document]: - raise NotImplementedError - - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity' - search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} - - if search_type == 'similarity_score_threshold': - score_threshold = search_kwargs.get("score_threshold") - if (score_threshold is None) or (not isinstance(score_threshold, float)): - search_kwargs['score_threshold'] = .0 - - docs_with_similarity = vector_store.similarity_search_with_relevance_scores( - query, **search_kwargs - ) - - docs = [] - for doc, similarity in docs_with_similarity: - doc.metadata['score'] = similarity - docs.append(doc) - - return docs - - # similarity k - # mmr k, fetch_k, lambda_mult - # similarity_score_threshold k - return vector_store.as_retriever( - search_type=search_type, - search_kwargs=search_kwargs - ).get_relevant_documents(query) - - def get_retriever(self, **kwargs: Any) -> BaseRetriever: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - return vector_store.as_retriever(**kwargs) - - def add_texts(self, texts: list[Document], **kwargs): - if self._is_origin(): - self.recreate_dataset(self.dataset) - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - if kwargs.get('duplicate_check', False): - texts = self._filter_duplicate_texts(texts) - - uuids = self._get_uuids(texts) - vector_store.add_documents(texts, uuids=uuids) - - def text_exists(self, id: str) -> bool: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - return vector_store.text_exists(id) - - def delete_by_ids(self, ids: list[str]) -> None: - if self._is_origin(): - self.recreate_dataset(self.dataset) - return - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - for node_id in ids: - vector_store.del_text(node_id) - - def delete_by_group_id(self, group_id: str) -> None: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - if self.dataset.collection_binding_id: - vector_store.delete_by_group_id(group_id) - else: - vector_store.delete() - - def delete(self) -> None: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.delete() - - def _is_origin(self): - return False - - def recreate_dataset(self, dataset: Dataset): - logging.info(f"Recreating dataset {dataset.id}") - - try: - self.delete() - except Exception as e: - raise e - - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() - - documents = [] - for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() - - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - documents.append(document) - - origin_index_struct = self.dataset.index_struct[:] - self.dataset.index_struct = None - - if documents: - try: - self.create(documents) - except Exception as e: - self.dataset.index_struct = origin_index_struct - raise e - - dataset.index_struct = json.dumps(self.to_index_struct()) - - db.session.commit() - - self.dataset = dataset - logging.info(f"Dataset {dataset.id} recreate successfully.") - - def create_qdrant_dataset(self, dataset: Dataset): - logging.info(f"create_qdrant_dataset {dataset.id}") - - try: - self.delete() - except Exception as e: - raise e - - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() - - documents = [] - for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() - - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - documents.append(document) - - if documents: - try: - self.create(documents) - except Exception as e: - raise e - - logging.info(f"Dataset {dataset.id} recreate successfully.") - - def update_qdrant_dataset(self, dataset: Dataset): - logging.info(f"update_qdrant_dataset {dataset.id}") - - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).first() - - if segment: - try: - exist = self.text_exists(segment.index_node_id) - if exist: - index_struct = { - "type": 'qdrant', - "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} - } - dataset.index_struct = json.dumps(index_struct) - db.session.commit() - except Exception as e: - raise e - - logging.info(f"Dataset {dataset.id} recreate successfully.") - - def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): - logging.info(f"restore dataset in_one,_dataset {dataset.id}") - - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() - - documents = [] - for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() - - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - documents.append(document) - - if documents: - try: - self.add_texts(documents) - except Exception as e: - raise e - - logging.info(f"Dataset {dataset.id} recreate successfully.") - - def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): - logging.info(f"delete original collection: {dataset.id}") - - self.delete() - - dataset.collection_binding_id = dataset_collection_binding.id - db.session.add(dataset) - db.session.commit() - - logging.info(f"Dataset {dataset.id} recreate successfully.") diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py deleted file mode 100644 index a18cf35a27..0000000000 --- a/api/core/index/vector_index/milvus_vector_index.py +++ /dev/null @@ -1,165 +0,0 @@ -from typing import Any, cast - -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain.vectorstores import VectorStore -from pydantic import BaseModel, root_validator - -from core.index.base import BaseIndex -from core.index.vector_index.base import BaseVectorIndex -from core.vector_store.milvus_vector_store import MilvusVectorStore -from models.dataset import Dataset - - -class MilvusConfig(BaseModel): - host: str - port: int - user: str - password: str - secure: bool = False - batch_size: int = 100 - - @root_validator() - def validate_config(cls, values: dict) -> dict: - if not values['host']: - raise ValueError("config MILVUS_HOST is required") - if not values['port']: - raise ValueError("config MILVUS_PORT is required") - if not values['user']: - raise ValueError("config MILVUS_USER is required") - if not values['password']: - raise ValueError("config MILVUS_PASSWORD is required") - return values - - def to_milvus_params(self): - return { - 'host': self.host, - 'port': self.port, - 'user': self.user, - 'password': self.password, - 'secure': self.secure - } - - -class MilvusVectorIndex(BaseVectorIndex): - def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings): - super().__init__(dataset, embeddings) - self._client_config = config - - def get_type(self) -> str: - return 'milvus' - - def get_index_name(self, dataset: Dataset) -> str: - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - class_prefix += '_Node' - - return class_prefix - - dataset_id = dataset.id - return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' - - def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self.get_index_name(self.dataset)} - } - - def create(self, texts: list[Document], **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - index_params = { - 'metric_type': 'IP', - 'index_type': "HNSW", - 'params': {"M": 8, "efConstruction": 64} - } - self._vector_store = MilvusVectorStore.from_documents( - texts, - self._embeddings, - collection_name=self.get_index_name(self.dataset), - connection_args=self._client_config.to_milvus_params(), - index_params=index_params - ) - - return self - - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = MilvusVectorStore.from_documents( - texts, - self._embeddings, - collection_name=collection_name, - ids=uuids, - content_payload_key='page_content' - ) - - return self - - def _get_vector_store(self) -> VectorStore: - """Only for created index.""" - if self._vector_store: - return self._vector_store - - return MilvusVectorStore( - collection_name=self.get_index_name(self.dataset), - embedding_function=self._embeddings, - connection_args=self._client_config.to_milvus_params() - ) - - def _get_vector_store_class(self) -> type: - return MilvusVectorStore - - def delete_by_document_id(self, document_id: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - ids = vector_store.get_ids_by_document_id(document_id) - if ids: - vector_store.del_texts({ - 'filter': f'id in {ids}' - }) - - def delete_by_metadata_field(self, key: str, value: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - ids = vector_store.get_ids_by_metadata_field(key, value) - if ids: - vector_store.del_texts({ - 'filter': f'id in {ids}' - }) - - def delete_by_ids(self, doc_ids: list[str]) -> None: - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - ids = vector_store.get_ids_by_doc_ids(doc_ids) - vector_store.del_texts({ - 'filter': f' id in {ids}' - }) - - def delete_by_group_id(self, group_id: str) -> None: - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.delete() - - def delete(self) -> None: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=self.dataset.id), - ), - ], - )) - - def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: - # milvus/zilliz doesn't support bm25 search - return [] diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py deleted file mode 100644 index 046260d2f8..0000000000 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ /dev/null @@ -1,229 +0,0 @@ -import os -from typing import Any, Optional, cast - -import qdrant_client -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain.vectorstores import VectorStore -from pydantic import BaseModel -from qdrant_client.http.models import HnswConfigDiff - -from core.index.base import BaseIndex -from core.index.vector_index.base import BaseVectorIndex -from core.vector_store.qdrant_vector_store import QdrantVectorStore -from extensions.ext_database import db -from models.dataset import Dataset, DatasetCollectionBinding - - -class QdrantConfig(BaseModel): - endpoint: str - api_key: Optional[str] - timeout: float = 20 - root_path: Optional[str] - - def to_qdrant_params(self): - if self.endpoint and self.endpoint.startswith('path:'): - path = self.endpoint.replace('path:', '') - if not os.path.isabs(path): - path = os.path.join(self.root_path, path) - - return { - 'path': path - } - else: - return { - 'url': self.endpoint, - 'api_key': self.api_key, - 'timeout': self.timeout - } - - -class QdrantVectorIndex(BaseVectorIndex): - def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings): - super().__init__(dataset, embeddings) - self._client_config = config - - def get_type(self) -> str: - return 'qdrant' - - def get_index_name(self, dataset: Dataset) -> str: - if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() - if dataset_collection_binding: - return dataset_collection_binding.collection_name - else: - raise ValueError('Dataset Collection Bindings is not exist!') - else: - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - return class_prefix - - dataset_id = dataset.id - return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' - - def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self.get_index_name(self.dataset)} - } - - def create(self, texts: list[Document], **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = QdrantVectorStore.from_documents( - texts, - self._embeddings, - collection_name=self.get_index_name(self.dataset), - ids=uuids, - content_payload_key='page_content', - group_id=self.dataset.id, - group_payload_key='group_id', - hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False), - **self._client_config.to_qdrant_params() - ) - - return self - - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = QdrantVectorStore.from_documents( - texts, - self._embeddings, - collection_name=collection_name, - ids=uuids, - content_payload_key='page_content', - group_id=self.dataset.id, - group_payload_key='group_id', - hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False), - **self._client_config.to_qdrant_params() - ) - - return self - - def _get_vector_store(self) -> VectorStore: - """Only for created index.""" - if self._vector_store: - return self._vector_store - attributes = ['doc_id', 'dataset_id', 'document_id'] - client = qdrant_client.QdrantClient( - **self._client_config.to_qdrant_params() - ) - - return QdrantVectorStore( - client=client, - collection_name=self.get_index_name(self.dataset), - embeddings=self._embeddings, - content_payload_key='page_content', - group_id=self.dataset.id, - group_payload_key='group_id' - ) - - def _get_vector_store_class(self) -> type: - return QdrantVectorStore - - def delete_by_document_id(self, document_id: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="metadata.document_id", - match=models.MatchValue(value=document_id), - ), - ], - )) - - def delete_by_metadata_field(self, key: str, value: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key=f"metadata.{key}", - match=models.MatchValue(value=value), - ), - ], - )) - - def delete_by_ids(self, ids: list[str]) -> None: - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - for node_id in ids: - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="metadata.doc_id", - match=models.MatchValue(value=node_id), - ), - ], - )) - - def delete_by_group_id(self, group_id: str) -> None: - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=group_id), - ), - ], - )) - - def delete(self) -> None: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=self.dataset.id), - ), - ], - )) - - def _is_origin(self): - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - return True - - return False - - def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - return vector_store.similarity_search_by_bm25(models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=self.dataset.id), - ), - models.FieldCondition( - key="page_content", - match=models.MatchText(text=query), - ) - ], - ), kwargs.get('top_k', 2)) diff --git a/api/core/index/vector_index/vector_index.py b/api/core/index/vector_index/vector_index.py deleted file mode 100644 index ed6e2699d6..0000000000 --- a/api/core/index/vector_index/vector_index.py +++ /dev/null @@ -1,90 +0,0 @@ -import json - -from flask import current_app -from langchain.embeddings.base import Embeddings - -from core.index.vector_index.base import BaseVectorIndex -from extensions.ext_database import db -from models.dataset import Dataset, Document - - -class VectorIndex: - def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings, - attributes: list = None): - if attributes is None: - attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] - self._dataset = dataset - self._embeddings = embeddings - self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes) - self._attributes = attributes - - def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings, - attributes: list) -> BaseVectorIndex: - vector_type = config.get('VECTOR_STORE') - - if self._dataset.index_struct_dict: - vector_type = self._dataset.index_struct_dict['type'] - - if not vector_type: - raise ValueError("Vector store must be specified.") - - if vector_type == "weaviate": - from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex - - return WeaviateVectorIndex( - dataset=dataset, - config=WeaviateConfig( - endpoint=config.get('WEAVIATE_ENDPOINT'), - api_key=config.get('WEAVIATE_API_KEY'), - batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) - ), - embeddings=embeddings, - attributes=attributes - ) - elif vector_type == "qdrant": - from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex - - return QdrantVectorIndex( - dataset=dataset, - config=QdrantConfig( - endpoint=config.get('QDRANT_URL'), - api_key=config.get('QDRANT_API_KEY'), - root_path=current_app.root_path, - timeout=config.get('QDRANT_CLIENT_TIMEOUT') - ), - embeddings=embeddings - ) - elif vector_type == "milvus": - from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex - - return MilvusVectorIndex( - dataset=dataset, - config=MilvusConfig( - host=config.get('MILVUS_HOST'), - port=config.get('MILVUS_PORT'), - user=config.get('MILVUS_USER'), - password=config.get('MILVUS_PASSWORD'), - secure=config.get('MILVUS_SECURE'), - ), - embeddings=embeddings - ) - else: - raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") - - def add_texts(self, texts: list[Document], **kwargs): - if not self._dataset.index_struct_dict: - self._vector_index.create(texts, **kwargs) - self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct()) - db.session.commit() - return - - self._vector_index.add_texts(texts, **kwargs) - - def __getattr__(self, name): - if self._vector_index is not None: - method = getattr(self._vector_index, name) - if callable(method): - return method - - raise AttributeError(f"'VectorIndex' object has no attribute '{name}'") - diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py deleted file mode 100644 index 72a74a039f..0000000000 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ /dev/null @@ -1,179 +0,0 @@ -from typing import Any, Optional, cast - -import requests -import weaviate -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain.vectorstores import VectorStore -from pydantic import BaseModel, root_validator - -from core.index.base import BaseIndex -from core.index.vector_index.base import BaseVectorIndex -from core.vector_store.weaviate_vector_store import WeaviateVectorStore -from models.dataset import Dataset - - -class WeaviateConfig(BaseModel): - endpoint: str - api_key: Optional[str] - batch_size: int = 100 - - @root_validator() - def validate_config(cls, values: dict) -> dict: - if not values['endpoint']: - raise ValueError("config WEAVIATE_ENDPOINT is required") - return values - - -class WeaviateVectorIndex(BaseVectorIndex): - - def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list): - super().__init__(dataset, embeddings) - self._client = self._init_client(config) - self._attributes = attributes - - def _init_client(self, config: WeaviateConfig) -> weaviate.Client: - auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) - - weaviate.connect.connection.has_grpc = False - - try: - client = weaviate.Client( - url=config.endpoint, - auth_client_secret=auth_config, - timeout_config=(5, 60), - startup_period=None - ) - except requests.exceptions.ConnectionError: - raise ConnectionError("Vector database connection error") - - client.batch.configure( - # `batch_size` takes an `int` value to enable auto-batching - # (`None` is used for manual batching) - batch_size=config.batch_size, - # dynamically update the `batch_size` based on import speed - dynamic=True, - # `timeout_retries` takes an `int` value to retry on time outs - timeout_retries=3, - ) - - return client - - def get_type(self) -> str: - return 'weaviate' - - def get_index_name(self, dataset: Dataset) -> str: - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - class_prefix += '_Node' - - return class_prefix - - dataset_id = dataset.id - return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' - - def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self.get_index_name(self.dataset)} - } - - def create(self, texts: list[Document], **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = WeaviateVectorStore.from_documents( - texts, - self._embeddings, - client=self._client, - index_name=self.get_index_name(self.dataset), - uuids=uuids, - by_text=False - ) - - return self - - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = WeaviateVectorStore.from_documents( - texts, - self._embeddings, - client=self._client, - index_name=self.get_index_name(self.dataset), - uuids=uuids, - by_text=False - ) - - return self - - - def _get_vector_store(self) -> VectorStore: - """Only for created index.""" - if self._vector_store: - return self._vector_store - - attributes = self._attributes - if self._is_origin(): - attributes = ['doc_id'] - - return WeaviateVectorStore( - client=self._client, - index_name=self.get_index_name(self.dataset), - text_key='text', - embedding=self._embeddings, - attributes=attributes, - by_text=False - ) - - def _get_vector_store_class(self) -> type: - return WeaviateVectorStore - - def delete_by_document_id(self, document_id: str): - if self._is_origin(): - self.recreate_dataset(self.dataset) - return - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.del_texts({ - "operator": "Equal", - "path": ["document_id"], - "valueText": document_id - }) - - def delete_by_metadata_field(self, key: str, value: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.del_texts({ - "operator": "Equal", - "path": [key], - "valueText": value - }) - - def delete_by_group_id(self, group_id: str): - if self._is_origin(): - self.recreate_dataset(self.dataset) - return - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.delete() - - def _is_origin(self): - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - return True - - return False - - def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) - diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index a14001d04e..c8a2e09443 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,21 +9,21 @@ from typing import Optional, cast from flask import Flask, current_app from flask_login import current_user -from langchain.schema import Document -from langchain.text_splitter import TextSplitter from sqlalchemy.orm.exc import ObjectDeletedError -from core.data_loader.file_extractor import FileExtractor -from core.data_loader.loader.notion import NotionLoader from core.docstore.dataset_docstore import DatasetDocumentStore from core.errors.error import ProviderTokenNotInitError from core.generator.llm_generator import LLMGenerator -from core.index.index import IndexBuilder from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter +from core.splitter.text_splitter import TextSplitter from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -31,7 +31,7 @@ from libs import helper from models.dataset import Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile -from models.source import DataSourceBinding +from services.feature_service import FeatureService class IndexingRunner: @@ -56,38 +56,19 @@ class IndexingRunner: processing_rule = db.session.query(DatasetProcessRule). \ filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ first() + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + # extract + text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) - # load file - text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') + # transform + documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) + # save segment + self._load_segments(dataset, dataset_document, documents) - # get embedding model instance - embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': - if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - else: - embedding_model_instance = self.model_manager.get_default_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) - - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._step_split( - text_docs=text_docs, - splitter=splitter, - dataset=dataset, - dataset_document=dataset_document, - processing_rule=processing_rule - ) - self._build_index( + # load + self._load( + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents @@ -133,39 +114,19 @@ class IndexingRunner: filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ first() - # load file - text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + # extract + text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) - # get embedding model instance - embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': - if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - else: - embedding_model_instance = self.model_manager.get_default_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) + # transform + documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) + # save segment + self._load_segments(dataset, dataset_document, documents) - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._step_split( - text_docs=text_docs, - splitter=splitter, - dataset=dataset, - dataset_document=dataset_document, - processing_rule=processing_rule - ) - - # build index - self._build_index( + # load + self._load( + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents @@ -219,7 +180,15 @@ class IndexingRunner: documents.append(document) # build index - self._build_index( + # get the process rule + processing_rule = db.session.query(DatasetProcessRule). \ + filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ + first() + + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor() + self._load( + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents @@ -238,12 +207,20 @@ class IndexingRunner: dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() - def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, - indexing_technique: str = 'economy') -> dict: + def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, + doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, + indexing_technique: str = 'economy') -> dict: """ Estimate the indexing for the document. """ + # check document limit + features = FeatureService.get_features(tenant_id) + if features.billing.enabled: + count = len(extract_settings) + batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + embedding_model_instance = None if dataset_id: dataset = Dataset.query.filter_by( @@ -275,16 +252,18 @@ class IndexingRunner: total_segments = 0 total_price = 0 currency = 'USD' - for file_detail in file_details: - + index_type = doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + all_text_docs = [] + for extract_setting in extract_settings: + # extract + text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) - # load data from file - text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic') - # get splitter splitter = self._get_splitter(processing_rule, embedding_model_instance) @@ -296,7 +275,6 @@ class IndexingRunner: ) total_segments += len(documents) - for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) @@ -355,146 +333,8 @@ class IndexingRunner: "preview": preview_texts } - def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, - indexing_technique: str = 'economy') -> dict: - """ - Estimate the indexing for the document. - """ - embedding_model_instance = None - if dataset_id: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() - if not dataset: - raise ValueError('Dataset not found.') - if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': - if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( - tenant_id=tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - else: - embedding_model_instance = self.model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) - else: - if indexing_technique == 'high_quality': - embedding_model_instance = self.model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TEXT_EMBEDDING - ) - # load data from notion - tokens = 0 - preview_texts = [] - total_segments = 0 - total_price = 0 - currency = 'USD' - for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - data_source_binding = DataSourceBinding.query.filter( - db.and_( - DataSourceBinding.tenant_id == current_user.current_tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' - ) - ).first() - if not data_source_binding: - raise ValueError('Data source binding not found.') - - for page in notion_info['pages']: - loader = NotionLoader( - notion_access_token=data_source_binding.access_token, - notion_workspace_id=workspace_id, - notion_obj_id=page['page_id'], - notion_page_type=page['type'] - ) - documents = loader.load() - - processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], - rules=json.dumps(tmp_processing_rule["rules"]) - ) - - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._split_to_documents_for_estimate( - text_docs=documents, - splitter=splitter, - processing_rule=processing_rule - ) - total_segments += len(documents) - - embedding_model_type_instance = None - if embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) - - for document in documents: - if len(preview_texts) < 5: - preview_texts.append(document.page_content) - if indexing_technique == 'high_quality' and embedding_model_type_instance: - tokens += embedding_model_type_instance.get_num_tokens( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, - texts=[document.page_content] - ) - - if doc_form and doc_form == 'qa_model': - model_instance = self.model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM - ) - - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - if len(preview_texts) > 0: - # qa model document - response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], - doc_language) - document_qa_list = self.format_split_text(response) - - price_info = model_type_instance.get_price( - model=model_instance.model, - credentials=model_instance.credentials, - price_type=PriceType.INPUT, - tokens=total_segments * 2000, - ) - - return { - "total_segments": total_segments * 20, - "tokens": total_segments * 2000, - "total_price": '{:f}'.format(price_info.total_amount), - "currency": price_info.currency, - "qa_preview": document_qa_list, - "preview": preview_texts - } - if embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) - embedding_price_info = embedding_model_type_instance.get_price( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, - price_type=PriceType.INPUT, - tokens=tokens - ) - total_price = '{:f}'.format(embedding_price_info.total_amount) - currency = embedding_price_info.currency - return { - "total_segments": total_segments, - "tokens": tokens, - "total_price": total_price, - "currency": currency, - "preview": preview_texts - } - - def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]: + def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ + -> list[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import"]: return [] @@ -510,11 +350,27 @@ class IndexingRunner: one_or_none() if file_detail: - text_docs = FileExtractor.load(file_detail, is_automatic=automatic) + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file_detail, + document_model=dataset_document.doc_form + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) elif dataset_document.data_source_type == 'notion_import': - loader = NotionLoader.from_document(dataset_document) - text_docs = loader.load() - + if (not data_source_info or 'notion_workspace_id' not in data_source_info + or 'notion_page_id' not in data_source_info): + raise ValueError("no notion import info found") + extract_setting = ExtractSetting( + datasource_type="notion_import", + notion_info={ + "notion_workspace_id": data_source_info['notion_workspace_id'], + "notion_obj_id": data_source_info['notion_page_id'], + "notion_page_type": data_source_info['notion_page_type'], + "document": dataset_document + }, + document_model=dataset_document.doc_form + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, @@ -528,8 +384,6 @@ class IndexingRunner: # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - # remove invalid symbol - text_doc.page_content = self.filter_string(text_doc.page_content) text_doc.metadata['document_id'] = dataset_document.id text_doc.metadata['dataset_id'] = dataset_document.dataset_id @@ -770,12 +624,12 @@ class IndexingRunner: for q, a in matches if q and a ] - def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None: + def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, + dataset_document: DatasetDocument, documents: list[Document]) -> None: """ - Build the index for the document. + insert index and update document/segment status to completed """ - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - keyword_table_index = IndexBuilder.get_index(dataset, 'economy') + embedding_model_instance = None if dataset.indexing_technique == 'high_quality': embedding_model_instance = self.model_manager.get_model_instance( @@ -808,13 +662,8 @@ class IndexingRunner: ) for document in chunk_documents ) - - # save vector index - if vector_index: - vector_index.add_texts(chunk_documents) - - # save keyword index - keyword_table_index.add_texts(chunk_documents) + # load index + index_processor.load(dataset, chunk_documents) document_ids = [document.metadata['doc_id'] for document in chunk_documents] db.session.query(DocumentSegment).filter( @@ -894,14 +743,64 @@ class IndexingRunner: ) documents.append(document) # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts(documents, duplicate_check=True) + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, documents) - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - index.add_texts(documents) + def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, + text_docs: list[Document], process_rule: dict) -> list[Document]: + # get embedding model instance + embedding_model_instance = None + if dataset.indexing_technique == 'high_quality': + if dataset.embedding_model_provider: + embedding_model_instance = self.model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) + else: + embedding_model_instance = self.model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + + documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, + process_rule=process_rule) + + return documents + + def _load_segments(self, dataset, dataset_document, documents): + # save node to document segment + doc_store = DatasetDocumentStore( + dataset=dataset, + user_id=dataset_document.created_by, + document_id=dataset_document.id + ) + + # add document segments + doc_store.add_documents(documents) + + # update document status to indexing + cur_time = datetime.datetime.utcnow() + self._update_document_index_status( + document_id=dataset_document.id, + after_indexing_status="indexing", + extra_update_params={ + DatasetDocument.cleaning_completed_at: cur_time, + DatasetDocument.splitting_completed_at: cur_time, + } + ) + + # update segment status to indexing + self._update_segments_by_document( + dataset_document_id=dataset_document.id, + update_params={ + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: datetime.datetime.utcnow() + } + ) + pass class DocumentIsPausedException(Exception): diff --git a/api/core/rag/__init__.py b/api/core/rag/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py new file mode 100644 index 0000000000..eaad0e0f4c --- /dev/null +++ b/api/core/rag/cleaner/clean_processor.py @@ -0,0 +1,38 @@ +import re + + +class CleanProcessor: + + @classmethod + def clean(cls, text: str, process_rule: dict) -> str: + # default clean + # remove invalid symbol + text = re.sub(r'<\|', '<', text) + text = re.sub(r'\|>', '>', text) + text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + # Unicode U+FFFE + text = re.sub('\uFFFE', '', text) + + rules = process_rule['rules'] if process_rule else None + if 'pre_processing_rules' in rules: + pre_processing_rules = rules["pre_processing_rules"] + for pre_processing_rule in pre_processing_rules: + if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: + # Remove extra spaces + pattern = r'\n{3,}' + text = re.sub(pattern, '\n\n', text) + pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' + text = re.sub(pattern, ' ', text) + elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: + # Remove email + pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' + text = re.sub(pattern, '', text) + + # Remove URL + pattern = r'https?://[^\s]+' + text = re.sub(pattern, '', text) + return text + + def filter_string(self, text): + + return text diff --git a/api/core/rag/cleaner/cleaner_base.py b/api/core/rag/cleaner/cleaner_base.py new file mode 100644 index 0000000000..523bd904f2 --- /dev/null +++ b/api/core/rag/cleaner/cleaner_base.py @@ -0,0 +1,12 @@ +"""Abstract interface for document cleaner implementations.""" +from abc import ABC, abstractmethod + + +class BaseCleaner(ABC): + """Interface for clean chunk content. + """ + + @abstractmethod + def clean(self, content: str): + raise NotImplementedError + diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py new file mode 100644 index 0000000000..6a0b8c9046 --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py @@ -0,0 +1,12 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredNonAsciiCharsCleaner(BaseCleaner): + + def clean(self, content) -> str: + """clean document content.""" + from unstructured.cleaners.core import clean_extra_whitespace + + # Returns "ITEM 1A: RISK FACTORS" + return clean_extra_whitespace(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py new file mode 100644 index 0000000000..6fc3a408da --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py @@ -0,0 +1,15 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): + + def clean(self, content) -> str: + """clean document content.""" + import re + + from unstructured.cleaners.core import group_broken_paragraphs + + para_split_re = re.compile(r"(\s*\n\s*){3}") + + return group_broken_paragraphs(content, paragraph_split=para_split_re) diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py new file mode 100644 index 0000000000..ca1ae8dfd1 --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py @@ -0,0 +1,12 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredNonAsciiCharsCleaner(BaseCleaner): + + def clean(self, content) -> str: + """clean document content.""" + from unstructured.cleaners.core import clean_non_ascii_chars + + # Returns "This text containsnon-ascii characters!" + return clean_non_ascii_chars(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py new file mode 100644 index 0000000000..974a28fef1 --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py @@ -0,0 +1,11 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredNonAsciiCharsCleaner(BaseCleaner): + + def clean(self, content) -> str: + """Replaces unicode quote characters, such as the \x91 character in a string.""" + + from unstructured.cleaners.core import replace_unicode_quotes + return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py new file mode 100644 index 0000000000..dfaf3a2787 --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py @@ -0,0 +1,11 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredTranslateTextCleaner(BaseCleaner): + + def clean(self, content) -> str: + """clean document content.""" + from unstructured.cleaners.translate import translate_text + + return translate_text(content) diff --git a/api/core/rag/data_post_processor/__init__.py b/api/core/rag/data_post_processor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py new file mode 100644 index 0000000000..bdd69c27b1 --- /dev/null +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -0,0 +1,49 @@ +from typing import Optional + +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.models.document import Document +from core.rerank.rerank import RerankRunner + + +class DataPostProcessor: + """Interface for data post-processing document. + """ + + def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False): + self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id) + self.reorder_runner = self._get_reorder_runner(reorder_enabled) + + def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, + top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + if self.rerank_runner: + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) + + if self.reorder_runner: + documents = self.reorder_runner.run(documents) + + return documents + + def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]: + if reranking_model: + try: + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model['reranking_provider_name'], + model_type=ModelType.RERANK, + model=reranking_model['reranking_model_name'] + ) + except InvokeAuthorizationError: + return None + return RerankRunner(rerank_model_instance) + return None + + def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: + if reorder_enabled: + return ReorderRunner() + return None + + diff --git a/api/core/rag/data_post_processor/reorder.py b/api/core/rag/data_post_processor/reorder.py new file mode 100644 index 0000000000..71297588a4 --- /dev/null +++ b/api/core/rag/data_post_processor/reorder.py @@ -0,0 +1,18 @@ +from core.rag.models.document import Document + + +class ReorderRunner: + + def run(self, documents: list[Document]) -> list[Document]: + # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list + odd_elements = documents[::2] + + # Retrieve elements from even indices (1, 3, 5, etc.) of the documents list + even_elements = documents[1::2] + + # Reverse the list of elements from even indices + even_elements_reversed = even_elements[::-1] + + new_documents = odd_elements + even_elements_reversed + + return new_documents diff --git a/api/core/rag/datasource/__init__.py b/api/core/rag/datasource/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/entity/embedding.py b/api/core/rag/datasource/entity/embedding.py new file mode 100644 index 0000000000..126c1a3723 --- /dev/null +++ b/api/core/rag/datasource/entity/embedding.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + + +class Embeddings(ABC): + """Interface for embedding models.""" + + @abstractmethod + def embed_documents(self, texts: list[str]) -> list[list[float]]: + """Embed search docs.""" + + @abstractmethod + def embed_query(self, text: str) -> list[float]: + """Embed query text.""" + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + """Asynchronous Embed search docs.""" + raise NotImplementedError + + async def aembed_query(self, text: str) -> list[float]: + """Asynchronous Embed query text.""" + raise NotImplementedError diff --git a/api/core/rag/datasource/keyword/__init__.py b/api/core/rag/datasource/keyword/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/keyword/jieba/__init__.py b/api/core/rag/datasource/keyword/jieba/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/rag/datasource/keyword/jieba/jieba.py similarity index 71% rename from api/core/index/keyword_table_index/keyword_table_index.py rename to api/core/rag/datasource/keyword/jieba/jieba.py index 8bf0b13344..94a692637f 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -2,11 +2,11 @@ import json from collections import defaultdict from typing import Any, Optional -from langchain.schema import BaseRetriever, Document -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel -from core.index.base import BaseIndex -from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment @@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel): max_keywords_per_chunk: int = 10 -class KeywordTableIndex(BaseIndex): - def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): +class Jieba(BaseKeyword): + def __init__(self, dataset: Dataset): super().__init__(dataset) - self._config = config + self._config = KeywordTableConfig() - def create(self, texts: list[Document], **kwargs) -> BaseIndex: + def create(self, texts: list[Document], **kwargs) -> BaseKeyword: keyword_table_handler = JiebaKeywordTableHandler() - keyword_table = {} + keyword_table = self._get_dataset_keyword_table() for text in texts: keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) - dataset_keyword_table = DatasetKeywordTable( - dataset_id=self.dataset.id, - keyword_table=json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) - ) - db.session.add(dataset_keyword_table) - db.session.commit() - - self._save_dataset_keyword_table(keyword_table) - - return self - - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: - keyword_table_handler = JiebaKeywordTableHandler() - keyword_table = {} - for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) - - dataset_keyword_table = DatasetKeywordTable( - dataset_id=self.dataset.id, - keyword_table=json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) - ) - db.session.add(dataset_keyword_table) - db.session.commit() - self._save_dataset_keyword_table(keyword_table) return self @@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) + keywords_list = kwargs.get('keywords_list', None) + for i in range(len(texts)): + text = texts[i] + if keywords_list: + keywords = keywords_list[i] + else: + keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) @@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex): self._save_dataset_keyword_table(keyword_table) - def delete_by_metadata_field(self, key: str, value: str): - pass - - def get_retriever(self, **kwargs: Any) -> BaseRetriever: - return KeywordTableRetriever(index=self, **kwargs) - def search( self, query: str, **kwargs: Any ) -> list[Document]: keyword_table = self._get_dataset_keyword_table() - search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} - k = search_kwargs.get('k') if search_kwargs.get('k') else 4 + k = kwargs.get('top_k', 4) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) @@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex): db.session.delete(dataset_keyword_table) db.session.commit() - def delete_by_group_id(self, group_id: str) -> None: - dataset_keyword_table = self.dataset.dataset_keyword_table - if dataset_keyword_table: - db.session.delete(dataset_keyword_table) - db.session.commit() - def _save_dataset_keyword_table(self, keyword_table): keyword_table_dict = { '__type__': 'keyword_table', @@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex): ).first() if document_segment: document_segment.keywords = keywords + db.session.add(document_segment) db.session.commit() def create_segment_keywords(self, node_id: str, keywords: list[str]): @@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex): self._save_dataset_keyword_table(keyword_table) -class KeywordTableRetriever(BaseRetriever, BaseModel): - index: KeywordTableIndex - search_kwargs: dict = Field(default_factory=dict) - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - def get_relevant_documents(self, query: str) -> list[Document]: - """Get documents relevant for a query. - - Args: - query: string to find relevant documents for - - Returns: - List of relevant documents - """ - return self.index.search(query, **self.search_kwargs) - - async def aget_relevant_documents(self, query: str) -> list[Document]: - raise NotImplementedError("KeywordTableRetriever does not support async") - - class SetEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, set): diff --git a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py similarity index 91% rename from api/core/index/keyword_table_index/jieba_keyword_table_handler.py rename to api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index df93a1903a..5f862b8d18 100644 --- a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -3,7 +3,7 @@ import re import jieba from jieba.analyse import default_tfidf -from core.index.keyword_table_index.stopwords import STOPWORDS +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS class JiebaKeywordTableHandler: diff --git a/api/core/index/keyword_table_index/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py similarity index 100% rename from api/core/index/keyword_table_index/stopwords.py rename to api/core/rag/datasource/keyword/jieba/stopwords.py diff --git a/api/core/index/base.py b/api/core/rag/datasource/keyword/keyword_base.py similarity index 63% rename from api/core/index/base.py rename to api/core/rag/datasource/keyword/keyword_base.py index f8eb1a134a..84a5800855 100644 --- a/api/core/index/base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -3,22 +3,17 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any -from langchain.schema import BaseRetriever, Document - +from core.rag.models.document import Document from models.dataset import Dataset -class BaseIndex(ABC): +class BaseKeyword(ABC): def __init__(self, dataset: Dataset): self.dataset = dataset @abstractmethod - def create(self, texts: list[Document], **kwargs) -> BaseIndex: - raise NotImplementedError - - @abstractmethod - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: + def create(self, texts: list[Document], **kwargs) -> BaseKeyword: raise NotImplementedError @abstractmethod @@ -34,31 +29,18 @@ class BaseIndex(ABC): raise NotImplementedError @abstractmethod - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_document_id(self, document_id: str) -> None: raise NotImplementedError - @abstractmethod - def delete_by_group_id(self, group_id: str) -> None: + def delete(self) -> None: raise NotImplementedError - @abstractmethod - def delete_by_document_id(self, document_id: str): - raise NotImplementedError - - @abstractmethod - def get_retriever(self, **kwargs: Any) -> BaseRetriever: - raise NotImplementedError - - @abstractmethod def search( self, query: str, **kwargs: Any ) -> list[Document]: raise NotImplementedError - def delete(self) -> None: - raise NotImplementedError - def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts: doc_id = text.metadata['doc_id'] diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py new file mode 100644 index 0000000000..bccec20714 --- /dev/null +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -0,0 +1,60 @@ +from typing import Any, cast + +from flask import current_app + +from core.rag.datasource.keyword.jieba.jieba import Jieba +from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.models.document import Document +from models.dataset import Dataset + + +class Keyword: + def __init__(self, dataset: Dataset): + self._dataset = dataset + self._keyword_processor = self._init_keyword() + + def _init_keyword(self) -> BaseKeyword: + config = cast(dict, current_app.config) + keyword_type = config.get('KEYWORD_STORE') + + if not keyword_type: + raise ValueError("Keyword store must be specified.") + + if keyword_type == "jieba": + return Jieba( + dataset=self._dataset + ) + else: + raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + + def create(self, texts: list[Document], **kwargs): + self._keyword_processor.create(texts, **kwargs) + + def add_texts(self, texts: list[Document], **kwargs): + self._keyword_processor.add_texts(texts, **kwargs) + + def text_exists(self, id: str) -> bool: + return self._keyword_processor.text_exists(id) + + def delete_by_ids(self, ids: list[str]) -> None: + self._keyword_processor.delete_by_ids(ids) + + def delete_by_document_id(self, document_id: str) -> None: + self._keyword_processor.delete_by_document_id(document_id) + + def delete(self) -> None: + self._keyword_processor.delete() + + def search( + self, query: str, + **kwargs: Any + ) -> list[Document]: + return self._keyword_processor.search(query, **kwargs) + + def __getattr__(self, name): + if self._keyword_processor is not None: + method = getattr(self._keyword_processor, name) + if callable(method): + return method + + raise AttributeError(f"'Keyword' object has no attribute '{name}'") diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py new file mode 100644 index 0000000000..79673ffa83 --- /dev/null +++ b/api/core/rag/datasource/retrieval_service.py @@ -0,0 +1,165 @@ +import threading +from typing import Optional + +from flask import Flask, current_app +from flask_login import current_user + +from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.vdb.vector_factory import Vector +from extensions.ext_database import db +from models.dataset import Dataset + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class RetrievalService: + + @classmethod + def retrieve(cls, retrival_method: str, dataset_id: str, query: str, + top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): + all_documents = [] + threads = [] + # retrieval_model source with keyword + if retrival_method == 'keyword_search': + keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'top_k': top_k + }) + threads.append(keyword_thread) + keyword_thread.start() + # retrieval_model source with semantic + if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search': + embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'top_k': top_k, + 'score_threshold': score_threshold, + 'reranking_model': reranking_model, + 'all_documents': all_documents, + 'retrival_method': retrival_method + }) + threads.append(embedding_thread) + embedding_thread.start() + + # retrieval source with full text + if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search': + full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'retrival_method': retrival_method, + 'score_threshold': score_threshold, + 'top_k': top_k, + 'reranking_model': reranking_model, + 'all_documents': all_documents + }) + threads.append(full_text_index_thread) + full_text_index_thread.start() + + for thread in threads: + thread.join() + + if retrival_method == 'hybrid_search': + data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False) + all_documents = data_post_processor.invoke( + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=top_k + ) + return all_documents + + @classmethod + def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, + top_k: int, all_documents: list): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + + keyword = Keyword( + dataset=dataset + ) + + documents = keyword.search( + query, + k=top_k + ) + all_documents.extend(documents) + + @classmethod + def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, + top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], + all_documents: list, retrival_method: str): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + + vector = Vector( + dataset=dataset + ) + + documents = vector.search_by_vector( + query, + search_type='similarity_score_threshold', + k=top_k, + score_threshold=score_threshold, + filter={ + 'group_id': [dataset.id] + } + ) + + if documents: + if reranking_model and retrival_method == 'semantic_search': + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + all_documents.extend(data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents) + )) + else: + all_documents.extend(documents) + + @classmethod + def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, + top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], + all_documents: list, retrival_method: str): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + + vector_processor = Vector( + dataset=dataset, + ) + + documents = vector_processor.search_by_full_text( + query, + top_k=top_k + ) + if documents: + if reranking_model and retrival_method == 'full_text_search': + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + all_documents.extend(data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents) + )) + else: + all_documents.extend(documents) diff --git a/api/core/rag/datasource/vdb/__init__.py b/api/core/rag/datasource/vdb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py new file mode 100644 index 0000000000..6a594a83ca --- /dev/null +++ b/api/core/rag/datasource/vdb/field.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class Field(Enum): + CONTENT_KEY = "page_content" + METADATA_KEY = "metadata" + GROUP_KEY = "group_id" + VECTOR = "vector" + TEXT_KEY = "text" + PRIMARY_KEY = " id" diff --git a/api/core/rag/datasource/vdb/milvus/__init__.py b/api/core/rag/datasource/vdb/milvus/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py new file mode 100644 index 0000000000..9a251ede97 --- /dev/null +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -0,0 +1,214 @@ +import logging +from typing import Any, Optional +from uuid import uuid4 + +from pydantic import BaseModel, root_validator +from pymilvus import MilvusClient, MilvusException, connections + +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document + +logger = logging.getLogger(__name__) + + +class MilvusConfig(BaseModel): + host: str + port: int + user: str + password: str + secure: bool = False + batch_size: int = 100 + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values['host']: + raise ValueError("config MILVUS_HOST is required") + if not values['port']: + raise ValueError("config MILVUS_PORT is required") + if not values['user']: + raise ValueError("config MILVUS_USER is required") + if not values['password']: + raise ValueError("config MILVUS_PASSWORD is required") + return values + + def to_milvus_params(self): + return { + 'host': self.host, + 'port': self.port, + 'user': self.user, + 'password': self.password, + 'secure': self.secure + } + + +class MilvusVector(BaseVector): + + def __init__(self, collection_name: str, config: MilvusConfig): + super().__init__(collection_name) + self._client_config = config + self._client = self._init_client(config) + self._consistency_level = 'Session' + self._fields = [] + + def get_type(self) -> str: + return 'milvus' + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + index_params = { + 'metric_type': 'IP', + 'index_type': "HNSW", + 'params': {"M": 8, "efConstruction": 64} + } + metadatas = [d.metadata for d in texts] + + # Grab the existing collection if it exists + from pymilvus import utility + alias = uuid4().hex + if self._client_config.secure: + uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) + else: + uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) + connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) + if not utility.has_collection(self._collection_name, using=alias): + self.create_collection(embeddings, metadatas, index_params) + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + insert_dict_list = [] + for i in range(len(documents)): + insert_dict = { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], + Field.METADATA_KEY.value: documents[i].metadata + } + insert_dict_list.append(insert_dict) + # Total insert count + total_count = len(insert_dict_list) + + pks: list[str] = [] + + for i in range(0, total_count, 1000): + batch_insert_list = insert_dict_list[i:i + 1000] + # Insert into the collection. + try: + ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) + pks.extend(ids) + except MilvusException as e: + logger.error( + "Failed to insert batch starting at entity: %s/%s", i, total_count + ) + raise e + return pks + + def delete_by_document_id(self, document_id: str): + + ids = self.get_ids_by_metadata_field('document_id', document_id) + if ids: + self._client.delete(collection_name=self._collection_name, pks=ids) + + def get_ids_by_metadata_field(self, key: str, value: str): + result = self._client.query(collection_name=self._collection_name, + filter=f'metadata["{key}"] == "{value}"', + output_fields=["id"]) + if result: + return [item["id"] for item in result] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str): + + ids = self.get_ids_by_metadata_field(key, value) + if ids: + self._client.delete(collection_name=self._collection_name, pks=ids) + + def delete_by_ids(self, doc_ids: list[str]) -> None: + + self._client.delete(collection_name=self._collection_name, pks=doc_ids) + + def delete(self) -> None: + + from pymilvus import utility + utility.drop_collection(self._collection_name, None) + + def text_exists(self, id: str) -> bool: + + result = self._client.query(collection_name=self._collection_name, + filter=f'metadata["doc_id"] == "{id}"', + output_fields=["id"]) + + return len(result) > 0 + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + + # Set search parameters. + results = self._client.search(collection_name=self._collection_name, + data=[query_vector], + limit=kwargs.get('top_k', 4), + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + ) + # Organize results. + docs = [] + for result in results[0]: + metadata = result['entity'].get(Field.METADATA_KEY.value) + metadata['score'] = result['distance'] + score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + if result['distance'] > score_threshold: + doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value), + metadata=metadata) + docs.append(doc) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # milvus/zilliz doesn't support bm25 search + return [] + + def create_collection( + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + ) -> str: + from pymilvus import CollectionSchema, DataType, FieldSchema + from pymilvus.orm.types import infer_dtype_bydata + + # Determine embedding dim + dim = len(embeddings[0]) + fields = [] + if metadatas: + fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) + + # Create the text field + fields.append( + FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) + ) + # Create the primary key field + fields.append( + FieldSchema( + Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True + ) + ) + # Create the vector field, supports binary or float vectors + fields.append( + FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) + ) + + # Create the schema for the collection + schema = CollectionSchema(fields) + + for x in schema.fields: + self._fields.append(x.name) + # Since primary field is auto-id, no need to track it + self._fields.remove(Field.PRIMARY_KEY.value) + + # Create the collection + collection_name = self._collection_name + self._client.create_collection_with_schema(collection_name=collection_name, + schema=schema, index_param=index_params, + consistency_level=self._consistency_level) + return collection_name + + def _init_client(self, config) -> MilvusClient: + if config.secure: + uri = "https://" + str(config.host) + ":" + str(config.port) + else: + uri = "http://" + str(config.host) + ":" + str(config.port) + client = MilvusClient(uri=uri, user=config.user, password=config.password) + return client diff --git a/api/core/rag/datasource/vdb/qdrant/__init__.py b/api/core/rag/datasource/vdb/qdrant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py new file mode 100644 index 0000000000..2432931228 --- /dev/null +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -0,0 +1,360 @@ +import os +import uuid +from collections.abc import Generator, Iterable, Sequence +from itertools import islice +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import qdrant_client +from pydantic import BaseModel +from qdrant_client.http import models as rest +from qdrant_client.http.models import ( + FilterSelector, + HnswConfigDiff, + PayloadSchemaType, + TextIndexParams, + TextIndexType, + TokenizerType, +) +from qdrant_client.local.qdrant_local import QdrantLocal + +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document + +if TYPE_CHECKING: + from qdrant_client import grpc # noqa + from qdrant_client.conversions import common_types + from qdrant_client.http import models as rest + + DictFilter = dict[str, Union[str, int, bool, dict, list]] + MetadataFilter = Union[DictFilter, common_types.Filter] + + +class QdrantConfig(BaseModel): + endpoint: str + api_key: Optional[str] + timeout: float = 20 + root_path: Optional[str] + + def to_qdrant_params(self): + if self.endpoint and self.endpoint.startswith('path:'): + path = self.endpoint.replace('path:', '') + if not os.path.isabs(path): + path = os.path.join(self.root_path, path) + + return { + 'path': path + } + else: + return { + 'url': self.endpoint, + 'api_key': self.api_key, + 'timeout': self.timeout + } + + +class QdrantVector(BaseVector): + + def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'): + super().__init__(collection_name) + self._client_config = config + self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) + self._distance_func = distance_func.upper() + self._group_id = group_id + + def get_type(self) -> str: + return 'qdrant' + + def to_index_struct(self) -> dict: + return { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name} + } + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + if texts: + # get embedding vector size + vector_size = len(embeddings[0]) + # get collection name + collection_name = self._collection_name + collection_name = collection_name or uuid.uuid4().hex + all_collection_name = [] + collections_response = self._client.get_collections() + collection_list = collections_response.collections + for collection in collection_list: + all_collection_name.append(collection.name) + if collection_name not in all_collection_name: + # create collection + self.create_collection(collection_name, vector_size) + + self.add_texts(texts, embeddings, **kwargs) + + def create_collection(self, collection_name: str, vector_size: int): + from qdrant_client.http import models as rest + vectors_config = rest.VectorParams( + size=vector_size, + distance=rest.Distance[self._distance_func], + ) + hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, + max_indexing_threads=0, on_disk=False) + self._client.recreate_collection( + collection_name=collection_name, + vectors_config=vectors_config, + hnsw_config=hnsw_config, + timeout=int(self._client_config.timeout), + ) + + # create payload index + self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, + field_schema=PayloadSchemaType.KEYWORD, + field_type=PayloadSchemaType.KEYWORD) + # creat full text index + text_index_params = TextIndexParams( + type=TextIndexType.TEXT, + tokenizer=TokenizerType.MULTILINGUAL, + min_token_len=2, + max_token_len=20, + lowercase=True + ) + self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, + field_schema=text_index_params) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + added_ids = [] + for batch_ids, points in self._generate_rest_batches( + texts, embeddings, metadatas, uuids, 64, self._group_id + ): + self._client.upsert( + collection_name=self._collection_name, points=points + ) + added_ids.extend(batch_ids) + + return added_ids + + def _generate_rest_batches( + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, + ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: + from qdrant_client.http import models as rest + texts_iterator = iter(texts) + embeddings_iterator = iter(embeddings) + metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) + while batch_texts := list(islice(texts_iterator, batch_size)): + # Take the corresponding metadata and id for each text in a batch + batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + batch_ids = list(islice(ids_iterator, batch_size)) + + # Generate the embeddings for all the texts in a batch + batch_embeddings = list(islice(embeddings_iterator, batch_size)) + + points = [ + rest.PointStruct( + id=point_id, + vector=vector, + payload=payload, + ) + for point_id, vector, payload in zip( + batch_ids, + batch_embeddings, + self._build_payloads( + batch_texts, + batch_metadatas, + Field.CONTENT_KEY.value, + Field.METADATA_KEY.value, + group_id, + Field.GROUP_KEY.value, + ), + ) + ] + + yield batch_ids, points + + @classmethod + def _build_payloads( + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str + ) -> list[dict]: + payloads = [] + for i, text in enumerate(texts): + if text is None: + raise ValueError( + "At least one of the texts is None. Please remove it before " + "calling .from_texts or .add_texts on Qdrant instance." + ) + metadata = metadatas[i] if metadatas is not None else None + payloads.append( + { + content_payload_key: text, + metadata_payload_key: metadata, + group_payload_key: group_id + } + ) + + return payloads + + def delete_by_metadata_field(self, key: str, value: str): + + from qdrant_client.http import models + + filter = models.Filter( + must=[ + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ], + ) + + self._reload_if_needed() + + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector( + filter=filter + ), + ) + + def delete(self): + from qdrant_client.http import models + filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector( + filter=filter + ), + ) + + def delete_by_ids(self, ids: list[str]) -> None: + + from qdrant_client.http import models + for node_id in ids: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchValue(value=node_id), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector( + filter=filter + ), + ) + + def text_exists(self, id: str) -> bool: + response = self._client.retrieve( + collection_name=self._collection_name, + ids=[id] + ) + + return len(response) > 0 + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + from qdrant_client.http import models + filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + ], + ) + results = self._client.search( + collection_name=self._collection_name, + query_vector=query_vector, + query_filter=filter, + limit=kwargs.get("top_k", 4), + with_payload=True, + with_vectors=True, + score_threshold=kwargs.get("score_threshold", .0) + ) + docs = [] + for result in results: + metadata = result.payload.get(Field.METADATA_KEY.value) or {} + # duplicate check score threshold + score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + if result.score > score_threshold: + metadata['score'] = result.score + doc = Document( + page_content=result.payload.get(Field.CONTENT_KEY.value), + metadata=metadata, + ) + docs.append(doc) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Return docs most similar by bm25. + Returns: + List of documents most similar to the query text and distance for each. + """ + from qdrant_client.http import models + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + models.FieldCondition( + key="page_content", + match=models.MatchText(text=query), + ) + ] + ) + response = self._client.scroll( + collection_name=self._collection_name, + scroll_filter=scroll_filter, + limit=kwargs.get('top_k', 2), + with_payload=True, + with_vectors=True + + ) + results = response[0] + documents = [] + for result in results: + if result: + documents.append(self._document_from_scored_point( + result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value + )) + + return documents + + def _reload_if_needed(self): + if isinstance(self._client, QdrantLocal): + self._client = cast(QdrantLocal, self._client) + self._client._load() + + @classmethod + def _document_from_scored_point( + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, + ) -> Document: + return Document( + page_content=scored_point.payload.get(content_payload_key), + metadata=scored_point.payload.get(metadata_payload_key) or {}, + ) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py new file mode 100644 index 0000000000..69ed4ed51c --- /dev/null +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +from core.rag.models.document import Document + + +class BaseVector(ABC): + + def __init__(self, collection_name: str): + self._collection_name = collection_name + + @abstractmethod + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + raise NotImplementedError + + @abstractmethod + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + raise NotImplementedError + + @abstractmethod + def text_exists(self, id: str) -> bool: + raise NotImplementedError + + @abstractmethod + def delete_by_ids(self, ids: list[str]) -> None: + raise NotImplementedError + + @abstractmethod + def delete_by_metadata_field(self, key: str, value: str) -> None: + raise NotImplementedError + + @abstractmethod + def search_by_vector( + self, + query_vector: list[float], + **kwargs: Any + ) -> list[Document]: + raise NotImplementedError + + @abstractmethod + def search_by_full_text( + self, query: str, + **kwargs: Any + ) -> list[Document]: + raise NotImplementedError + + def delete(self) -> None: + raise NotImplementedError + + def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: + for text in texts: + doc_id = text.metadata['doc_id'] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) + + return texts + + def _get_uuids(self, texts: list[Document]) -> list[str]: + return [text.metadata['doc_id'] for text in texts] diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py new file mode 100644 index 0000000000..dd8fc93041 --- /dev/null +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -0,0 +1,171 @@ +from typing import Any, cast + +from flask import current_app + +from core.embedding.cached_embedding import CacheEmbedding +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from extensions.ext_database import db +from models.dataset import Dataset, DatasetCollectionBinding + + +class Vector: + def __init__(self, dataset: Dataset, attributes: list = None): + if attributes is None: + attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self._dataset = dataset + self._embeddings = self._get_embeddings() + self._attributes = attributes + self._vector_processor = self._init_vector() + + def _init_vector(self) -> BaseVector: + config = cast(dict, current_app.config) + vector_type = config.get('VECTOR_STORE') + + if self._dataset.index_struct_dict: + vector_type = self._dataset.index_struct_dict['type'] + + if not vector_type: + raise ValueError("Vector store must be specified.") + + if vector_type == "weaviate": + from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = self._dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + return WeaviateVector( + collection_name=collection_name, + config=WeaviateConfig( + endpoint=config.get('WEAVIATE_ENDPOINT'), + api_key=config.get('WEAVIATE_API_KEY'), + batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) + ), + attributes=self._attributes + ) + elif vector_type == "qdrant": + from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector + if self._dataset.collection_binding_id: + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \ + one_or_none() + if dataset_collection_binding: + collection_name = dataset_collection_binding.collection_name + else: + raise ValueError('Dataset Collection Bindings is not exist!') + else: + if self._dataset.index_struct_dict: + class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = self._dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + + return QdrantVector( + collection_name=collection_name, + group_id=self._dataset.id, + config=QdrantConfig( + endpoint=config.get('QDRANT_URL'), + api_key=config.get('QDRANT_API_KEY'), + root_path=current_app.root_path, + timeout=config.get('QDRANT_CLIENT_TIMEOUT') + ) + ) + elif vector_type == "milvus": + from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = self._dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + return MilvusVector( + collection_name=collection_name, + config=MilvusConfig( + host=config.get('MILVUS_HOST'), + port=config.get('MILVUS_PORT'), + user=config.get('MILVUS_USER'), + password=config.get('MILVUS_PASSWORD'), + secure=config.get('MILVUS_SECURE'), + ) + ) + else: + raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + + def create(self, texts: list = None, **kwargs): + if texts: + embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) + self._vector_processor.create( + texts=texts, + embeddings=embeddings, + **kwargs + ) + + def add_texts(self, documents: list[Document], **kwargs): + if kwargs.get('duplicate_check', False): + documents = self._filter_duplicate_texts(documents) + embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) + self._vector_processor.add_texts( + documents=documents, + embeddings=embeddings, + **kwargs + ) + + def text_exists(self, id: str) -> bool: + return self._vector_processor.text_exists(id) + + def delete_by_ids(self, ids: list[str]) -> None: + self._vector_processor.delete_by_ids(ids) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + self._vector_processor.delete_by_metadata_field(key, value) + + def search_by_vector( + self, query: str, + **kwargs: Any + ) -> list[Document]: + query_vector = self._embeddings.embed_query(query) + return self._vector_processor.search_by_vector(query_vector, **kwargs) + + def search_by_full_text( + self, query: str, + **kwargs: Any + ) -> list[Document]: + return self._vector_processor.search_by_full_text(query, **kwargs) + + def delete(self) -> None: + self._vector_processor.delete() + + def _get_embeddings(self) -> Embeddings: + model_manager = ModelManager() + + embedding_model = model_manager.get_model_instance( + tenant_id=self._dataset.tenant_id, + provider=self._dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=self._dataset.embedding_model + + ) + return CacheEmbedding(embedding_model) + + def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: + for text in texts: + doc_id = text.metadata['doc_id'] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) + + return texts + + def __getattr__(self, name): + if self._vector_processor is not None: + method = getattr(self._vector_processor, name) + if callable(method): + return method + + raise AttributeError(f"'vector_processor' object has no attribute '{name}'") diff --git a/api/core/rag/datasource/vdb/weaviate/__init__.py b/api/core/rag/datasource/vdb/weaviate/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py new file mode 100644 index 0000000000..5c3a810fbf --- /dev/null +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -0,0 +1,235 @@ +import datetime +from typing import Any, Optional + +import requests +import weaviate +from pydantic import BaseModel, root_validator + +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from models.dataset import Dataset + + +class WeaviateConfig(BaseModel): + endpoint: str + api_key: Optional[str] + batch_size: int = 100 + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values['endpoint']: + raise ValueError("config WEAVIATE_ENDPOINT is required") + return values + + +class WeaviateVector(BaseVector): + + def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): + super().__init__(collection_name) + self._client = self._init_client(config) + self._attributes = attributes + + def _init_client(self, config: WeaviateConfig) -> weaviate.Client: + auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) + + weaviate.connect.connection.has_grpc = False + + try: + client = weaviate.Client( + url=config.endpoint, + auth_client_secret=auth_config, + timeout_config=(5, 60), + startup_period=None + ) + except requests.exceptions.ConnectionError: + raise ConnectionError("Vector database connection error") + + client.batch.configure( + # `batch_size` takes an `int` value to enable auto-batching + # (`None` is used for manual batching) + batch_size=config.batch_size, + # dynamically update the `batch_size` based on import speed + dynamic=True, + # `timeout_retries` takes an `int` value to retry on time outs + timeout_retries=3, + ) + + return client + + def get_type(self) -> str: + return 'weaviate' + + def get_collection_name(self, dataset: Dataset) -> str: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + if not class_prefix.endswith('_Node'): + # original class_prefix + class_prefix += '_Node' + + return class_prefix + + dataset_id = dataset.id + return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + + def to_index_struct(self) -> dict: + return { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name} + } + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + + schema = self._default_schema(self._collection_name) + + # check whether the index already exists + if not self._client.schema.contains(schema): + # create collection + self._client.schema.create_class(schema) + # create vector + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + ids = [] + + with self._client.batch as batch: + for i, text in enumerate(texts): + data_properties = {Field.TEXT_KEY.value: text} + if metadatas is not None: + for key, val in metadatas[i].items(): + data_properties[key] = self._json_serializable(val) + + batch.add_data_object( + data_object=data_properties, + class_name=self._collection_name, + uuid=uuids[i], + vector=embeddings[i] if embeddings else None, + ) + ids.append(uuids[i]) + return ids + + def delete_by_metadata_field(self, key: str, value: str): + + where_filter = { + "operator": "Equal", + "path": [key], + "valueText": value + } + + self._client.batch.delete_objects( + class_name=self._collection_name, + where=where_filter, + output='minimal' + ) + + def delete(self): + self._client.schema.delete_class(self._collection_name) + + def text_exists(self, id: str) -> bool: + collection_name = self._collection_name + result = self._client.query.get(collection_name).with_additional(["id"]).with_where({ + "path": ["doc_id"], + "operator": "Equal", + "valueText": id, + }).with_limit(1).do() + + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + + entries = result["data"]["Get"][collection_name] + if len(entries) == 0: + return False + + return True + + def delete_by_ids(self, ids: list[str]) -> None: + self._client.data_object.delete( + ids, + class_name=self._collection_name + ) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """Look up similar documents by embedding vector in Weaviate.""" + collection_name = self._collection_name + properties = self._attributes + properties.append(Field.TEXT_KEY.value) + query_obj = self._client.query.get(collection_name, properties) + + vector = {"vector": query_vector} + if kwargs.get("where_filter"): + query_obj = query_obj.with_where(kwargs.get("where_filter")) + result = ( + query_obj.with_near_vector(vector) + .with_limit(kwargs.get("top_k", 4)) + .with_additional(["vector", "distance"]) + .do() + ) + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + + docs_and_scores = [] + for res in result["data"]["Get"][collection_name]: + text = res.pop(Field.TEXT_KEY.value) + score = 1 - res["_additional"]["distance"] + docs_and_scores.append((Document(page_content=text, metadata=res), score)) + + docs = [] + for doc, score in docs_and_scores: + score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + # check score threshold + if score > score_threshold: + doc.metadata['score'] = score + docs.append(doc) + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Return docs using BM25F. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query. + """ + collection_name = self._collection_name + content: dict[str, Any] = {"concepts": [query]} + properties = self._attributes + properties.append(Field.TEXT_KEY.value) + if kwargs.get("search_distance"): + content["certainty"] = kwargs.get("search_distance") + query_obj = self._client.query.get(collection_name, properties) + if kwargs.get("where_filter"): + query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("additional"): + query_obj = query_obj.with_additional(kwargs.get("additional")) + properties = ['text'] + result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + docs = [] + for res in result["data"]["Get"][collection_name]: + text = res.pop(Field.TEXT_KEY.value) + docs.append(Document(page_content=text, metadata=res)) + return docs + + def _default_schema(self, index_name: str) -> dict: + return { + "class": index_name, + "properties": [ + { + "name": "text", + "dataType": ["text"], + } + ], + } + + def _json_serializable(self, value: Any) -> Any: + if isinstance(value, datetime.datetime): + return value.isoformat() + return value diff --git a/api/core/rag/extractor/blod/blod.py b/api/core/rag/extractor/blod/blod.py new file mode 100644 index 0000000000..368946b5e4 --- /dev/null +++ b/api/core/rag/extractor/blod/blod.py @@ -0,0 +1,166 @@ +"""Schema for Blobs and Blob Loaders. + +The goal is to facilitate decoupling of content loading from content parsing code. + +In addition, content loading code should provide a lazy loading interface by default. +""" +from __future__ import annotations + +import contextlib +import mimetypes +from abc import ABC, abstractmethod +from collections.abc import Generator, Iterable, Mapping +from io import BufferedReader, BytesIO +from pathlib import PurePath +from typing import Any, Optional, Union + +from pydantic import BaseModel, root_validator + +PathLike = Union[str, PurePath] + + +class Blob(BaseModel): + """A blob is used to represent raw data by either reference or value. + + Provides an interface to materialize the blob in different representations, and + help to decouple the development of data loaders from the downstream parsing of + the raw data. + + Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob + """ + + data: Union[bytes, str, None] # Raw data + mimetype: Optional[str] = None # Not to be confused with a file extension + encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string + # Location where the original content was found + # Represent location on the local file system + # Useful for situations where downstream code assumes it must work with file paths + # rather than in-memory content. + path: Optional[PathLike] = None + + class Config: + arbitrary_types_allowed = True + frozen = True + + @property + def source(self) -> Optional[str]: + """The source location of the blob as string if known otherwise none.""" + return str(self.path) if self.path else None + + @root_validator(pre=True) + def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]: + """Verify that either data or path is provided.""" + if "data" not in values and "path" not in values: + raise ValueError("Either data or path must be provided") + return values + + def as_string(self) -> str: + """Read data as a string.""" + if self.data is None and self.path: + with open(str(self.path), encoding=self.encoding) as f: + return f.read() + elif isinstance(self.data, bytes): + return self.data.decode(self.encoding) + elif isinstance(self.data, str): + return self.data + else: + raise ValueError(f"Unable to get string for blob {self}") + + def as_bytes(self) -> bytes: + """Read data as bytes.""" + if isinstance(self.data, bytes): + return self.data + elif isinstance(self.data, str): + return self.data.encode(self.encoding) + elif self.data is None and self.path: + with open(str(self.path), "rb") as f: + return f.read() + else: + raise ValueError(f"Unable to get bytes for blob {self}") + + @contextlib.contextmanager + def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]: + """Read data as a byte stream.""" + if isinstance(self.data, bytes): + yield BytesIO(self.data) + elif self.data is None and self.path: + with open(str(self.path), "rb") as f: + yield f + else: + raise NotImplementedError(f"Unable to convert blob {self}") + + @classmethod + def from_path( + cls, + path: PathLike, + *, + encoding: str = "utf-8", + mime_type: Optional[str] = None, + guess_type: bool = True, + ) -> Blob: + """Load the blob from a path like object. + + Args: + path: path like object to file to be read + encoding: Encoding to use if decoding the bytes into a string + mime_type: if provided, will be set as the mime-type of the data + guess_type: If True, the mimetype will be guessed from the file extension, + if a mime-type was not provided + + Returns: + Blob instance + """ + if mime_type is None and guess_type: + _mimetype = mimetypes.guess_type(path)[0] if guess_type else None + else: + _mimetype = mime_type + # We do not load the data immediately, instead we treat the blob as a + # reference to the underlying data. + return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path) + + @classmethod + def from_data( + cls, + data: Union[str, bytes], + *, + encoding: str = "utf-8", + mime_type: Optional[str] = None, + path: Optional[str] = None, + ) -> Blob: + """Initialize the blob from in-memory data. + + Args: + data: the in-memory data associated with the blob + encoding: Encoding to use if decoding the bytes into a string + mime_type: if provided, will be set as the mime-type of the data + path: if provided, will be set as the source from which the data came + + Returns: + Blob instance + """ + return cls(data=data, mimetype=mime_type, encoding=encoding, path=path) + + def __repr__(self) -> str: + """Define the blob representation.""" + str_repr = f"Blob {id(self)}" + if self.source: + str_repr += f" {self.source}" + return str_repr + + +class BlobLoader(ABC): + """Abstract interface for blob loaders implementation. + + Implementer should be able to load raw content from a datasource system according + to some criteria and return the raw content lazily as a stream of blobs. + """ + + @abstractmethod + def yield_blobs( + self, + ) -> Iterable[Blob]: + """A lazy loader for raw data represented by LangChain's Blob object. + + Returns: + A generator over blobs + """ diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py new file mode 100644 index 0000000000..c391d7ae66 --- /dev/null +++ b/api/core/rag/extractor/csv_extractor.py @@ -0,0 +1,71 @@ +"""Abstract interface for document loader implementations.""" +import csv +from typing import Optional + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document + + +class CSVExtractor(BaseExtractor): + """Load CSV files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, + source_column: Optional[str] = None, + csv_args: Optional[dict] = None, + ): + """Initialize with file path.""" + self._file_path = file_path + self._encoding = encoding + self._autodetect_encoding = autodetect_encoding + self.source_column = source_column + self.csv_args = csv_args or {} + + def extract(self) -> list[Document]: + """Load data into document objects.""" + try: + with open(self._file_path, newline="", encoding=self._encoding) as csvfile: + docs = self._read_from_file(csvfile) + except UnicodeDecodeError as e: + if self._autodetect_encoding: + detected_encodings = detect_filze_encodings(self._file_path) + for encoding in detected_encodings: + try: + with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: + docs = self._read_from_file(csvfile) + break + except UnicodeDecodeError: + continue + else: + raise RuntimeError(f"Error loading {self._file_path}") from e + + return docs + + def _read_from_file(self, csvfile) -> list[Document]: + docs = [] + csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore + for i, row in enumerate(csv_reader): + content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) + try: + source = ( + row[self.source_column] + if self.source_column is not None + else '' + ) + except KeyError: + raise ValueError( + f"Source column '{self.source_column}' not found in CSV file." + ) + metadata = {"source": source, "row": i} + doc = Document(page_content=content, metadata=metadata) + docs.append(doc) + + return docs diff --git a/api/core/rag/extractor/entity/datasource_type.py b/api/core/rag/extractor/entity/datasource_type.py new file mode 100644 index 0000000000..2c79e7b97b --- /dev/null +++ b/api/core/rag/extractor/entity/datasource_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class DatasourceType(Enum): + FILE = "upload_file" + NOTION = "notion_import" diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py new file mode 100644 index 0000000000..bc5310f7be --- /dev/null +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -0,0 +1,36 @@ +from pydantic import BaseModel + +from models.dataset import Document +from models.model import UploadFile + + +class NotionInfo(BaseModel): + """ + Notion import info. + """ + notion_workspace_id: str + notion_obj_id: str + notion_page_type: str + document: Document = None + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **data) -> None: + super().__init__(**data) + + +class ExtractSetting(BaseModel): + """ + Model class for provider response. + """ + datasource_type: str + upload_file: UploadFile = None + notion_info: NotionInfo = None + document_model: str = None + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **data) -> None: + super().__init__(**data) diff --git a/api/core/data_loader/loader/excel.py b/api/core/rag/extractor/excel_extractor.py similarity index 66% rename from api/core/data_loader/loader/excel.py rename to api/core/rag/extractor/excel_extractor.py index cddb298547..532391048b 100644 --- a/api/core/data_loader/loader/excel.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,14 +1,14 @@ -import logging +"""Abstract interface for document loader implementations.""" +from typing import Optional -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document from openpyxl.reader.excel import load_workbook -logger = logging.getLogger(__name__) +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document -class ExcelLoader(BaseLoader): - """Load xlxs files. +class ExcelExtractor(BaseExtractor): + """Load Excel files. Args: @@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader): """ def __init__( - self, - file_path: str + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False ): """Initialize with file path.""" self._file_path = file_path + self._encoding = encoding + self._autodetect_encoding = autodetect_encoding - def load(self) -> list[Document]: + def extract(self) -> list[Document]: + """Load from file path.""" data = [] keys = [] wb = load_workbook(filename=self._file_path, read_only=True) diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py new file mode 100644 index 0000000000..7c7dc5bdae --- /dev/null +++ b/api/core/rag/extractor/extract_processor.py @@ -0,0 +1,139 @@ +import tempfile +from pathlib import Path +from typing import Union + +import requests +from flask import current_app + +from core.rag.extractor.csv_extractor import CSVExtractor +from core.rag.extractor.entity.datasource_type import DatasourceType +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.excel_extractor import ExcelExtractor +from core.rag.extractor.html_extractor import HtmlExtractor +from core.rag.extractor.markdown_extractor import MarkdownExtractor +from core.rag.extractor.notion_extractor import NotionExtractor +from core.rag.extractor.pdf_extractor import PdfExtractor +from core.rag.extractor.text_extractor import TextExtractor +from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor +from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor +from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor +from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor +from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor +from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor +from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor +from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor +from core.rag.extractor.word_extractor import WordExtractor +from core.rag.models.document import Document +from extensions.ext_storage import storage +from models.model import UploadFile + +SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] +USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + + +class ExtractProcessor: + @classmethod + def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \ + -> Union[list[Document], str]: + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=upload_file, + document_model='text_model' + ) + if return_text: + delimiter = '\n' + return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)]) + else: + return cls.extract(extract_setting, is_automatic) + + @classmethod + def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: + response = requests.get(url, headers={ + "User-Agent": USER_AGENT + }) + + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(url).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + with open(file_path, 'wb') as file: + file.write(response.content) + extract_setting = ExtractSetting( + datasource_type="upload_file", + document_model='text_model' + ) + if return_text: + delimiter = '\n' + return delimiter.join([document.page_content for document in cls.extract( + extract_setting=extract_setting, file_path=file_path)]) + else: + return cls.extract(extract_setting=extract_setting, file_path=file_path) + + @classmethod + def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, + file_path: str = None) -> list[Document]: + if extract_setting.datasource_type == DatasourceType.FILE.value: + with tempfile.TemporaryDirectory() as temp_dir: + if not file_path: + upload_file: UploadFile = extract_setting.upload_file + suffix = Path(upload_file.key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + storage.download(upload_file.key, file_path) + input_file = Path(file_path) + file_extension = input_file.suffix.lower() + etl_type = current_app.config['ETL_TYPE'] + unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] + if etl_type == 'Unstructured': + if file_extension == '.xlsx': + extractor = ExcelExtractor(file_path) + elif file_extension == '.pdf': + extractor = PdfExtractor(file_path) + elif file_extension in ['.md', '.markdown']: + extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \ + else MarkdownExtractor(file_path, autodetect_encoding=True) + elif file_extension in ['.htm', '.html']: + extractor = HtmlExtractor(file_path) + elif file_extension in ['.docx']: + extractor = UnstructuredWordExtractor(file_path, unstructured_api_url) + elif file_extension == '.csv': + extractor = CSVExtractor(file_path, autodetect_encoding=True) + elif file_extension == '.msg': + extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url) + elif file_extension == '.eml': + extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) + elif file_extension == '.ppt': + extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url) + elif file_extension == '.pptx': + extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) + elif file_extension == '.xml': + extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) + else: + # txt + extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ + else TextExtractor(file_path, autodetect_encoding=True) + else: + if file_extension == '.xlsx': + extractor = ExcelExtractor(file_path) + elif file_extension == '.pdf': + extractor = PdfExtractor(file_path) + elif file_extension in ['.md', '.markdown']: + extractor = MarkdownExtractor(file_path, autodetect_encoding=True) + elif file_extension in ['.htm', '.html']: + extractor = HtmlExtractor(file_path) + elif file_extension in ['.docx']: + extractor = WordExtractor(file_path) + elif file_extension == '.csv': + extractor = CSVExtractor(file_path, autodetect_encoding=True) + else: + # txt + extractor = TextExtractor(file_path, autodetect_encoding=True) + return extractor.extract() + elif extract_setting.datasource_type == DatasourceType.NOTION.value: + extractor = NotionExtractor( + notion_workspace_id=extract_setting.notion_info.notion_workspace_id, + notion_obj_id=extract_setting.notion_info.notion_obj_id, + notion_page_type=extract_setting.notion_info.notion_page_type, + document_model=extract_setting.notion_info.document + ) + return extractor.extract() + else: + raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}") diff --git a/api/core/rag/extractor/extractor_base.py b/api/core/rag/extractor/extractor_base.py new file mode 100644 index 0000000000..c490e59332 --- /dev/null +++ b/api/core/rag/extractor/extractor_base.py @@ -0,0 +1,12 @@ +"""Abstract interface for document loader implementations.""" +from abc import ABC, abstractmethod + + +class BaseExtractor(ABC): + """Interface for extract files. + """ + + @abstractmethod + def extract(self): + raise NotImplementedError + diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py new file mode 100644 index 0000000000..0c17a47b32 --- /dev/null +++ b/api/core/rag/extractor/helpers.py @@ -0,0 +1,46 @@ +"""Document loader helpers.""" + +import concurrent.futures +from typing import NamedTuple, Optional, cast + + +class FileEncoding(NamedTuple): + """A file encoding as the NamedTuple.""" + + encoding: Optional[str] + """The encoding of the file.""" + confidence: float + """The confidence of the encoding.""" + language: Optional[str] + """The language of the file.""" + + +def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]: + """Try to detect the file encoding. + + Returns a list of `FileEncoding` tuples with the detected encodings ordered + by confidence. + + Args: + file_path: The path to the file to detect the encoding for. + timeout: The timeout in seconds for the encoding detection. + """ + import chardet + + def read_and_detect(file_path: str) -> list[dict]: + with open(file_path, "rb") as f: + rawdata = f.read() + return cast(list[dict], chardet.detect_all(rawdata)) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(read_and_detect, file_path) + try: + encodings = future.result(timeout=timeout) + except concurrent.futures.TimeoutError: + raise TimeoutError( + f"Timeout reached while detecting encoding for {file_path}" + ) + + if all(encoding["encoding"] is None for encoding in encodings): + raise RuntimeError(f"Could not detect encoding for {file_path}") + return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None] diff --git a/api/core/data_loader/loader/csv_loader.py b/api/core/rag/extractor/html_extractor.py similarity index 61% rename from api/core/data_loader/loader/csv_loader.py rename to api/core/rag/extractor/html_extractor.py index ce252c157e..557ea42b19 100644 --- a/api/core/data_loader/loader/csv_loader.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,51 +1,55 @@ -import csv -import logging +"""Abstract interface for document loader implementations.""" from typing import Optional -from langchain.document_loaders import CSVLoader as LCCSVLoader -from langchain.document_loaders.helpers import detect_file_encodings -from langchain.schema import Document - -logger = logging.getLogger(__name__) +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.extractor.helpers import detect_file_encodings +from core.rag.models.document import Document -class CSVLoader(LCCSVLoader): +class HtmlExtractor(BaseExtractor): + """Load html files. + + + Args: + file_path: Path to the file to load. + """ + def __init__( self, file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, source_column: Optional[str] = None, csv_args: Optional[dict] = None, - encoding: Optional[str] = None, - autodetect_encoding: bool = True, ): - self.file_path = file_path + """Initialize with file path.""" + self._file_path = file_path + self._encoding = encoding + self._autodetect_encoding = autodetect_encoding self.source_column = source_column - self.encoding = encoding self.csv_args = csv_args or {} - self.autodetect_encoding = autodetect_encoding - def load(self) -> list[Document]: + def extract(self) -> list[Document]: """Load data into document objects.""" try: - with open(self.file_path, newline="", encoding=self.encoding) as csvfile: + with open(self._file_path, newline="", encoding=self._encoding) as csvfile: docs = self._read_from_file(csvfile) except UnicodeDecodeError as e: - if self.autodetect_encoding: - detected_encodings = detect_file_encodings(self.file_path) + if self._autodetect_encoding: + detected_encodings = detect_file_encodings(self._file_path) for encoding in detected_encodings: - logger.debug("Trying encoding: ", encoding.encoding) try: - with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: + with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: docs = self._read_from_file(csvfile) break except UnicodeDecodeError: continue else: - raise RuntimeError(f"Error loading {self.file_path}") from e + raise RuntimeError(f"Error loading {self._file_path}") from e return docs - def _read_from_file(self, csvfile): + def _read_from_file(self, csvfile) -> list[Document]: docs = [] csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore for i, row in enumerate(csv_reader): diff --git a/api/core/data_loader/loader/markdown.py b/api/core/rag/extractor/markdown_extractor.py similarity index 79% rename from api/core/data_loader/loader/markdown.py rename to api/core/rag/extractor/markdown_extractor.py index ecbc6d548f..91c687bac9 100644 --- a/api/core/data_loader/loader/markdown.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -1,39 +1,27 @@ -import logging +"""Abstract interface for document loader implementations.""" import re from typing import Optional, cast -from langchain.document_loaders.base import BaseLoader -from langchain.document_loaders.helpers import detect_file_encodings -from langchain.schema import Document - -logger = logging.getLogger(__name__) +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.extractor.helpers import detect_file_encodings +from core.rag.models.document import Document -class MarkdownLoader(BaseLoader): - """Load md files. +class MarkdownExtractor(BaseExtractor): + """Load Markdown files. Args: file_path: Path to the file to load. - - remove_hyperlinks: Whether to remove hyperlinks from the text. - - remove_images: Whether to remove images from the text. - - encoding: File encoding to use. If `None`, the file will be loaded - with the default system encoding. - - autodetect_encoding: Whether to try to autodetect the file encoding - if the specified encoding fails. """ def __init__( - self, - file_path: str, - remove_hyperlinks: bool = True, - remove_images: bool = True, - encoding: Optional[str] = None, - autodetect_encoding: bool = True, + self, + file_path: str, + remove_hyperlinks: bool = True, + remove_images: bool = True, + encoding: Optional[str] = None, + autodetect_encoding: bool = True, ): """Initialize with file path.""" self._file_path = file_path @@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader): self._encoding = encoding self._autodetect_encoding = autodetect_encoding - def load(self) -> list[Document]: + def extract(self) -> list[Document]: + """Load from file path.""" tups = self.parse_tups(self._file_path) documents = [] for header, value in tups: @@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader): if self._autodetect_encoding: detected_encodings = detect_file_encodings(filepath) for encoding in detected_encodings: - logger.debug("Trying encoding: ", encoding.encoding) try: with open(filepath, encoding=encoding.encoding) as f: content = f.read() diff --git a/api/core/data_loader/loader/notion.py b/api/core/rag/extractor/notion_extractor.py similarity index 89% rename from api/core/data_loader/loader/notion.py rename to api/core/rag/extractor/notion_extractor.py index f8d8837683..f28436ffd9 100644 --- a/api/core/data_loader/loader/notion.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -4,9 +4,10 @@ from typing import Any, Optional import requests from flask import current_app -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from flask_login import current_user +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Document as DocumentModel from models.source import DataSourceBinding @@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] -class NotionLoader(BaseLoader): +class NotionExtractor(BaseExtractor): + def __init__( self, - notion_access_token: str, notion_workspace_id: str, notion_obj_id: str, notion_page_type: str, - document_model: Optional[DocumentModel] = None + document_model: Optional[DocumentModel] = None, + notion_access_token: Optional[str] = None ): + self._notion_access_token = None self._document_model = document_model self._notion_workspace_id = notion_workspace_id self._notion_obj_id = notion_obj_id self._notion_page_type = notion_page_type - self._notion_access_token = notion_access_token + if notion_access_token: + self._notion_access_token = notion_access_token + else: + self._notion_access_token = self._get_access_token(current_user.current_tenant_id, + self._notion_workspace_id) + if not self._notion_access_token: + integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') + if integration_token is None: + raise ValueError( + "Must specify `integration_token` or set environment " + "variable `NOTION_INTEGRATION_TOKEN`." + ) - if not self._notion_access_token: - integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') - if integration_token is None: - raise ValueError( - "Must specify `integration_token` or set environment " - "variable `NOTION_INTEGRATION_TOKEN`." - ) + self._notion_access_token = integration_token - self._notion_access_token = integration_token - - @classmethod - def from_document(cls, document_model: DocumentModel): - data_source_info = document_model.data_source_info_dict - if not data_source_info or 'notion_page_id' not in data_source_info \ - or 'notion_workspace_id' not in data_source_info: - raise ValueError("no notion page found") - - notion_workspace_id = data_source_info['notion_workspace_id'] - notion_obj_id = data_source_info['notion_page_id'] - notion_page_type = data_source_info['type'] - notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id) - - return cls( - notion_access_token=notion_access_token, - notion_workspace_id=notion_workspace_id, - notion_obj_id=notion_obj_id, - notion_page_type=notion_page_type, - document_model=document_model - ) - - def load(self) -> list[Document]: + def extract(self) -> list[Document]: self.update_last_edited_time( self._document_model ) diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py new file mode 100644 index 0000000000..cbb2655390 --- /dev/null +++ b/api/core/rag/extractor/pdf_extractor.py @@ -0,0 +1,72 @@ +"""Abstract interface for document loader implementations.""" +from collections.abc import Iterator +from typing import Optional + +from core.rag.extractor.blod.blod import Blob +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document +from extensions.ext_storage import storage + + +class PdfExtractor(BaseExtractor): + """Load pdf files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str, + file_cache_key: Optional[str] = None + ): + """Initialize with file path.""" + self._file_path = file_path + self._file_cache_key = file_cache_key + + def extract(self) -> list[Document]: + plaintext_file_key = '' + plaintext_file_exists = False + if self._file_cache_key: + try: + text = storage.load(self._file_cache_key).decode('utf-8') + plaintext_file_exists = True + return [Document(page_content=text)] + except FileNotFoundError: + pass + documents = list(self.load()) + text_list = [] + for document in documents: + text_list.append(document.page_content) + text = "\n\n".join(text_list) + + # save plaintext file for caching + if not plaintext_file_exists and plaintext_file_key: + storage.save(plaintext_file_key, text.encode('utf-8')) + + return documents + + def load( + self, + ) -> Iterator[Document]: + """Lazy load given path as pages.""" + blob = Blob.from_path(self._file_path) + yield from self.parse(blob) + + def parse(self, blob: Blob) -> Iterator[Document]: + """Lazily parse the blob.""" + import pypdfium2 + + with blob.as_bytes_io() as file_path: + pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) + try: + for page_number, page in enumerate(pdf_reader): + text_page = page.get_textpage() + content = text_page.get_text_range() + text_page.close() + page.close() + metadata = {"source": blob.source, "page": page_number} + yield Document(page_content=content, metadata=metadata) + finally: + pdf_reader.close() diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py new file mode 100644 index 0000000000..ac5d0920cf --- /dev/null +++ b/api/core/rag/extractor/text_extractor.py @@ -0,0 +1,50 @@ +"""Abstract interface for document loader implementations.""" +from typing import Optional + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.extractor.helpers import detect_file_encodings +from core.rag.models.document import Document + + +class TextExtractor(BaseExtractor): + """Load text files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False + ): + """Initialize with file path.""" + self._file_path = file_path + self._encoding = encoding + self._autodetect_encoding = autodetect_encoding + + def extract(self) -> list[Document]: + """Load from file path.""" + text = "" + try: + with open(self._file_path, encoding=self._encoding) as f: + text = f.read() + except UnicodeDecodeError as e: + if self._autodetect_encoding: + detected_encodings = detect_file_encodings(self._file_path) + for encoding in detected_encodings: + try: + with open(self._file_path, encoding=encoding.encoding) as f: + text = f.read() + break + except UnicodeDecodeError: + continue + else: + raise RuntimeError(f"Error loading {self._file_path}") from e + except Exception as e: + raise RuntimeError(f"Error loading {self._file_path}") from e + + metadata = {"source": self._file_path} + return [Document(page_content=text, metadata=metadata)] diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py new file mode 100644 index 0000000000..b37981a30d --- /dev/null +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -0,0 +1,61 @@ +import logging +import os + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document + +logger = logging.getLogger(__name__) + + +class UnstructuredWordExtractor(BaseExtractor): + """Loader that uses unstructured to load word documents. + """ + + def __init__( + self, + file_path: str, + api_url: str, + ): + """Initialize with file path.""" + self._file_path = file_path + self._api_url = api_url + + def extract(self) -> list[Document]: + from unstructured.__version__ import __version__ as __unstructured_version__ + from unstructured.file_utils.filetype import FileType, detect_filetype + + unstructured_version = tuple( + [int(x) for x in __unstructured_version__.split(".")] + ) + # check the file extension + try: + import magic # noqa: F401 + + is_doc = detect_filetype(self._file_path) == FileType.DOC + except ImportError: + _, extension = os.path.splitext(str(self._file_path)) + is_doc = extension == ".doc" + + if is_doc and unstructured_version < (0, 4, 11): + raise ValueError( + f"You are on unstructured version {__unstructured_version__}. " + "Partitioning .doc files is only supported in unstructured>=0.4.11. " + "Please upgrade the unstructured package and try again." + ) + + if is_doc: + from unstructured.partition.doc import partition_doc + + elements = partition_doc(filename=self._file_path) + else: + from unstructured.partition.docx import partition_docx + + elements = partition_docx(filename=self._file_path) + + from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0) + documents = [] + for chunk in chunks: + text = chunk.text.strip() + documents.append(Document(page_content=text)) + return documents diff --git a/api/core/data_loader/loader/unstructured/unstructured_eml.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py similarity index 87% rename from api/core/data_loader/loader/unstructured/unstructured_eml.py rename to api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2fa3aac133..1d92bbbee6 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_eml.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -2,13 +2,14 @@ import base64 import logging from bs4 import BeautifulSoup -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredEmailLoader(BaseLoader): +class UnstructuredEmailExtractor(BaseExtractor): """Load msg files. Args: file_path: Path to the file to load. @@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.email import partition_email elements = partition_email(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_markdown.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py similarity index 85% rename from api/core/data_loader/loader/unstructured/unstructured_markdown.py rename to api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 036a2afd25..3ac04ddc17 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_markdown.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -1,12 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredMarkdownLoader(BaseLoader): +class UnstructuredMarkdownExtractor(BaseExtractor): """Load md files. @@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.md import partition_md elements = partition_md(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_msg.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py similarity index 80% rename from api/core/data_loader/loader/unstructured/unstructured_msg.py rename to api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index 495be328ed..d4b72e37eb 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_msg.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -1,12 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredMsgLoader(BaseLoader): +class UnstructuredMsgExtractor(BaseExtractor): """Load msg files. @@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.msg import partition_msg elements = partition_msg(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_ppt.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py similarity index 78% rename from api/core/data_loader/loader/unstructured/unstructured_ppt.py rename to api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index cfac91cc7b..cd3aba9866 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_ppt.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -1,11 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredPPTLoader(BaseLoader): + +class UnstructuredPPTExtractor(BaseExtractor): """Load msg files. @@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader): """ def __init__( - self, - file_path: str, - api_url: str + self, + file_path: str, + api_url: str ): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.ppt import partition_ppt elements = partition_ppt(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_pptx.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py similarity index 78% rename from api/core/data_loader/loader/unstructured/unstructured_pptx.py rename to api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index 41e3bfcb54..f9667d2527 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_pptx.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -1,10 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredPPTXLoader(BaseLoader): + + +class UnstructuredPPTXExtractor(BaseExtractor): """Load msg files. @@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader): """ def __init__( - self, - file_path: str, - api_url: str + self, + file_path: str, + api_url: str ): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_text.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py similarity index 80% rename from api/core/data_loader/loader/unstructured/unstructured_text.py rename to api/core/rag/extractor/unstructured/unstructured_text_extractor.py index 09d14fdb17..5af21b2b1d 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_text.py +++ b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py @@ -1,12 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredTextLoader(BaseLoader): +class UnstructuredTextExtractor(BaseExtractor): """Load msg files. @@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.text import partition_text elements = partition_text(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_xml.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py similarity index 81% rename from api/core/data_loader/loader/unstructured/unstructured_xml.py rename to api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index cca6e1b0b7..b08ff63a1c 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_xml.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -1,12 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredXmlLoader(BaseLoader): +class UnstructuredXmlExtractor(BaseExtractor): """Load msg files. @@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.xml import partition_xml elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py new file mode 100644 index 0000000000..8e2cd14be7 --- /dev/null +++ b/api/core/rag/extractor/word_extractor.py @@ -0,0 +1,62 @@ +"""Abstract interface for document loader implementations.""" +import os +import tempfile +from urllib.parse import urlparse + +import requests + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document + + +class WordExtractor(BaseExtractor): + """Load pdf files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__(self, file_path: str): + """Initialize with file path.""" + self.file_path = file_path + if "~" in self.file_path: + self.file_path = os.path.expanduser(self.file_path) + + # If the file is a web path, download it to a temporary file, and use that + if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): + r = requests.get(self.file_path) + + if r.status_code != 200: + raise ValueError( + "Check the url of your file; returned status code %s" + % r.status_code + ) + + self.web_path = self.file_path + self.temp_file = tempfile.NamedTemporaryFile() + self.temp_file.write(r.content) + self.file_path = self.temp_file.name + elif not os.path.isfile(self.file_path): + raise ValueError("File path %s is not a valid file or url" % self.file_path) + + def __del__(self) -> None: + if hasattr(self, "temp_file"): + self.temp_file.close() + + def extract(self) -> list[Document]: + """Load given path as single page.""" + import docx2txt + + return [ + Document( + page_content=docx2txt.process(self.file_path), + metadata={"source": self.file_path}, + ) + ] + + @staticmethod + def _is_valid_url(url: str) -> bool: + """Check if the url is valid.""" + parsed = urlparse(url) + return bool(parsed.netloc) and bool(parsed.scheme) diff --git a/api/core/rag/index_processor/__init__.py b/api/core/rag/index_processor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/index_processor/constant/__init__.py b/api/core/rag/index_processor/constant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py new file mode 100644 index 0000000000..e42cc44c6f --- /dev/null +++ b/api/core/rag/index_processor/constant/index_type.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class IndexType(Enum): + PARAGRAPH_INDEX = "text_model" + QA_INDEX = "qa_model" + PARENT_CHILD_INDEX = "parent_child_index" + SUMMARY_INDEX = "summary_index" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py new file mode 100644 index 0000000000..fcb06e5c84 --- /dev/null +++ b/api/core/rag/index_processor/index_processor_base.py @@ -0,0 +1,69 @@ +"""Abstract interface for document loader implementations.""" +from abc import ABC, abstractmethod +from typing import Optional + +from core.model_manager import ModelInstance +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.models.document import Document +from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter +from core.splitter.text_splitter import TextSplitter +from models.dataset import Dataset, DatasetProcessRule + + +class BaseIndexProcessor(ABC): + """Interface for extract files. + """ + + @abstractmethod + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + raise NotImplementedError + + @abstractmethod + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + raise NotImplementedError + + @abstractmethod + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + raise NotImplementedError + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + raise NotImplementedError + + @abstractmethod + def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, + score_threshold: float, reranking_model: dict) -> list[Document]: + raise NotImplementedError + + def _get_splitter(self, processing_rule: dict, + embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + """ + Get the NodeParser object according to the processing rule. + """ + if processing_rule['mode'] == "custom": + # The user-defined segmentation rule + rules = processing_rule['rules'] + segmentation = rules["segmentation"] + if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000: + raise ValueError("Custom segment length should be between 50 and 1000.") + + separator = segmentation["separator"] + if separator: + separator = separator.replace('\\n', '\n') + + character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( + chunk_size=segmentation["max_tokens"], + chunk_overlap=0, + fixed_separator=separator, + separators=["\n\n", "。", ".", " ", ""], + embedding_model_instance=embedding_model_instance + ) + else: + # Automatic segmentation + character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( + chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], + chunk_overlap=0, + separators=["\n\n", "。", ".", " ", ""], + embedding_model_instance=embedding_model_instance + ) + + return character_splitter diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py new file mode 100644 index 0000000000..df43a64910 --- /dev/null +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -0,0 +1,28 @@ +"""Abstract interface for document loader implementations.""" + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor + + +class IndexProcessorFactory: + """IndexProcessorInit. + """ + + def __init__(self, index_type: str): + self._index_type = index_type + + def init_index_processor(self) -> BaseIndexProcessor: + """Init index processor.""" + + if not self._index_type: + raise ValueError("Index type must be specified.") + + if self._index_type == IndexType.PARAGRAPH_INDEX.value: + return ParagraphIndexProcessor() + elif self._index_type == IndexType.QA_INDEX.value: + + return QAIndexProcessor() + else: + raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/__init__.py b/api/core/rag/index_processor/processor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py new file mode 100644 index 0000000000..3f0467ee24 --- /dev/null +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -0,0 +1,92 @@ +"""Paragraph index processor.""" +import uuid +from typing import Optional + +from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import Document +from libs import helper +from models.dataset import Dataset + + +class ParagraphIndexProcessor(BaseIndexProcessor): + + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + + text_docs = ExtractProcessor.extract(extract_setting=extract_setting, + is_automatic=kwargs.get('process_rule_mode') == "automatic") + + return text_docs + + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + # Split the text documents into nodes. + splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), + embedding_model_instance=kwargs.get('embedding_model_instance')) + all_documents = [] + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document.page_content = document_text + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: + + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata['doc_id'] = doc_id + document_node.metadata['doc_hash'] = hash + # delete Spliter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:] + else: + page_content = page_content + document_node.page_content = page_content + split_documents.append(document_node) + all_documents.extend(split_documents) + return all_documents + + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + vector.create(documents) + if with_keywords: + keyword = Keyword(dataset) + keyword.create(documents) + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + if node_ids: + vector.delete_by_ids(node_ids) + else: + vector.delete() + if with_keywords: + keyword = Keyword(dataset) + if node_ids: + keyword.delete_by_ids(node_ids) + else: + keyword.delete() + + def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, + score_threshold: float, reranking_model: dict) -> list[Document]: + # Set search parameters. + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata['score'] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py new file mode 100644 index 0000000000..f61c728b49 --- /dev/null +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -0,0 +1,161 @@ +"""Paragraph index processor.""" +import logging +import re +import threading +import uuid +from typing import Optional + +import pandas as pd +from flask import Flask, current_app +from flask_login import current_user +from werkzeug.datastructures import FileStorage + +from core.generator.llm_generator import LLMGenerator +from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import Document +from libs import helper +from models.dataset import Dataset + + +class QAIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + + text_docs = ExtractProcessor.extract(extract_setting=extract_setting, + is_automatic=kwargs.get('process_rule_mode') == "automatic") + return text_docs + + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), + embedding_model_instance=None) + + # Split the text documents into nodes. + all_documents = [] + all_qa_documents = [] + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document.page_content = document_text + + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: + + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata['doc_id'] = doc_id + document_node.metadata['doc_hash'] = hash + # delete Spliter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:] + else: + page_content = page_content + document_node.page_content = page_content + split_documents.append(document_node) + all_documents.extend(split_documents) + for i in range(0, len(all_documents), 10): + threads = [] + sub_documents = all_documents[i:i + 10] + for doc in sub_documents: + document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ + 'flask_app': current_app._get_current_object(), + 'tenant_id': current_user.current_tenant.id, + 'document_node': doc, + 'all_qa_documents': all_qa_documents, + 'document_language': kwargs.get('document_language', 'English')}) + threads.append(document_format_thread) + document_format_thread.start() + for thread in threads: + thread.join() + return all_qa_documents + + def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: + + # check file type + if not file.filename.endswith('.csv'): + raise ValueError("Invalid file type. Only CSV files are allowed") + + try: + # Skip the first row + df = pd.read_csv(file) + text_docs = [] + for index, row in df.iterrows(): + data = Document(page_content=row[0], metadata={'answer': row[1]}) + text_docs.append(data) + if len(text_docs) == 0: + raise ValueError("The CSV file is empty.") + + except Exception as e: + raise ValueError(str(e)) + return text_docs + + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + vector.create(documents) + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + vector = Vector(dataset) + if node_ids: + vector.delete_by_ids(node_ids) + else: + vector.delete() + + def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, + score_threshold: float, reranking_model: dict): + # Set search parameters. + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata['score'] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs + + def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): + format_documents = [] + if document_node.page_content is None or not document_node.page_content.strip(): + return + with flask_app.app_context(): + try: + # qa model document + response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) + document_qa_list = self._format_split_text(response) + qa_documents = [] + for result in document_qa_list: + qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result['question']) + qa_document.metadata['answer'] = result['answer'] + qa_document.metadata['doc_id'] = doc_id + qa_document.metadata['doc_hash'] = hash + qa_documents.append(qa_document) + format_documents.extend(qa_documents) + except Exception as e: + logging.exception(e) + + all_qa_documents.extend(format_documents) + + def _format_split_text(self, text): + regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" + matches = re.findall(regex, text, re.UNICODE) + + return [ + { + "question": q, + "answer": re.sub(r"\n\s*", "\n", a.strip()) + } + for q, a in matches if q and a + ] diff --git a/api/core/rag/models/__init__.py b/api/core/rag/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py new file mode 100644 index 0000000000..221318c2c3 --- /dev/null +++ b/api/core/rag/models/document.py @@ -0,0 +1,79 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class Document(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + metadata: Optional[dict] = Field(default_factory=dict) + + +class BaseDocumentTransformer(ABC): + """Abstract base class for document transformation systems. + + A document transformation system takes a sequence of Documents and returns a + sequence of transformed Documents. + + Example: + .. code-block:: python + + class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + embeddings: Embeddings + similarity_fn: Callable = cosine_similarity + similarity_threshold: float = 0.95 + + class Config: + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_similar_embeddings( + embedded_documents, self.similarity_fn, self.similarity_threshold + ) + return [stateful_documents[i] for i in sorted(included_idxs)] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + + """ # noqa: E501 + + @abstractmethod + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + + @abstractmethod + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ diff --git a/api/core/rerank/rerank.py b/api/core/rerank/rerank.py index a675dfc568..7000f4e0ad 100644 --- a/api/core/rerank/rerank.py +++ b/api/core/rerank/rerank.py @@ -1,8 +1,7 @@ from typing import Optional -from langchain.schema import Document - from core.model_manager import ModelInstance +from core.rag.models.document import Document class RerankRunner: diff --git a/api/core/splitter/fixed_text_splitter.py b/api/core/splitter/fixed_text_splitter.py index 285a7ba14e..a1510259ac 100644 --- a/api/core/splitter/fixed_text_splitter.py +++ b/api/core/splitter/fixed_text_splitter.py @@ -3,20 +3,18 @@ from __future__ import annotations from typing import Any, Optional, cast -from langchain.text_splitter import ( - TS, - AbstractSet, - Collection, - Literal, - RecursiveCharacterTextSplitter, - TokenTextSplitter, - Type, - Union, -) - from core.model_manager import ModelInstance from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.splitter.text_splitter import ( + TS, + Collection, + Literal, + RecursiveCharacterTextSplitter, + Set, + TokenTextSplitter, + Union, +) class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @@ -26,9 +24,9 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @classmethod def from_encoder( - cls: Type[TS], + cls: type[TS], embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], AbstractSet[str]] = set(), + allowed_special: Union[Literal[all], Set[str]] = set(), disallowed_special: Union[Literal[all], Collection[str]] = "all", **kwargs: Any, ): diff --git a/api/core/splitter/text_splitter.py b/api/core/splitter/text_splitter.py new file mode 100644 index 0000000000..e3d43c0658 --- /dev/null +++ b/api/core/splitter/text_splitter.py @@ -0,0 +1,903 @@ +from __future__ import annotations + +import copy +import logging +import re +from abc import ABC, abstractmethod +from collections.abc import Callable, Collection, Iterable, Sequence, Set +from dataclasses import dataclass +from enum import Enum +from typing import ( + Any, + Literal, + Optional, + TypedDict, + TypeVar, + Union, +) + +from core.rag.models.document import BaseDocumentTransformer, Document + +logger = logging.getLogger(__name__) + +TS = TypeVar("TS", bound="TextSplitter") + + +def _split_text_with_regex( + text: str, separator: str, keep_separator: bool +) -> list[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] + if len(_splits) % 2 == 0: + splits += _splits[-1:] + splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class TextSplitter(BaseDocumentTransformer, ABC): + """Interface for splitting text into chunks.""" + + def __init__( + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, + ) -> None: + """Create a new TextSplitter. + + Args: + chunk_size: Maximum size of chunks to return + chunk_overlap: Overlap in characters between chunks + length_function: Function that measures the length of given chunks + keep_separator: Whether to keep the separator in the chunks + add_start_index: If `True`, includes chunk's start index in metadata + """ + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + self._keep_separator = keep_separator + self._add_start_index = add_start_index + + @abstractmethod + def split_text(self, text: str) -> list[str]: + """Split text into multiple components.""" + + def create_documents( + self, texts: list[str], metadatas: Optional[list[dict]] = None + ) -> list[Document]: + """Create documents from a list of texts.""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + index = -1 + for chunk in self.split_text(text): + metadata = copy.deepcopy(_metadatas[i]) + if self._add_start_index: + index = text.find(chunk, index + 1) + metadata["start_index"] = index + new_doc = Document(page_content=chunk, metadata=metadata) + documents.append(new_doc) + return documents + + def split_documents(self, documents: Iterable[Document]) -> list[Document]: + """Split documents.""" + texts, metadatas = [], [] + for doc in documents: + texts.append(doc.page_content) + metadatas.append(doc.metadata) + return self.create_documents(texts, metadatas=metadatas) + + def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: + text = separator.join(docs) + text = text.strip() + if text == "": + return None + else: + return text + + def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: + # We now want to combine these smaller pieces into medium size + # chunks to send to the LLM. + separator_len = self._length_function(separator) + + docs = [] + current_doc: list[str] = [] + total = 0 + for d in splits: + _len = self._length_function(d) + if ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + ): + if total > self._chunk_size: + logger.warning( + f"Created a chunk of size {total}, " + f"which is longer than the specified {self._chunk_size}" + ) + if len(current_doc) > 0: + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self._chunk_overlap or ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + and total > 0 + ): + total -= self._length_function(current_doc[0]) + ( + separator_len if len(current_doc) > 1 else 0 + ) + current_doc = current_doc[1:] + current_doc.append(d) + total += _len + (separator_len if len(current_doc) > 1 else 0) + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + return docs + + @classmethod + def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: + """Text splitter that uses HuggingFace tokenizer to count length.""" + try: + from transformers import PreTrainedTokenizerBase + + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + "Tokenizer received was not an instance of PreTrainedTokenizerBase" + ) + + def _huggingface_tokenizer_length(text: str) -> int: + return len(tokenizer.encode(text)) + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) + return cls(length_function=_huggingface_tokenizer_length, **kwargs) + + @classmethod + def from_tiktoken_encoder( + cls: type[TS], + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ) -> TS: + """Text splitter that uses tiktoken encoder to count length.""" + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to calculate max_tokens_for_prompt. " + "Please install it with `pip install tiktoken`." + ) + + if model_name is not None: + enc = tiktoken.encoding_for_model(model_name) + else: + enc = tiktoken.get_encoding(encoding_name) + + def _tiktoken_encoder(text: str) -> int: + return len( + enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + + if issubclass(cls, TokenTextSplitter): + extra_kwargs = { + "encoding_name": encoding_name, + "model_name": model_name, + "allowed_special": allowed_special, + "disallowed_special": disallowed_special, + } + kwargs = {**kwargs, **extra_kwargs} + + return cls(length_function=_tiktoken_encoder, **kwargs) + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform sequence of documents by splitting them.""" + return self.split_documents(list(documents)) + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a sequence of documents by splitting them.""" + raise NotImplementedError + + +class CharacterTextSplitter(TextSplitter): + """Splitting text that looks at characters.""" + + def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._separator = separator + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + splits = _split_text_with_regex(text, self._separator, self._keep_separator) + _separator = "" if self._keep_separator else self._separator + return self._merge_splits(splits, _separator) + + +class LineType(TypedDict): + """Line type as typed dict.""" + + metadata: dict[str, str] + content: str + + +class HeaderType(TypedDict): + """Header type as typed dict.""" + + level: int + name: str + data: str + + +class MarkdownHeaderTextSplitter: + """Splitting markdown files based on specified headers.""" + + def __init__( + self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False + ): + """Create a new MarkdownHeaderTextSplitter. + + Args: + headers_to_split_on: Headers we want to track + return_each_line: Return each line w/ associated headers + """ + # Output line-by-line or aggregated into chunks w/ common headers + self.return_each_line = return_each_line + # Given the headers we want to split on, + # (e.g., "#, ##, etc") order by length + self.headers_to_split_on = sorted( + headers_to_split_on, key=lambda split: len(split[0]), reverse=True + ) + + def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: + """Combine lines with common metadata into chunks + Args: + lines: Line of text / associated header metadata + """ + aggregated_chunks: list[LineType] = [] + + for line in lines: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == line["metadata"] + ): + # If the last line in the aggregated list + # has the same metadata as the current line, + # append the current content to the last lines's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + else: + # Otherwise, append the current line to the aggregated list + aggregated_chunks.append(line) + + return [ + Document(page_content=chunk["content"], metadata=chunk["metadata"]) + for chunk in aggregated_chunks + ] + + def split_text(self, text: str) -> list[Document]: + """Split markdown file + Args: + text: Markdown file""" + + # Split the input text by newline character ("\n"). + lines = text.split("\n") + # Final output + lines_with_metadata: list[LineType] = [] + # Content and metadata of the chunk currently being processed + current_content: list[str] = [] + current_metadata: dict[str, str] = {} + # Keep track of the nested header structure + # header_stack: List[Dict[str, Union[int, str]]] = [] + header_stack: list[HeaderType] = [] + initial_metadata: dict[str, str] = {} + + for line in lines: + stripped_line = line.strip() + # Check each line against each of the header types (e.g., #, ##) + for sep, name in self.headers_to_split_on: + # Check if line starts with a header that we intend to split on + if stripped_line.startswith(sep) and ( + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) + or stripped_line[len(sep)] == " " + ): + # Ensure we are tracking the header as metadata + if name is not None: + # Get the current header level + current_header_level = sep.count("#") + + # Pop out headers of lower or same level from the stack + while ( + header_stack + and header_stack[-1]["level"] >= current_header_level + ): + # We have encountered a new header + # at the same or higher level + popped_header = header_stack.pop() + # Clear the metadata for the + # popped header in initial_metadata + if popped_header["name"] in initial_metadata: + initial_metadata.pop(popped_header["name"]) + + # Push the current header to the stack + header: HeaderType = { + "level": current_header_level, + "name": name, + "data": stripped_line[len(sep):].strip(), + } + header_stack.append(header) + # Update initial_metadata with the current header + initial_metadata[name] = header["data"] + + # Add the previous line to the lines_with_metadata + # only if current_content is not empty + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + break + else: + if stripped_line: + current_content.append(stripped_line) + elif current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + current_metadata = initial_metadata.copy() + + if current_content: + lines_with_metadata.append( + {"content": "\n".join(current_content), "metadata": current_metadata} + ) + + # lines_with_metadata has each line with associated header metadata + # aggregate these into chunks based on common metadata + if not self.return_each_line: + return self.aggregate_lines_to_chunks(lines_with_metadata) + else: + return [ + Document(page_content=chunk["content"], metadata=chunk["metadata"]) + for chunk in lines_with_metadata + ] + + +# should be in newer Python versions (3.10+) +# @dataclass(frozen=True, kw_only=True, slots=True) +@dataclass(frozen=True) +class Tokenizer: + chunk_overlap: int + tokens_per_chunk: int + decode: Callable[[list[int]], str] + encode: Callable[[str], list[int]] + + +def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: + """Split incoming text and return chunks using tokenizer.""" + splits: list[str] = [] + input_ids = tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(tokenizer.decode(chunk_ids)) + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits + + +class TokenTextSplitter(TextSplitter): + """Splitting text to tokens using model tokenizer.""" + + def __init__( + self, + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to for TokenTextSplitter. " + "Please install it with `pip install tiktoken`." + ) + + if model_name is not None: + enc = tiktoken.encoding_for_model(model_name) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc + self._allowed_special = allowed_special + self._disallowed_special = disallowed_special + + def split_text(self, text: str) -> list[str]: + def _encode(_text: str) -> list[int]: + return self._tokenizer.encode( + _text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self._chunk_size, + decode=self._tokenizer.decode, + encode=_encode, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + +class Language(str, Enum): + """Enum of the programming languages.""" + + CPP = "cpp" + GO = "go" + JAVA = "java" + JS = "js" + PHP = "php" + PROTO = "proto" + PYTHON = "python" + RST = "rst" + RUBY = "ruby" + RUST = "rust" + SCALA = "scala" + SWIFT = "swift" + MARKDOWN = "markdown" + LATEX = "latex" + HTML = "html" + SOL = "sol" + + +class RecursiveCharacterTextSplitter(TextSplitter): + """Splitting text by recursively look at characters. + + Recursively tries to split by different characters to find one + that works. + """ + + def __init__( + self, + separators: Optional[list[str]] = None, + keep_separator: bool = True, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or ["\n\n", "\n", " ", ""] + + def _split_text(self, text: str, separators: list[str]) -> list[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + if _s == "": + separator = _s + break + if re.search(_s, text): + separator = _s + new_separators = separators[i + 1:] + break + + splits = _split_text_with_regex(text, separator, self._keep_separator) + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return final_chunks + + def split_text(self, text: str) -> list[str]: + return self._split_text(text, self._separators) + + @classmethod + def from_language( + cls, language: Language, **kwargs: Any + ) -> RecursiveCharacterTextSplitter: + separators = cls.get_separators_for_language(language) + return cls(separators=separators, **kwargs) + + @staticmethod + def get_separators_for_language(language: Language) -> list[str]: + if language == Language.CPP: + return [ + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nvoid ", + "\nint ", + "\nfloat ", + "\ndouble ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.GO: + return [ + # Split along function definitions + "\nfunc ", + "\nvar ", + "\nconst ", + "\ntype ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JAVA: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JS: + return [ + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PHP: + return [ + # Split along function definitions + "\nfunction ", + # Split along class definitions + "\nclass ", + # Split along control flow statements + "\nif ", + "\nforeach ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PROTO: + return [ + # Split along message definitions + "\nmessage ", + # Split along service definitions + "\nservice ", + # Split along enum definitions + "\nenum ", + # Split along option definitions + "\noption ", + # Split along import statements + "\nimport ", + # Split along syntax declarations + "\nsyntax ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PYTHON: + return [ + # First, try to split along class definitions + "\nclass ", + "\ndef ", + "\n\tdef ", + # Now split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RST: + return [ + # Split along section titles + "\n=+\n", + "\n-+\n", + "\n\*+\n", + # Split along directive markers + "\n\n.. *\n\n", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUBY: + return [ + # Split along method definitions + "\ndef ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nunless ", + "\nwhile ", + "\nfor ", + "\ndo ", + "\nbegin ", + "\nrescue ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUST: + return [ + # Split along function definitions + "\nfn ", + "\nconst ", + "\nlet ", + # Split along control flow statements + "\nif ", + "\nwhile ", + "\nfor ", + "\nloop ", + "\nmatch ", + "\nconst ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SCALA: + return [ + # Split along class definitions + "\nclass ", + "\nobject ", + # Split along method definitions + "\ndef ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nmatch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SWIFT: + return [ + # Split along function definitions + "\nfunc ", + # Split along class definitions + "\nclass ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.MARKDOWN: + return [ + # First, try to split along Markdown headings (starting with level 2) + "\n#{1,6} ", + # Note the alternative syntax for headings (below) is not handled here + # Heading level 2 + # --------------- + # End of code block + "```\n", + # Horizontal lines + "\n\*\*\*+\n", + "\n---+\n", + "\n___+\n", + # Note that this splitter doesn't handle horizontal lines defined + # by *three or more* of ***, ---, or ___, but this is not handled + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.LATEX: + return [ + # First, try to split along Latex sections + "\n\\\chapter{", + "\n\\\section{", + "\n\\\subsection{", + "\n\\\subsubsection{", + # Now split by environments + "\n\\\begin{enumerate}", + "\n\\\begin{itemize}", + "\n\\\begin{description}", + "\n\\\begin{list}", + "\n\\\begin{quote}", + "\n\\\begin{quotation}", + "\n\\\begin{verse}", + "\n\\\begin{verbatim}", + # Now split by math environments + "\n\\\begin{align}", + "$$", + "$", + # Now split by the normal type of lines + " ", + "", + ] + elif language == Language.HTML: + return [ + # First, try to split along HTML tags + " str: headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } - supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] + supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) @@ -158,8 +158,8 @@ def get_url(url: str) -> str: if main_content_type not in supported_content_types: return "Unsupported content-type [{}] of URL.".format(main_content_type) - if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: - return FileExtractor.load_from_url(url, return_text=True) + if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: + return ExtractProcessor.load_from_url(url, return_text=True) response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) a = extract_using_readabilipy(response.text) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index b4579ec65a..ad27706c3a 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -113,7 +113,7 @@ class ToolParameter(BaseModel): form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None required: Optional[bool] = False - default: Optional[str] = None + default: Optional[Union[int, str]] = None min: Optional[Union[float, int]] = None max: Optional[Union[float, int]] = None options: Optional[list[ToolParameterOption]] = None diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index 7568b733cf..7b740293dd 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -1,4 +1,5 @@ from typing import Any, Union +from urllib.parse import quote from requests import get @@ -34,6 +35,18 @@ class BingSearchTool(BuiltinTool): market = tool_parameters.get('market', 'US') lang = tool_parameters.get('language', 'en') + filter = [] + + if tool_parameters.get('enable_computation', False): + filter.append('Computation') + if tool_parameters.get('enable_entities', False): + filter.append('Entities') + if tool_parameters.get('enable_news', False): + filter.append('News') + if tool_parameters.get('enable_related_search', False): + filter.append('RelatedSearches') + if tool_parameters.get('enable_webpages', False): + filter.append('WebPages') market_code = f'{lang}-{market}' accept_language = f'{lang},{market_code};q=0.9' @@ -42,35 +55,72 @@ class BingSearchTool(BuiltinTool): 'Accept-Language': accept_language } - params = { - 'q': query, - 'mkt': market_code - } - - response = get(server_url, headers=headers, params=params) + query = quote(query) + server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filter)}' + response = get(server_url, headers=headers) if response.status_code != 200: raise Exception(f'Error {response.status_code}: {response.text}') response = response.json() - search_results = response['webPages']['value'][:limit] + search_results = response['webPages']['value'][:limit] if 'webPages' in response else [] + related_searches = response['relatedSearches']['value'] if 'relatedSearches' in response else [] + entities = response['entities']['value'] if 'entities' in response else [] + news = response['news']['value'] if 'news' in response else [] + computation = response['computation']['value'] if 'computation' in response else None if result_type == 'link': results = [] - for result in search_results: - results.append(self.create_text_message( - text=f'{result["name"]}: {result["url"]}' - )) + if search_results: + for result in search_results: + results.append(self.create_text_message( + text=f'{result["name"]}: {result["url"]}' + )) + + if entities: + for entity in entities: + results.append(self.create_text_message( + text=f'{entity["name"]}: {entity["url"]}' + )) + + if news: + for news_item in news: + results.append(self.create_text_message( + text=f'{news_item["name"]}: {news_item["url"]}' + )) + + if related_searches: + for related in related_searches: + results.append(self.create_text_message( + text=f'{related["displayText"]}: {related["webSearchUrl"]}' + )) + return results else: # construct text text = '' - for i, result in enumerate(search_results): - text += f'{i+1}: {result["name"]} - {result["snippet"]}\n' + if search_results: + for i, result in enumerate(search_results): + text += f'{i+1}: {result["name"]} - {result["snippet"]}\n' - text += '\n\nRelated Searches:\n' - for related in response['relatedSearches']['value']: - text += f'{related["displayText"]} - {related["webSearchUrl"]}\n' + if computation and 'expression' in computation and 'value' in computation: + text += '\nComputation:\n' + text += f'{computation["expression"]} = {computation["value"]}\n' + + if entities: + text += '\nEntities:\n' + for entity in entities: + text += f'{entity["name"]} - {entity["url"]}\n' + + if news: + text += '\nNews:\n' + for news_item in news: + text += f'{news_item["name"]} - {news_item["url"]}\n' + + if related_searches: + text += '\n\nRelated Searches:\n' + for related in related_searches: + text += f'{related["displayText"]} - {related["webSearchUrl"]}\n' return self.create_text_message(text=self.summary(user_id=user_id, content=text)) diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.yaml b/api/core/tools/provider/builtin/bing/tools/bing_web_search.yaml index 329a83f12f..6bf64efb99 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.yaml +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.yaml @@ -25,9 +25,74 @@ parameters: zh_Hans: 用于搜索网页内容 pt_BR: used for searching llm_description: key words for searching + - name: enable_computation + type: boolean + required: false + form: form + label: + en_US: Enable computation + zh_Hans: 启用计算 + pt_BR: Enable computation + human_description: + en_US: enable computation + zh_Hans: 启用计算 + pt_BR: enable computation + default: false + - name: enable_entities + type: boolean + required: false + form: form + label: + en_US: Enable entities + zh_Hans: 启用实体搜索 + pt_BR: Enable entities + human_description: + en_US: enable entities + zh_Hans: 启用实体搜索 + pt_BR: enable entities + default: true + - name: enable_news + type: boolean + required: false + form: form + label: + en_US: Enable news + zh_Hans: 启用新闻搜索 + pt_BR: Enable news + human_description: + en_US: enable news + zh_Hans: 启用新闻搜索 + pt_BR: enable news + default: false + - name: enable_related_search + type: boolean + required: false + form: form + label: + en_US: Enable related search + zh_Hans: 启用相关搜索 + pt_BR: Enable related search + human_description: + en_US: enable related search + zh_Hans: 启用相关搜索 + pt_BR: enable related search + default: false + - name: enable_webpages + type: boolean + required: false + form: form + label: + en_US: Enable webpages search + zh_Hans: 启用网页搜索 + pt_BR: Enable webpages search + human_description: + en_US: enable webpages search + zh_Hans: 启用网页搜索 + pt_BR: enable webpages search + default: true - name: limit type: number - required: false + required: true form: form label: en_US: Limit for results length @@ -42,7 +107,7 @@ parameters: default: 5 - name: result_type type: select - required: false + required: true label: en_US: result type zh_Hans: 结果类型 diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 490f8514de..57b6e090c4 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -6,15 +6,12 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Field from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.embedding.cached_embedding import CacheEmbedding -from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.retrieval_service import RetrievalService from core.rerank.rerank import RerankRunner from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from services.retrieval_service import RetrievalService default_retrieval_model = { 'search_method': 'semantic_search', @@ -174,76 +171,24 @@ class DatasetMultiRetrieverTool(BaseTool): if dataset.indexing_technique == "economy": # use keyword table query - kw_table_index = KeywordTableIndex( - dataset=dataset, - config=KeywordTableConfig( - max_keywords_per_chunk=5 - ) - ) - - documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) if documents: all_documents.extend(documents) else: - - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - except LLMBadRequestError: - return [] - except ProviderTokenNotInitError: - return [] - - embeddings = CacheEmbedding(embedding_model) - - documents = [] - threads = [] if self.top_k > 0: - # retrieval_model source with semantic - if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[ - 'search_method'] == 'hybrid_search': - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'top_k': self.top_k, - 'score_threshold': self.score_threshold, - 'reranking_model': None, - 'all_documents': documents, - 'search_method': 'hybrid_search', - 'embeddings': embeddings - }) - threads.append(embedding_thread) - embedding_thread.start() - - # retrieval_model source with full text - if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[ - 'search_method'] == 'hybrid_search': - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, - kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'search_method': 'hybrid_search', - 'embeddings': embeddings, - 'score_threshold': retrieval_model[ - 'score_threshold'] if retrieval_model[ - 'score_threshold_enabled'] else None, - 'top_k': self.top_k, - 'reranking_model': retrieval_model[ - 'reranking_model'] if retrieval_model[ - 'reranking_enable'] else None, - 'all_documents': documents - }) - threads.append(full_text_index_thread) - full_text_index_thread.start() - - for thread in threads: - thread.join() + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) all_documents.extend(documents) \ No newline at end of file diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index c07cd3a5ff..d3ec0fba69 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,20 +1,12 @@ -import threading from typing import Optional -from flask import current_app from langchain.tools import BaseTool from pydantic import BaseModel, Field from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.embedding.cached_embedding import CacheEmbedding -from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError -from core.rerank.rerank import RerankRunner +from core.rag.datasource.retrieval_service import RetrievalService from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from services.retrieval_service import RetrievalService default_retrieval_model = { 'search_method': 'semantic_search', @@ -77,94 +69,24 @@ class DatasetRetrieverTool(BaseTool): retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - kw_table_index = KeywordTableIndex( - dataset=dataset, - config=KeywordTableConfig( - max_keywords_per_chunk=5 - ) - ) - - documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) return str("\n".join([document.page_content for document in documents])) else: - # get embedding model instance - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - except InvokeAuthorizationError: - return '' - - embeddings = CacheEmbedding(embedding_model) - - documents = [] - threads = [] if self.top_k > 0: - # retrieval source with semantic - if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'top_k': self.top_k, - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ - 'score_threshold_enabled'] else None, - 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ - 'reranking_enable'] else None, - 'all_documents': documents, - 'search_method': retrieval_model['search_method'], - 'embeddings': embeddings - }) - threads.append(embedding_thread) - embedding_thread.start() - - # retrieval_model source with full text - if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'search_method': retrieval_model['search_method'], - 'embeddings': embeddings, - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ - 'score_threshold_enabled'] else None, - 'top_k': self.top_k, - 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ - 'reranking_enable'] else None, - 'all_documents': documents - }) - threads.append(full_text_index_thread) - full_text_index_thread.start() - - for thread in threads: - thread.join() - - # hybrid search: rerank after all documents have been searched - if retrieval_model['search_method'] == 'hybrid_search': - # get rerank model instance - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=retrieval_model['reranking_model']['reranking_provider_name'], - model_type=ModelType.RERANK, - model=retrieval_model['reranking_model']['reranking_model_name'] - ) - except InvokeAuthorizationError: - return '' - - rerank_runner = RerankRunner(rerank_model_instance) - documents = rerank_runner.run( - query=query, - documents=documents, - score_threshold=retrieval_model['score_threshold'] if retrieval_model[ - 'score_threshold_enabled'] else None, - top_n=self.top_k - ) + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) else: documents = [] @@ -234,4 +156,4 @@ class DatasetRetrieverTool(BaseTool): return str("\n".join(document_context_list)) async def _arun(self, tool_input: str) -> str: - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 72ff3cdbdd..9975978357 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -13,7 +13,6 @@ import requests from bs4 import BeautifulSoup, CData, Comment, NavigableString from langchain.chains import RefineDocumentsChain from langchain.chains.summarize import refine_prompts -from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.tools.base import BaseTool from newspaper import Article @@ -21,9 +20,10 @@ from pydantic import BaseModel, Field from regex import regex from core.chain.llm_chain import LLMChain -from core.data_loader import file_extractor -from core.data_loader.file_extractor import FileExtractor from core.entities.application_entities import ModelConfigEntity +from core.rag.extractor import extract_processor +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.models.document import Document FULL_TEMPLATE = """ TITLE: {title} @@ -149,7 +149,7 @@ def get_url(url: str, user_agent: str = None) -> str: if user_agent: headers["User-Agent"] = user_agent - supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] + supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) @@ -161,8 +161,8 @@ def get_url(url: str, user_agent: str = None) -> str: if main_content_type not in supported_content_types: return "Unsupported content-type [{}] of URL.".format(main_content_type) - if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: - return FileExtractor.load_from_url(url, return_text=True) + if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: + return ExtractProcessor.load_from_url(url, return_text=True) response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) a = extract_using_readabilipy(response.text) diff --git a/api/core/vector_store/milvus_vector_store.py b/api/core/vector_store/milvus_vector_store.py deleted file mode 100644 index 67b958ded0..0000000000 --- a/api/core/vector_store/milvus_vector_store.py +++ /dev/null @@ -1,56 +0,0 @@ -from core.vector_store.vector.milvus import Milvus - - -class MilvusVectorStore(Milvus): - def del_texts(self, where_filter: dict): - if not where_filter: - raise ValueError('where_filter must not be empty') - - self.col.delete(where_filter.get('filter')) - - def del_text(self, uuid: str) -> None: - expr = f"id == {uuid}" - self.col.delete(expr) - - def text_exists(self, uuid: str) -> bool: - result = self.col.query( - expr=f'metadata["doc_id"] == "{uuid}"', - output_fields=["id"] - ) - - return len(result) > 0 - - def get_ids_by_document_id(self, document_id: str): - result = self.col.query( - expr=f'metadata["document_id"] == "{document_id}"', - output_fields=["id"] - ) - if result: - return [item["id"] for item in result] - else: - return None - - def get_ids_by_metadata_field(self, key: str, value: str): - result = self.col.query( - expr=f'metadata["{key}"] == "{value}"', - output_fields=["id"] - ) - if result: - return [item["id"] for item in result] - else: - return None - - def get_ids_by_doc_ids(self, doc_ids: list): - result = self.col.query( - expr=f'metadata["doc_id"] in {doc_ids}', - output_fields=["id"] - ) - if result: - return [item["id"] for item in result] - else: - return None - - def delete(self): - from pymilvus import utility - utility.drop_collection(self.collection_name, None, self.alias) - diff --git a/api/core/vector_store/qdrant_vector_store.py b/api/core/vector_store/qdrant_vector_store.py deleted file mode 100644 index 53ad7b2aae..0000000000 --- a/api/core/vector_store/qdrant_vector_store.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Any, cast - -from langchain.schema import Document -from qdrant_client.http.models import Filter, FilterSelector, PointIdsList -from qdrant_client.local.qdrant_local import QdrantLocal - -from core.vector_store.vector.qdrant import Qdrant - - -class QdrantVectorStore(Qdrant): - def del_texts(self, filter: Filter): - if not filter: - raise ValueError('filter must not be empty') - - self._reload_if_needed() - - self.client.delete( - collection_name=self.collection_name, - points_selector=FilterSelector( - filter=filter - ), - ) - - def del_text(self, uuid: str) -> None: - self._reload_if_needed() - - self.client.delete( - collection_name=self.collection_name, - points_selector=PointIdsList( - points=[uuid], - ), - ) - - def text_exists(self, uuid: str) -> bool: - self._reload_if_needed() - - response = self.client.retrieve( - collection_name=self.collection_name, - ids=[uuid] - ) - - return len(response) > 0 - - def delete(self): - self._reload_if_needed() - - self.client.delete_collection(collection_name=self.collection_name) - - def delete_group(self): - self._reload_if_needed() - - self.client.delete_collection(collection_name=self.collection_name) - - @classmethod - def _document_from_scored_point( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, - ) -> Document: - if scored_point.payload.get('doc_id'): - return Document( - page_content=scored_point.payload.get(content_payload_key), - metadata={'doc_id': scored_point.id} - ) - - return Document( - page_content=scored_point.payload.get(content_payload_key), - metadata=scored_point.payload.get(metadata_payload_key) or {}, - ) - - def _reload_if_needed(self): - if isinstance(self.client, QdrantLocal): - self.client = cast(QdrantLocal, self.client) - self.client._load() - diff --git a/api/core/vector_store/vector/milvus.py b/api/core/vector_store/vector/milvus.py deleted file mode 100644 index 9d8695dc5a..0000000000 --- a/api/core/vector_store/vector/milvus.py +++ /dev/null @@ -1,852 +0,0 @@ -"""Wrapper around the Milvus vector database.""" -from __future__ import annotations - -import logging -from collections.abc import Iterable, Sequence -from typing import Any, Optional, Union -from uuid import uuid4 - -import numpy as np -from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings -from langchain.vectorstores.base import VectorStore -from langchain.vectorstores.utils import maximal_marginal_relevance - -logger = logging.getLogger(__name__) - -DEFAULT_MILVUS_CONNECTION = { - "host": "localhost", - "port": "19530", - "user": "", - "password": "", - "secure": False, -} - - -class Milvus(VectorStore): - """Initialize wrapper around the milvus vector database. - - In order to use this you need to have `pymilvus` installed and a - running Milvus - - See the following documentation for how to run a Milvus instance: - https://milvus.io/docs/install_standalone-docker.md - - If looking for a hosted Milvus, take a look at this documentation: - https://zilliz.com/cloud and make use of the Zilliz vectorstore found in - this project, - - IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA. - - Args: - embedding_function (Embeddings): Function used to embed the text. - collection_name (str): Which Milvus collection to use. Defaults to - "LangChainCollection". - connection_args (Optional[dict[str, any]]): The connection args used for - this class comes in the form of a dict. - consistency_level (str): The consistency level to use for a collection. - Defaults to "Session". - index_params (Optional[dict]): Which index params to use. Defaults to - HNSW/AUTOINDEX depending on service. - search_params (Optional[dict]): Which search params to use. Defaults to - default of index. - drop_old (Optional[bool]): Whether to drop the current collection. Defaults - to False. - - The connection args used for this class comes in the form of a dict, - here are a few of the options: - address (str): The actual address of Milvus - instance. Example address: "localhost:19530" - uri (str): The uri of Milvus instance. Example uri: - "http://randomwebsite:19530", - "tcp:foobarsite:19530", - "https://ok.s3.south.com:19530". - host (str): The host of Milvus instance. Default at "localhost", - PyMilvus will fill in the default host if only port is provided. - port (str/int): The port of Milvus instance. Default at 19530, PyMilvus - will fill in the default port if only host is provided. - user (str): Use which user to connect to Milvus instance. If user and - password are provided, we will add related header in every RPC call. - password (str): Required when user is provided. The password - corresponding to the user. - secure (bool): Default is false. If set to true, tls will be enabled. - client_key_path (str): If use tls two-way authentication, need to - write the client.key path. - client_pem_path (str): If use tls two-way authentication, need to - write the client.pem path. - ca_pem_path (str): If use tls two-way authentication, need to write - the ca.pem path. - server_pem_path (str): If use tls one-way authentication, need to - write the server.pem path. - server_name (str): If use tls, need to write the common name. - - Example: - .. code-block:: python - - from langchain import Milvus - from langchain.embeddings import OpenAIEmbeddings - - embedding = OpenAIEmbeddings() - # Connect to a milvus instance on localhost - milvus_store = Milvus( - embedding_function = Embeddings, - collection_name = "LangChainCollection", - drop_old = True, - ) - - Raises: - ValueError: If the pymilvus python package is not installed. - """ - - def __init__( - self, - embedding_function: Embeddings, - collection_name: str = "LangChainCollection", - connection_args: Optional[dict[str, Any]] = None, - consistency_level: str = "Session", - index_params: Optional[dict] = None, - search_params: Optional[dict] = None, - drop_old: Optional[bool] = False, - ): - """Initialize the Milvus vector store.""" - try: - from pymilvus import Collection, utility - except ImportError: - raise ValueError( - "Could not import pymilvus python package. " - "Please install it with `pip install pymilvus`." - ) - - # Default search params when one is not provided. - self.default_search_params = { - "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, - "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, - "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, - "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, - "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, - "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, - "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, - "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, - "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, - "AUTOINDEX": {"metric_type": "L2", "params": {}}, - } - - self.embedding_func = embedding_function - self.collection_name = collection_name - self.index_params = index_params - self.search_params = search_params - self.consistency_level = consistency_level - - # In order for a collection to be compatible, pk needs to be auto'id and int - self._primary_field = "id" - # In order for compatibility, the text field will need to be called "text" - self._text_field = "page_content" - # In order for compatibility, the vector field needs to be called "vector" - self._vector_field = "vectors" - # In order for compatibility, the metadata field will need to be called "metadata" - self._metadata_field = "metadata" - self.fields: list[str] = [] - # Create the connection to the server - if connection_args is None: - connection_args = DEFAULT_MILVUS_CONNECTION - self.alias = self._create_connection_alias(connection_args) - self.col: Optional[Collection] = None - - # Grab the existing collection if it exists - if utility.has_collection(self.collection_name, using=self.alias): - self.col = Collection( - self.collection_name, - using=self.alias, - ) - # If need to drop old, drop it - if drop_old and isinstance(self.col, Collection): - self.col.drop() - self.col = None - - # Initialize the vector store - self._init() - - @property - def embeddings(self) -> Embeddings: - return self.embedding_func - - def _create_connection_alias(self, connection_args: dict) -> str: - """Create the connection to the Milvus server.""" - from pymilvus import MilvusException, connections - - # Grab the connection arguments that are used for checking existing connection - host: str = connection_args.get("host", None) - port: Union[str, int] = connection_args.get("port", None) - address: str = connection_args.get("address", None) - uri: str = connection_args.get("uri", None) - user = connection_args.get("user", None) - - # Order of use is host/port, uri, address - if host is not None and port is not None: - given_address = str(host) + ":" + str(port) - elif uri is not None: - given_address = uri.split("https://")[1] - elif address is not None: - given_address = address - else: - given_address = None - logger.debug("Missing standard address type for reuse atttempt") - - # User defaults to empty string when getting connection info - if user is not None: - tmp_user = user - else: - tmp_user = "" - - # If a valid address was given, then check if a connection exists - if given_address is not None: - for con in connections.list_connections(): - addr = connections.get_connection_addr(con[0]) - if ( - con[1] - and ("address" in addr) - and (addr["address"] == given_address) - and ("user" in addr) - and (addr["user"] == tmp_user) - ): - logger.debug("Using previous connection: %s", con[0]) - return con[0] - - # Generate a new connection if one doesn't exist - alias = uuid4().hex - try: - connections.connect(alias=alias, **connection_args) - logger.debug("Created new connection using: %s", alias) - return alias - except MilvusException as e: - logger.error("Failed to create new connection using: %s", alias) - raise e - - def _init( - self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None - ) -> None: - if embeddings is not None: - self._create_collection(embeddings, metadatas) - self._extract_fields() - self._create_index() - self._create_search_params() - self._load() - - def _create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None - ) -> None: - from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException - from pymilvus.orm.types import infer_dtype_bydata - - # Determine embedding dim - dim = len(embeddings[0]) - fields = [] - # Determine metadata schema - # if metadatas: - # # Create FieldSchema for each entry in metadata. - # for key, value in metadatas[0].items(): - # # Infer the corresponding datatype of the metadata - # dtype = infer_dtype_bydata(value) - # # Datatype isn't compatible - # if dtype == DataType.UNKNOWN or dtype == DataType.NONE: - # logger.error( - # "Failure to create collection, unrecognized dtype for key: %s", - # key, - # ) - # raise ValueError(f"Unrecognized datatype for {key}.") - # # Dataype is a string/varchar equivalent - # elif dtype == DataType.VARCHAR: - # fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) - # else: - # fields.append(FieldSchema(key, dtype)) - if metadatas: - fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535)) - - # Create the text field - fields.append( - FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) - ) - # Create the primary key field - fields.append( - FieldSchema( - self._primary_field, DataType.INT64, is_primary=True, auto_id=True - ) - ) - # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) - ) - - # Create the schema for the collection - schema = CollectionSchema(fields) - - # Create the collection - try: - self.col = Collection( - name=self.collection_name, - schema=schema, - consistency_level=self.consistency_level, - using=self.alias, - ) - except MilvusException as e: - logger.error( - "Failed to create collection: %s error: %s", self.collection_name, e - ) - raise e - - def _extract_fields(self) -> None: - """Grab the existing fields from the Collection""" - from pymilvus import Collection - - if isinstance(self.col, Collection): - schema = self.col.schema - for x in schema.fields: - self.fields.append(x.name) - # Since primary field is auto-id, no need to track it - self.fields.remove(self._primary_field) - - def _get_index(self) -> Optional[dict[str, Any]]: - """Return the vector index information if it exists""" - from pymilvus import Collection - - if isinstance(self.col, Collection): - for x in self.col.indexes: - if x.field_name == self._vector_field: - return x.to_dict() - return None - - def _create_index(self) -> None: - """Create a index on the collection""" - from pymilvus import Collection, MilvusException - - if isinstance(self.col, Collection) and self._get_index() is None: - try: - # If no index params, use a default HNSW based one - if self.index_params is None: - self.index_params = { - "metric_type": "IP", - "index_type": "HNSW", - "params": {"M": 8, "efConstruction": 64}, - } - - try: - self.col.create_index( - self._vector_field, - index_params=self.index_params, - using=self.alias, - ) - - # If default did not work, most likely on Zilliz Cloud - except MilvusException: - # Use AUTOINDEX based index - self.index_params = { - "metric_type": "L2", - "index_type": "AUTOINDEX", - "params": {}, - } - self.col.create_index( - self._vector_field, - index_params=self.index_params, - using=self.alias, - ) - logger.debug( - "Successfully created an index on collection: %s", - self.collection_name, - ) - - except MilvusException as e: - logger.error( - "Failed to create an index on collection: %s", self.collection_name - ) - raise e - - def _create_search_params(self) -> None: - """Generate search params based on the current index type""" - from pymilvus import Collection - - if isinstance(self.col, Collection) and self.search_params is None: - index = self._get_index() - if index is not None: - index_type: str = index["index_param"]["index_type"] - metric_type: str = index["index_param"]["metric_type"] - self.search_params = self.default_search_params[index_type] - self.search_params["metric_type"] = metric_type - - def _load(self) -> None: - """Load the collection if available.""" - from pymilvus import Collection - - if isinstance(self.col, Collection) and self._get_index() is not None: - self.col.load() - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - timeout: Optional[int] = None, - batch_size: int = 1000, - **kwargs: Any, - ) -> list[str]: - """Insert text data into Milvus. - - Inserting data when the collection has not be made yet will result - in creating a new Collection. The data of the first entity decides - the schema of the new collection, the dim is extracted from the first - embedding and the columns are decided by the first metadata dict. - Metada keys will need to be present for all inserted values. At - the moment there is no None equivalent in Milvus. - - Args: - texts (Iterable[str]): The texts to embed, it is assumed - that they all fit in memory. - metadatas (Optional[List[dict]]): Metadata dicts attached to each of - the texts. Defaults to None. - timeout (Optional[int]): Timeout for each batch insert. Defaults - to None. - batch_size (int, optional): Batch size to use for insertion. - Defaults to 1000. - - Raises: - MilvusException: Failure to add texts - - Returns: - List[str]: The resulting keys for each inserted element. - """ - from pymilvus import Collection, MilvusException - - texts = list(texts) - - try: - embeddings = self.embedding_func.embed_documents(texts) - except NotImplementedError: - embeddings = [self.embedding_func.embed_query(x) for x in texts] - - if len(embeddings) == 0: - logger.debug("Nothing to insert, skipping.") - return [] - - # If the collection hasn't been initialized yet, perform all steps to do so - if not isinstance(self.col, Collection): - self._init(embeddings, metadatas) - - # Dict to hold all insert columns - insert_dict: dict[str, list] = { - self._text_field: texts, - self._vector_field: embeddings, - } - - # Collect the metadata into the insert dict. - # if metadatas is not None: - # for d in metadatas: - # for key, value in d.items(): - # if key in self.fields: - # insert_dict.setdefault(key, []).append(value) - if metadatas is not None: - for d in metadatas: - insert_dict.setdefault(self._metadata_field, []).append(d) - - # Total insert count - vectors: list = insert_dict[self._vector_field] - total_count = len(vectors) - - pks: list[str] = [] - - assert isinstance(self.col, Collection) - for i in range(0, total_count, batch_size): - # Grab end index - end = min(i + batch_size, total_count) - # Convert dict to list of lists batch for insertion - insert_list = [insert_dict[x][i:end] for x in self.fields] - # Insert into the collection. - try: - res: Collection - res = self.col.insert(insert_list, timeout=timeout, **kwargs) - pks.extend(res.primary_keys) - except MilvusException as e: - logger.error( - "Failed to insert batch starting at entity: %s/%s", i, total_count - ) - raise e - return pks - - def similarity_search( - self, - query: str, - k: int = 4, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[Document]: - """Perform a similarity search against the query string. - - Args: - query (str): The text to search. - k (int, optional): How many results to return. Defaults to 4. - param (dict, optional): The search params for the index type. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[Document]: Document results for search. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - res = self.similarity_search_with_score( - query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs - ) - return [doc for doc, _ in res] - - def similarity_search_by_vector( - self, - embedding: list[float], - k: int = 4, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[Document]: - """Perform a similarity search against the query string. - - Args: - embedding (List[float]): The embedding vector to search. - k (int, optional): How many results to return. Defaults to 4. - param (dict, optional): The search params for the index type. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[Document]: Document results for search. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - res = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs - ) - return [doc for doc, _ in res] - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Perform a search on a query string and return results with score. - - For more information about the search parameters, take a look at the pymilvus - documentation found here: - https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md - - Args: - query (str): The text being searched. - k (int, optional): The amount of results to return. Defaults to 4. - param (dict): The search params for the specified index. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[float], List[Tuple[Document, any, any]]: - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - # Embed the query text. - embedding = self.embedding_func.embed_query(query) - - res = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs - ) - return res - - def _similarity_search_with_relevance_scores( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and relevance scores in the range [0, 1]. - - 0 is dissimilar, 1 is most similar. - - Args: - query: input text - k: Number of Documents to return. Defaults to 4. - **kwargs: kwargs to be passed to similarity search. Should include: - score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs - - Returns: - List of Tuples of (doc, similarity_score) - """ - return self.similarity_search_with_score(query, k, **kwargs) - - def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Perform a search on a query string and return results with score. - - For more information about the search parameters, take a look at the pymilvus - documentation found here: - https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md - - Args: - embedding (List[float]): The embedding vector being searched. - k (int, optional): The amount of results to return. Defaults to 4. - param (dict): The search params for the specified index. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[Tuple[Document, float]]: Result doc and score. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - if param is None: - param = self.search_params - - # Determine result metadata fields. - output_fields = self.fields[:] - output_fields.remove(self._vector_field) - - # Perform the search. - res = self.col.search( - data=[embedding], - anns_field=self._vector_field, - param=param, - limit=k, - expr=expr, - output_fields=output_fields, - timeout=timeout, - **kwargs, - ) - # Organize results. - ret = [] - for result in res[0]: - meta = {x: result.entity.get(x) for x in output_fields} - doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata')) - pair = (doc, result.score) - ret.append(pair) - - return ret - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[Document]: - """Perform a search and return results that are reordered by MMR. - - Args: - query (str): The text being searched. - k (int, optional): How many results to give. Defaults to 4. - fetch_k (int, optional): Total results to select k from. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5 - param (dict, optional): The search params for the specified index. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - - Returns: - List[Document]: Document results for search. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - embedding = self.embedding_func.embed_query(query) - - return self.max_marginal_relevance_search_by_vector( - embedding=embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - param=param, - expr=expr, - timeout=timeout, - **kwargs, - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[Document]: - """Perform a search and return results that are reordered by MMR. - - Args: - embedding (str): The embedding vector being searched. - k (int, optional): How many results to give. Defaults to 4. - fetch_k (int, optional): Total results to select k from. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5 - param (dict, optional): The search params for the specified index. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[Document]: Document results for search. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - if param is None: - param = self.search_params - - # Determine result metadata fields. - output_fields = self.fields[:] - output_fields.remove(self._vector_field) - - # Perform the search. - res = self.col.search( - data=[embedding], - anns_field=self._vector_field, - param=param, - limit=fetch_k, - expr=expr, - output_fields=output_fields, - timeout=timeout, - **kwargs, - ) - # Organize results. - ids = [] - documents = [] - scores = [] - for result in res[0]: - meta = {x: result.entity.get(x) for x in output_fields} - doc = Document(page_content=meta.pop(self._text_field), metadata=meta) - documents.append(doc) - scores.append(result.score) - ids.append(result.id) - - vectors = self.col.query( - expr=f"{self._primary_field} in {ids}", - output_fields=[self._primary_field, self._vector_field], - timeout=timeout, - ) - # Reorganize the results from query to match search order. - vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors} - - ordered_result_embeddings = [vectors[x] for x in ids] - - # Get the new order of results. - new_ordering = maximal_marginal_relevance( - np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult - ) - - # Reorder the values and return. - ret = [] - for x in new_ordering: - # Function can return -1 index - if x == -1: - break - else: - ret.append(documents[x]) - return ret - - @classmethod - def from_texts( - cls, - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - collection_name: str = "LangChainCollection", - connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION, - consistency_level: str = "Session", - index_params: Optional[dict] = None, - search_params: Optional[dict] = None, - drop_old: bool = False, - batch_size: int = 100, - ids: Optional[Sequence[str]] = None, - **kwargs: Any, - ) -> Milvus: - """Create a Milvus collection, indexes it with HNSW, and insert data. - - Args: - texts (List[str]): Text data. - embedding (Embeddings): Embedding function. - metadatas (Optional[List[dict]]): Metadata for each text if it exists. - Defaults to None. - collection_name (str, optional): Collection name to use. Defaults to - "LangChainCollection". - connection_args (dict[str, Any], optional): Connection args to use. Defaults - to DEFAULT_MILVUS_CONNECTION. - consistency_level (str, optional): Which consistency level to use. Defaults - to "Session". - index_params (Optional[dict], optional): Which index_params to use. Defaults - to None. - search_params (Optional[dict], optional): Which search params to use. - Defaults to None. - drop_old (Optional[bool], optional): Whether to drop the collection with - that name if it exists. Defaults to False. - batch_size: - How many vectors upload per-request. - Default: 100 - ids: Optional[Sequence[str]] = None, - - Returns: - Milvus: Milvus Vector Store - """ - vector_db = cls( - embedding_function=embedding, - collection_name=collection_name, - connection_args=connection_args, - consistency_level=consistency_level, - index_params=index_params, - search_params=search_params, - drop_old=drop_old, - **kwargs, - ) - vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size) - return vector_db diff --git a/api/core/vector_store/vector/qdrant.py b/api/core/vector_store/vector/qdrant.py deleted file mode 100644 index 47e5fe27c7..0000000000 --- a/api/core/vector_store/vector/qdrant.py +++ /dev/null @@ -1,1759 +0,0 @@ -"""Wrapper around Qdrant vector database.""" -from __future__ import annotations - -import asyncio -import functools -import uuid -import warnings -from collections.abc import Callable, Generator, Iterable, Sequence -from itertools import islice -from operator import itemgetter -from typing import TYPE_CHECKING, Any, Optional, Union - -import numpy as np -from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings -from langchain.vectorstores import VectorStore -from langchain.vectorstores.utils import maximal_marginal_relevance -from qdrant_client.http.models import PayloadSchemaType, TextIndexParams, TextIndexType, TokenizerType - -if TYPE_CHECKING: - from qdrant_client import grpc # noqa - from qdrant_client.conversions import common_types - from qdrant_client.http import models as rest - - DictFilter = dict[str, Union[str, int, bool, dict, list]] - MetadataFilter = Union[DictFilter, common_types.Filter] - - -class QdrantException(Exception): - """Base class for all the Qdrant related exceptions""" - - -def sync_call_fallback(method: Callable) -> Callable: - """ - Decorator to call the synchronous method of the class if the async method is not - implemented. This decorator might be only used for the methods that are defined - as async in the class. - """ - - @functools.wraps(method) - async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - try: - return await method(self, *args, **kwargs) - except NotImplementedError: - # If the async method is not implemented, call the synchronous method - # by removing the first letter from the method name. For example, - # if the async method is called ``aaad_texts``, the synchronous method - # will be called ``aad_texts``. - sync_method = functools.partial( - getattr(self, method.__name__[1:]), *args, **kwargs - ) - return await asyncio.get_event_loop().run_in_executor(None, sync_method) - - return wrapper - - -class Qdrant(VectorStore): - """Wrapper around Qdrant vector database. - - To use you should have the ``qdrant-client`` package installed. - - Example: - .. code-block:: python - - from qdrant_client import QdrantClient - from langchain import Qdrant - - client = QdrantClient() - collection_name = "MyCollection" - qdrant = Qdrant(client, collection_name, embedding_function) - """ - - CONTENT_KEY = "page_content" - METADATA_KEY = "metadata" - GROUP_KEY = "group_id" - VECTOR_NAME = None - - def __init__( - self, - client: Any, - collection_name: str, - embeddings: Optional[Embeddings] = None, - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - group_payload_key: str = GROUP_KEY, - group_id: str = None, - distance_strategy: str = "COSINE", - vector_name: Optional[str] = VECTOR_NAME, - embedding_function: Optional[Callable] = None, # deprecated - is_new_collection: bool = False - ): - """Initialize with necessary components.""" - try: - import qdrant_client - except ImportError: - raise ValueError( - "Could not import qdrant-client python package. " - "Please install it with `pip install qdrant-client`." - ) - - if not isinstance(client, qdrant_client.QdrantClient): - raise ValueError( - f"client should be an instance of qdrant_client.QdrantClient, " - f"got {type(client)}" - ) - - if embeddings is None and embedding_function is None: - raise ValueError( - "`embeddings` value can't be None. Pass `Embeddings` instance." - ) - - if embeddings is not None and embedding_function is not None: - raise ValueError( - "Both `embeddings` and `embedding_function` are passed. " - "Use `embeddings` only." - ) - - self._embeddings = embeddings - self._embeddings_function = embedding_function - self.client: qdrant_client.QdrantClient = client - self.collection_name = collection_name - self.content_payload_key = content_payload_key or self.CONTENT_KEY - self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY - self.group_payload_key = group_payload_key or self.GROUP_KEY - self.vector_name = vector_name or self.VECTOR_NAME - self.group_id = group_id - self.is_new_collection= is_new_collection - - if embedding_function is not None: - warnings.warn( - "Using `embedding_function` is deprecated. " - "Pass `Embeddings` instance to `embeddings` instead." - ) - - if not isinstance(embeddings, Embeddings): - warnings.warn( - "`embeddings` should be an instance of `Embeddings`." - "Using `embeddings` as `embedding_function` which is deprecated" - ) - self._embeddings_function = embeddings - self._embeddings = None - - self.distance_strategy = distance_strategy.upper() - - @property - def embeddings(self) -> Optional[Embeddings]: - return self._embeddings - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - **kwargs: Any, - ) -> list[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - batch_size: - How many vectors upload per-request. - Default: 64 - group_id: - collection group - - Returns: - List of ids from adding the texts into the vectorstore. - """ - added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, metadatas, ids, batch_size - ): - self.client.upsert( - collection_name=self.collection_name, points=points - ) - added_ids.extend(batch_ids) - # if is new collection, create payload index on group_id - if self.is_new_collection: - # create payload index - self.client.create_payload_index(self.collection_name, self.group_payload_key, - field_schema=PayloadSchemaType.KEYWORD, - field_type=PayloadSchemaType.KEYWORD) - # creat full text index - text_index_params = TextIndexParams( - type=TextIndexType.TEXT, - tokenizer=TokenizerType.MULTILINGUAL, - min_token_len=2, - max_token_len=20, - lowercase=True - ) - self.client.create_payload_index(self.collection_name, self.content_payload_key, - field_schema=text_index_params) - return added_ids - - @sync_call_fallback - async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - **kwargs: Any, - ) -> list[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - batch_size: - How many vectors upload per-request. - Default: 64 - - Returns: - List of ids from adding the texts into the vectorstore. - """ - from qdrant_client import grpc # noqa - from qdrant_client.conversions.conversion import RestToGrpc - - added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, metadatas, ids, batch_size - ): - await self.client.async_grpc_points.Upsert( - grpc.UpsertPoints( - collection_name=self.collection_name, - points=[RestToGrpc.convert_point_struct(point) for point in points], - ) - ) - added_ids.extend(batch_ids) - - return added_ids - - def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of Documents most similar to the query. - """ - results = self.similarity_search_with_score( - query, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - @sync_call_fallback - async def asimilarity_search( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs most similar to query. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - Returns: - List of Documents most similar to the query. - """ - results = await self.asimilarity_search_with_score(query, k, filter, **kwargs) - return list(map(itemgetter(0), results)) - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of documents most similar to the query text and distance for each. - """ - return self.similarity_search_with_score_by_vector( - self._embed_query(query), - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - - @sync_call_fallback - async def asimilarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of documents most similar to the query text and distance for each. - """ - return await self.asimilarity_search_with_score_by_vector( - self._embed_query(query), - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - - def similarity_search_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of Documents most similar to the query. - """ - results = self.similarity_search_with_score_by_vector( - embedding, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - @sync_call_fallback - async def asimilarity_search_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of Documents most similar to the query. - """ - results = await self.asimilarity_search_with_score_by_vector( - embedding, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of documents most similar to the query text and distance for each. - """ - if filter is not None and isinstance(filter, dict): - warnings.warn( - "Using dict as a `filter` is deprecated. Please use qdrant-client " - "filters directly: " - "https://qdrant.tech/documentation/concepts/filtering/", - DeprecationWarning, - ) - qdrant_filter = self._qdrant_filter_from_dict(filter) - else: - qdrant_filter = filter - - query_vector = embedding - if self.vector_name is not None: - query_vector = (self.vector_name, embedding) # type: ignore[assignment] - - results = self.client.search( - collection_name=self.collection_name, - query_vector=query_vector, - query_filter=qdrant_filter, - search_params=search_params, - limit=k, - offset=offset, - with_payload=True, - with_vectors=True, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return [ - ( - self._document_from_scored_point( - result, self.content_payload_key, self.metadata_payload_key - ), - result.score, - ) - for result in results - ] - - def similarity_search_by_bm25( - self, - filter: Optional[MetadataFilter] = None, - k: int = 4 - ) -> list[Document]: - """Return docs most similar by bm25. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - Returns: - List of documents most similar to the query text and distance for each. - """ - response = self.client.scroll( - collection_name=self.collection_name, - scroll_filter=filter, - limit=k, - with_payload=True, - with_vectors=True - - ) - results = response[0] - documents = [] - for result in results: - if result: - documents.append(self._document_from_scored_point( - result, self.content_payload_key, self.metadata_payload_key - )) - - return documents - - @sync_call_fallback - async def asimilarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of documents most similar to the query text and distance for each. - """ - from qdrant_client import grpc # noqa - from qdrant_client.conversions.conversion import RestToGrpc - from qdrant_client.http import models as rest - - if filter is not None and isinstance(filter, dict): - warnings.warn( - "Using dict as a `filter` is deprecated. Please use qdrant-client " - "filters directly: " - "https://qdrant.tech/documentation/concepts/filtering/", - DeprecationWarning, - ) - qdrant_filter = self._qdrant_filter_from_dict(filter) - else: - qdrant_filter = filter - - if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter): - qdrant_filter = RestToGrpc.convert_filter(qdrant_filter) - - response = await self.client.async_grpc_points.Search( - grpc.SearchPoints( - collection_name=self.collection_name, - vector_name=self.vector_name, - vector=embedding, - filter=qdrant_filter, - params=search_params, - limit=k, - offset=offset, - with_payload=grpc.WithPayloadSelector(enable=True), - with_vectors=grpc.WithVectorsSelector(enable=False), - score_threshold=score_threshold, - read_consistency=consistency, - **kwargs, - ) - ) - - return [ - ( - self._document_from_scored_point_grpc( - result, self.content_payload_key, self.metadata_payload_key - ), - result.score, - ) - for result in response.result - ] - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - query_embedding = self._embed_query(query) - return self.max_marginal_relevance_search_by_vector( - query_embedding, k, fetch_k, lambda_mult, **kwargs - ) - - @sync_call_fallback - async def amax_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - query_embedding = self._embed_query(query) - return await self.amax_marginal_relevance_search_by_vector( - query_embedding, k, fetch_k, lambda_mult, **kwargs - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - results = self.max_marginal_relevance_search_with_score_by_vector( - embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs - ) - return list(map(itemgetter(0), results)) - - @sync_call_fallback - async def amax_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance and distance for - each. - """ - results = await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, k, fetch_k, lambda_mult, **kwargs - ) - return list(map(itemgetter(0), results)) - - def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance and distance for - each. - """ - query_vector = embedding - if self.vector_name is not None: - query_vector = (self.vector_name, query_vector) # type: ignore[assignment] - - results = self.client.search( - collection_name=self.collection_name, - query_vector=query_vector, - with_payload=True, - with_vectors=True, - limit=fetch_k, - ) - embeddings = [ - result.vector.get(self.vector_name) # type: ignore[index, union-attr] - if self.vector_name is not None - else result.vector - for result in results - ] - mmr_selected = maximal_marginal_relevance( - np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult - ) - return [ - ( - self._document_from_scored_point( - results[i], self.content_payload_key, self.metadata_payload_key - ), - results[i].score, - ) - for i in mmr_selected - ] - - @sync_call_fallback - async def amax_marginal_relevance_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance and distance for - each. - """ - from qdrant_client import grpc # noqa - from qdrant_client.conversions.conversion import GrpcToRest - - response = await self.client.async_grpc_points.Search( - grpc.SearchPoints( - collection_name=self.collection_name, - vector_name=self.vector_name, - vector=embedding, - with_payload=grpc.WithPayloadSelector(enable=True), - with_vectors=grpc.WithVectorsSelector(enable=True), - limit=fetch_k, - ) - ) - results = [ - GrpcToRest.convert_vectors(result.vectors) for result in response.result - ] - embeddings: list[list[float]] = [ - result.get(self.vector_name) # type: ignore - if isinstance(result, dict) - else result - for result in results - ] - mmr_selected: list[int] = maximal_marginal_relevance( - np.array(embedding), - embeddings, - k=k, - lambda_mult=lambda_mult, - ) - return [ - ( - self._document_from_scored_point_grpc( - response.result[i], - self.content_payload_key, - self.metadata_payload_key, - ), - response.result[i].score, - ) - for i in mmr_selected - ] - - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]: - """Delete by vector ID or other criteria. - - Args: - ids: List of ids to delete. - **kwargs: Other keyword arguments that subclasses might use. - - Returns: - Optional[bool]: True if deletion is successful, - False otherwise, None if not implemented. - """ - from qdrant_client.http import models as rest - - result = self.client.delete( - collection_name=self.collection_name, - points_selector=ids, - ) - return result.status == rest.UpdateStatus.COMPLETED - - @classmethod - def from_texts( - cls: type[Qdrant], - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - group_payload_key: str = GROUP_KEY, - group_id: str = None, - vector_name: Optional[str] = VECTOR_NAME, - batch_size: int = 64, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - """Construct Qdrant wrapper from a list of texts. - - Args: - texts: A list of texts to be indexed in Qdrant. - embedding: A subclass of `Embeddings`, responsible for text vectorization. - metadatas: - An optional list of metadata. If provided it has to be of the same - length as a list of texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - location: - If `:memory:` - use in-memory Qdrant instance. - If `str` - use it as a `url` parameter. - If `None` - fallback to relying on `host` and `port` parameters. - url: either host or str of "Optional[scheme], host, Optional[port], - Optional[prefix]". Default: `None` - port: Port of the REST API interface. Default: 6333 - grpc_port: Port of the gRPC interface. Default: 6334 - prefer_grpc: - If true - use gPRC interface whenever possible in custom methods. - Default: False - https: If true - use HTTPS(SSL) protocol. Default: None - api_key: API key for authentication in Qdrant Cloud. Default: None - prefix: - If not None - add prefix to the REST URL path. - Example: service/v1 will result in - http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. - Default: None - timeout: - Timeout for REST and gRPC API requests. - Default: 5.0 seconds for REST and unlimited for gRPC - host: - Host name of Qdrant service. If url and host are None, set to - 'localhost'. Default: None - path: - Path in which the vectors will be stored while using local mode. - Default: None - collection_name: - Name of the Qdrant collection to be used. If not provided, - it will be created randomly. Default: None - distance_func: - Distance function. One of: "Cosine" / "Euclid" / "Dot". - Default: "Cosine" - content_payload_key: - A payload key used to store the content of the document. - Default: "page_content" - metadata_payload_key: - A payload key used to store the metadata of the document. - Default: "metadata" - group_payload_key: - A payload key used to store the content of the document. - Default: "group_id" - group_id: - collection group id - vector_name: - Name of the vector to be used internally in Qdrant. - Default: None - batch_size: - How many vectors upload per-request. - Default: 64 - shard_number: Number of shards in collection. Default is 1, minimum is 1. - replication_factor: - Replication factor for collection. Default is 1, minimum is 1. - Defines how many copies of each shard will be created. - Have effect only in distributed mode. - write_consistency_factor: - Write consistency factor for collection. Default is 1, minimum is 1. - Defines how many replicas should apply the operation for us to consider - it successful. Increasing this number will make the collection more - resilient to inconsistencies, but will also make it fail if not enough - replicas are available. - Does not have any performance impact. - Have effect only in distributed mode. - on_disk_payload: - If true - point`s payload will not be stored in memory. - It will be read from the disk every time it is requested. - This setting saves RAM by (slightly) increasing the response time. - Note: those payload values that are involved in filtering and are - indexed - remain in RAM. - hnsw_config: Params for HNSW index - optimizers_config: Params for optimizer - wal_config: Params for Write-Ahead-Log - quantization_config: - Params for quantization, if None - quantization will be disabled - init_from: - Use data stored in another collection to initialize this collection - force_recreate: - Force recreating the collection - **kwargs: - Additional arguments passed directly into REST client initialization - - This is a user-friendly interface that: - 1. Creates embeddings, one for each text - 2. Initializes the Qdrant database as an in-memory docstore by default - (and overridable to a remote docstore) - 3. Adds the text embeddings to the Qdrant database - - This is intended to be a quick way to get started. - - Example: - .. code-block:: python - - from langchain import Qdrant - from langchain.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - qdrant = Qdrant.from_texts(texts, embeddings, "localhost") - """ - qdrant = cls._construct_instance( - texts, - embedding, - metadatas, - ids, - location, - url, - port, - grpc_port, - prefer_grpc, - https, - api_key, - prefix, - timeout, - host, - path, - collection_name, - distance_func, - content_payload_key, - metadata_payload_key, - group_payload_key, - group_id, - vector_name, - shard_number, - replication_factor, - write_consistency_factor, - on_disk_payload, - hnsw_config, - optimizers_config, - wal_config, - quantization_config, - init_from, - force_recreate, - **kwargs, - ) - qdrant.add_texts(texts, metadatas, ids, batch_size) - return qdrant - - @classmethod - @sync_call_fallback - async def afrom_texts( - cls: type[Qdrant], - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, - batch_size: int = 64, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - """Construct Qdrant wrapper from a list of texts. - - Args: - texts: A list of texts to be indexed in Qdrant. - embedding: A subclass of `Embeddings`, responsible for text vectorization. - metadatas: - An optional list of metadata. If provided it has to be of the same - length as a list of texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - location: - If `:memory:` - use in-memory Qdrant instance. - If `str` - use it as a `url` parameter. - If `None` - fallback to relying on `host` and `port` parameters. - url: either host or str of "Optional[scheme], host, Optional[port], - Optional[prefix]". Default: `None` - port: Port of the REST API interface. Default: 6333 - grpc_port: Port of the gRPC interface. Default: 6334 - prefer_grpc: - If true - use gPRC interface whenever possible in custom methods. - Default: False - https: If true - use HTTPS(SSL) protocol. Default: None - api_key: API key for authentication in Qdrant Cloud. Default: None - prefix: - If not None - add prefix to the REST URL path. - Example: service/v1 will result in - http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. - Default: None - timeout: - Timeout for REST and gRPC API requests. - Default: 5.0 seconds for REST and unlimited for gRPC - host: - Host name of Qdrant service. If url and host are None, set to - 'localhost'. Default: None - path: - Path in which the vectors will be stored while using local mode. - Default: None - collection_name: - Name of the Qdrant collection to be used. If not provided, - it will be created randomly. Default: None - distance_func: - Distance function. One of: "Cosine" / "Euclid" / "Dot". - Default: "Cosine" - content_payload_key: - A payload key used to store the content of the document. - Default: "page_content" - metadata_payload_key: - A payload key used to store the metadata of the document. - Default: "metadata" - vector_name: - Name of the vector to be used internally in Qdrant. - Default: None - batch_size: - How many vectors upload per-request. - Default: 64 - shard_number: Number of shards in collection. Default is 1, minimum is 1. - replication_factor: - Replication factor for collection. Default is 1, minimum is 1. - Defines how many copies of each shard will be created. - Have effect only in distributed mode. - write_consistency_factor: - Write consistency factor for collection. Default is 1, minimum is 1. - Defines how many replicas should apply the operation for us to consider - it successful. Increasing this number will make the collection more - resilient to inconsistencies, but will also make it fail if not enough - replicas are available. - Does not have any performance impact. - Have effect only in distributed mode. - on_disk_payload: - If true - point`s payload will not be stored in memory. - It will be read from the disk every time it is requested. - This setting saves RAM by (slightly) increasing the response time. - Note: those payload values that are involved in filtering and are - indexed - remain in RAM. - hnsw_config: Params for HNSW index - optimizers_config: Params for optimizer - wal_config: Params for Write-Ahead-Log - quantization_config: - Params for quantization, if None - quantization will be disabled - init_from: - Use data stored in another collection to initialize this collection - force_recreate: - Force recreating the collection - **kwargs: - Additional arguments passed directly into REST client initialization - - This is a user-friendly interface that: - 1. Creates embeddings, one for each text - 2. Initializes the Qdrant database as an in-memory docstore by default - (and overridable to a remote docstore) - 3. Adds the text embeddings to the Qdrant database - - This is intended to be a quick way to get started. - - Example: - .. code-block:: python - - from langchain import Qdrant - from langchain.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost") - """ - qdrant = cls._construct_instance( - texts, - embedding, - metadatas, - ids, - location, - url, - port, - grpc_port, - prefer_grpc, - https, - api_key, - prefix, - timeout, - host, - path, - collection_name, - distance_func, - content_payload_key, - metadata_payload_key, - vector_name, - shard_number, - replication_factor, - write_consistency_factor, - on_disk_payload, - hnsw_config, - optimizers_config, - wal_config, - quantization_config, - init_from, - force_recreate, - **kwargs, - ) - await qdrant.aadd_texts(texts, metadatas, ids, batch_size) - return qdrant - - @classmethod - def _construct_instance( - cls: type[Qdrant], - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - group_payload_key: str = GROUP_KEY, - group_id: str = None, - vector_name: Optional[str] = VECTOR_NAME, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - try: - import qdrant_client - except ImportError: - raise ValueError( - "Could not import qdrant-client python package. " - "Please install it with `pip install qdrant-client`." - ) - from qdrant_client.http import models as rest - - # Just do a single quick embedding to get vector size - partial_embeddings = embedding.embed_documents(texts[:1]) - vector_size = len(partial_embeddings[0]) - collection_name = collection_name or uuid.uuid4().hex - distance_func = distance_func.upper() - is_new_collection = False - client = qdrant_client.QdrantClient( - location=location, - url=url, - port=port, - grpc_port=grpc_port, - prefer_grpc=prefer_grpc, - https=https, - api_key=api_key, - prefix=prefix, - timeout=timeout, - host=host, - path=path, - **kwargs, - ) - all_collection_name = [] - collections_response = client.get_collections() - collection_list = collections_response.collections - for collection in collection_list: - all_collection_name.append(collection.name) - if collection_name not in all_collection_name: - vectors_config = rest.VectorParams( - size=vector_size, - distance=rest.Distance[distance_func], - ) - - # If vector name was provided, we're going to use the named vectors feature - # with just a single vector. - if vector_name is not None: - vectors_config = { # type: ignore[assignment] - vector_name: vectors_config, - } - - client.recreate_collection( - collection_name=collection_name, - vectors_config=vectors_config, - shard_number=shard_number, - replication_factor=replication_factor, - write_consistency_factor=write_consistency_factor, - on_disk_payload=on_disk_payload, - hnsw_config=hnsw_config, - optimizers_config=optimizers_config, - wal_config=wal_config, - quantization_config=quantization_config, - init_from=init_from, - timeout=int(timeout), # type: ignore[arg-type] - ) - is_new_collection = True - if force_recreate: - raise ValueError - - # Get the vector configuration of the existing collection and vector, if it - # was specified. If the old configuration does not match the current one, - # an exception is being thrown. - collection_info = client.get_collection(collection_name=collection_name) - current_vector_config = collection_info.config.params.vectors - if isinstance(current_vector_config, dict) and vector_name is not None: - if vector_name not in current_vector_config: - raise QdrantException( - f"Existing Qdrant collection {collection_name} does not " - f"contain vector named {vector_name}. Did you mean one of the " - f"existing vectors: {', '.join(current_vector_config.keys())}? " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - current_vector_config = current_vector_config.get( - vector_name - ) # type: ignore[assignment] - elif isinstance(current_vector_config, dict) and vector_name is None: - raise QdrantException( - f"Existing Qdrant collection {collection_name} uses named vectors. " - f"If you want to reuse it, please set `vector_name` to any of the " - f"existing named vectors: " - f"{', '.join(current_vector_config.keys())}." # noqa - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - elif ( - not isinstance(current_vector_config, dict) and vector_name is not None - ): - raise QdrantException( - f"Existing Qdrant collection {collection_name} doesn't use named " - f"vectors. If you want to reuse it, please set `vector_name` to " - f"`None`. If you want to recreate the collection, set " - f"`force_recreate` parameter to `True`." - ) - - # Check if the vector configuration has the same dimensionality. - if current_vector_config.size != vector_size: # type: ignore[union-attr] - raise QdrantException( - f"Existing Qdrant collection is configured for vectors with " - f"{current_vector_config.size} " # type: ignore[union-attr] - f"dimensions. Selected embeddings are {vector_size}-dimensional. " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - - current_distance_func = ( - current_vector_config.distance.name.upper() # type: ignore[union-attr] - ) - if current_distance_func != distance_func: - raise QdrantException( - f"Existing Qdrant collection is configured for " - f"{current_vector_config.distance} " # type: ignore[union-attr] - f"similarity. Please set `distance_func` parameter to " - f"`{distance_func}` if you want to reuse it. If you want to " - f"recreate the collection, set `force_recreate` parameter to " - f"`True`." - ) - qdrant = cls( - client=client, - collection_name=collection_name, - embeddings=embedding, - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - distance_strategy=distance_func, - vector_name=vector_name, - group_id=group_id, - group_payload_key=group_payload_key, - is_new_collection=is_new_collection - ) - return qdrant - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - """ - - if self.distance_strategy == "COSINE": - return self._cosine_relevance_score_fn - elif self.distance_strategy == "DOT": - return self._max_inner_product_relevance_score_fn - elif self.distance_strategy == "EUCLID": - return self._euclidean_relevance_score_fn - else: - raise ValueError( - "Unknown distance strategy, must be cosine, " - "max_inner_product, or euclidean" - ) - - def _similarity_search_with_relevance_scores( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and relevance scores in the range [0, 1]. - - 0 is dissimilar, 1 is most similar. - - Args: - query: input text - k: Number of Documents to return. Defaults to 4. - **kwargs: kwargs to be passed to similarity search. Should include: - score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs - - Returns: - List of Tuples of (doc, similarity_score) - """ - return self.similarity_search_with_score(query, k, **kwargs) - - @classmethod - def _build_payloads( - cls, - texts: Iterable[str], - metadatas: Optional[list[dict]], - content_payload_key: str, - metadata_payload_key: str, - group_id: str, - group_payload_key: str - ) -> list[dict]: - payloads = [] - for i, text in enumerate(texts): - if text is None: - raise ValueError( - "At least one of the texts is None. Please remove it before " - "calling .from_texts or .add_texts on Qdrant instance." - ) - metadata = metadatas[i] if metadatas is not None else None - payloads.append( - { - content_payload_key: text, - metadata_payload_key: metadata, - group_payload_key: group_id - } - ) - - return payloads - - @classmethod - def _document_from_scored_point( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, - ) -> Document: - return Document( - page_content=scored_point.payload.get(content_payload_key), - metadata=scored_point.payload.get(metadata_payload_key) or {}, - ) - - @classmethod - def _document_from_scored_point_grpc( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, - ) -> Document: - from qdrant_client.conversions.conversion import grpc_to_payload - - payload = grpc_to_payload(scored_point.payload) - return Document( - page_content=payload[content_payload_key], - metadata=payload.get(metadata_payload_key) or {}, - ) - - def _build_condition(self, key: str, value: Any) -> list[rest.FieldCondition]: - from qdrant_client.http import models as rest - - out = [] - - if isinstance(value, dict): - for _key, value in value.items(): - out.extend(self._build_condition(f"{key}.{_key}", value)) - elif isinstance(value, list): - for _value in value: - if isinstance(_value, dict): - out.extend(self._build_condition(f"{key}[]", _value)) - else: - out.extend(self._build_condition(f"{key}", _value)) - else: - out.append( - rest.FieldCondition( - key=key, - match=rest.MatchValue(value=value), - ) - ) - - return out - - def _qdrant_filter_from_dict( - self, filter: Optional[DictFilter] - ) -> Optional[rest.Filter]: - from qdrant_client.http import models as rest - - if not filter: - return None - - return rest.Filter( - must=[ - condition - for key, value in filter.items() - for condition in self._build_condition(key, value) - ] - ) - - def _embed_query(self, query: str) -> list[float]: - """Embed query text. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - query: Query text. - - Returns: - List of floats representing the query embedding. - """ - if self.embeddings is not None: - embedding = self.embeddings.embed_query(query) - else: - if self._embeddings_function is not None: - embedding = self._embeddings_function(query) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - return embedding.tolist() if hasattr(embedding, "tolist") else embedding - - def _embed_texts(self, texts: Iterable[str]) -> list[list[float]]: - """Embed search texts. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - texts: Iterable of texts to embed. - - Returns: - List of floats representing the texts embedding. - """ - if self.embeddings is not None: - embeddings = self.embeddings.embed_documents(list(texts)) - if hasattr(embeddings, "tolist"): - embeddings = embeddings.tolist() - elif self._embeddings_function is not None: - embeddings = [] - for text in texts: - embedding = self._embeddings_function(text) - if hasattr(embeddings, "tolist"): - embedding = embedding.tolist() - embeddings.append(embedding) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - - return embeddings - - def _generate_rest_batches( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - group_id: Optional[str] = None, - ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: - from qdrant_client.http import models as rest - - texts_iterator = iter(texts) - metadatas_iterator = iter(metadatas or []) - ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) - while batch_texts := list(islice(texts_iterator, batch_size)): - # Take the corresponding metadata and id for each text in a batch - batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None - batch_ids = list(islice(ids_iterator, batch_size)) - - # Generate the embeddings for all the texts in a batch - batch_embeddings = self._embed_texts(batch_texts) - - points = [ - rest.PointStruct( - id=point_id, - vector=vector - if self.vector_name is None - else {self.vector_name: vector}, - payload=payload, - ) - for point_id, vector, payload in zip( - batch_ids, - batch_embeddings, - self._build_payloads( - batch_texts, - batch_metadatas, - self.content_payload_key, - self.metadata_payload_key, - self.group_id, - self.group_payload_key - ), - ) - ] - - yield batch_ids, points diff --git a/api/core/vector_store/vector/weaviate.py b/api/core/vector_store/vector/weaviate.py deleted file mode 100644 index 8ac77152e1..0000000000 --- a/api/core/vector_store/vector/weaviate.py +++ /dev/null @@ -1,506 +0,0 @@ -"""Wrapper around weaviate vector database.""" -from __future__ import annotations - -import datetime -from collections.abc import Callable, Iterable -from typing import Any, Optional -from uuid import uuid4 - -import numpy as np -from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings -from langchain.utils import get_from_dict_or_env -from langchain.vectorstores.base import VectorStore -from langchain.vectorstores.utils import maximal_marginal_relevance - - -def _default_schema(index_name: str) -> dict: - return { - "class": index_name, - "properties": [ - { - "name": "text", - "dataType": ["text"], - } - ], - } - - -def _create_weaviate_client(**kwargs: Any) -> Any: - client = kwargs.get("client") - if client is not None: - return client - - weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") - - try: - # the weaviate api key param should not be mandatory - weaviate_api_key = get_from_dict_or_env( - kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None - ) - except ValueError: - weaviate_api_key = None - - try: - import weaviate - except ImportError: - raise ValueError( - "Could not import weaviate python package. " - "Please install it with `pip install weaviate-client`" - ) - - auth = ( - weaviate.auth.AuthApiKey(api_key=weaviate_api_key) - if weaviate_api_key is not None - else None - ) - client = weaviate.Client(weaviate_url, auth_client_secret=auth) - - return client - - -def _default_score_normalizer(val: float) -> float: - return 1 - val - - -def _json_serializable(value: Any) -> Any: - if isinstance(value, datetime.datetime): - return value.isoformat() - return value - - -class Weaviate(VectorStore): - """Wrapper around Weaviate vector database. - - To use, you should have the ``weaviate-client`` python package installed. - - Example: - .. code-block:: python - - import weaviate - from langchain.vectorstores import Weaviate - client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) - weaviate = Weaviate(client, index_name, text_key) - - """ - - def __init__( - self, - client: Any, - index_name: str, - text_key: str, - embedding: Optional[Embeddings] = None, - attributes: Optional[list[str]] = None, - relevance_score_fn: Optional[ - Callable[[float], float] - ] = _default_score_normalizer, - by_text: bool = True, - ): - """Initialize with Weaviate client.""" - try: - import weaviate - except ImportError: - raise ValueError( - "Could not import weaviate python package. " - "Please install it with `pip install weaviate-client`." - ) - if not isinstance(client, weaviate.Client): - raise ValueError( - f"client should be an instance of weaviate.Client, got {type(client)}" - ) - self._client = client - self._index_name = index_name - self._embedding = embedding - self._text_key = text_key - self._query_attrs = [self._text_key] - self.relevance_score_fn = relevance_score_fn - self._by_text = by_text - if attributes is not None: - self._query_attrs.extend(attributes) - - @property - def embeddings(self) -> Optional[Embeddings]: - return self._embedding - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - return ( - self.relevance_score_fn - if self.relevance_score_fn - else _default_score_normalizer - ) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - **kwargs: Any, - ) -> list[str]: - """Upload texts with metadata (properties) to Weaviate.""" - from weaviate.util import get_valid_uuid - - ids = [] - embeddings: Optional[list[list[float]]] = None - if self._embedding: - if not isinstance(texts, list): - texts = list(texts) - embeddings = self._embedding.embed_documents(texts) - - with self._client.batch as batch: - for i, text in enumerate(texts): - data_properties = {self._text_key: text} - if metadatas is not None: - for key, val in metadatas[i].items(): - data_properties[key] = _json_serializable(val) - - # Allow for ids (consistent w/ other methods) - # # Or uuids (backwards compatble w/ existing arg) - # If the UUID of one of the objects already exists - # then the existing object will be replaced by the new object. - _id = get_valid_uuid(uuid4()) - if "uuids" in kwargs: - _id = kwargs["uuids"][i] - elif "ids" in kwargs: - _id = kwargs["ids"][i] - - batch.add_data_object( - data_object=data_properties, - class_name=self._index_name, - uuid=_id, - vector=embeddings[i] if embeddings else None, - ) - ids.append(_id) - return ids - - def similarity_search( - self, query: str, k: int = 4, **kwargs: Any - ) -> list[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query. - """ - if self._by_text: - return self.similarity_search_by_text(query, k, **kwargs) - else: - if self._embedding is None: - raise ValueError( - "_embedding cannot be None for similarity_search when " - "_by_text=False" - ) - embedding = self._embedding.embed_query(query) - return self.similarity_search_by_vector(embedding, k, **kwargs) - - def similarity_search_by_text( - self, query: str, k: int = 4, **kwargs: Any - ) -> list[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query. - """ - content: dict[str, Any] = {"concepts": [query]} - if kwargs.get("search_distance"): - content["certainty"] = kwargs.get("search_distance") - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) - result = query_obj.with_near_text(content).with_limit(k).do() - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - docs = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - docs.append(Document(page_content=text, metadata=res)) - return docs - - def similarity_search_by_bm25( - self, query: str, k: int = 4, **kwargs: Any - ) -> list[Document]: - """Return docs using BM25F. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query. - """ - content: dict[str, Any] = {"concepts": [query]} - if kwargs.get("search_distance"): - content["certainty"] = kwargs.get("search_distance") - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) - properties = ['text'] - result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do() - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - docs = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - docs.append(Document(page_content=text, metadata=res)) - return docs - - def similarity_search_by_vector( - self, embedding: list[float], k: int = 4, **kwargs: Any - ) -> list[Document]: - """Look up similar documents by embedding vector in Weaviate.""" - vector = {"vector": embedding} - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) - result = query_obj.with_near_vector(vector).with_limit(k).do() - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - docs = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - docs.append(Document(page_content=text, metadata=res)) - return docs - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - if self._embedding is not None: - embedding = self._embedding.embed_query(query) - else: - raise ValueError( - "max_marginal_relevance_search requires a suitable Embeddings object" - ) - - return self.max_marginal_relevance_search_by_vector( - embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - vector = {"vector": embedding} - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - results = ( - query_obj.with_additional("vector") - .with_near_vector(vector) - .with_limit(fetch_k) - .do() - ) - - payload = results["data"]["Get"][self._index_name] - embeddings = [result["_additional"]["vector"] for result in payload] - mmr_selected = maximal_marginal_relevance( - np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult - ) - - docs = [] - for idx in mmr_selected: - text = payload[idx].pop(self._text_key) - payload[idx].pop("_additional") - meta = payload[idx] - docs.append(Document(page_content=text, metadata=meta)) - return docs - - def similarity_search_with_score( - self, query: str, k: int = 4, **kwargs: Any - ) -> list[tuple[Document, float]]: - """ - Return list of documents most similar to the query - text and cosine distance in float for each. - Lower score represents more similarity. - """ - if self._embedding is None: - raise ValueError( - "_embedding cannot be None for similarity_search_with_score" - ) - content: dict[str, Any] = {"concepts": [query]} - if kwargs.get("search_distance"): - content["certainty"] = kwargs.get("search_distance") - query_obj = self._client.query.get(self._index_name, self._query_attrs) - - embedded_query = self._embedding.embed_query(query) - if not self._by_text: - vector = {"vector": embedded_query} - result = ( - query_obj.with_near_vector(vector) - .with_limit(k) - .with_additional(["vector", "distance"]) - .do() - ) - else: - result = ( - query_obj.with_near_text(content) - .with_limit(k) - .with_additional(["vector", "distance"]) - .do() - ) - - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - - docs_and_scores = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - score = res["_additional"]["distance"] - docs_and_scores.append((Document(page_content=text, metadata=res), score)) - return docs_and_scores - - @classmethod - def from_texts( - cls: type[Weaviate], - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - **kwargs: Any, - ) -> Weaviate: - """Construct Weaviate wrapper from raw documents. - - This is a user-friendly interface that: - 1. Embeds documents. - 2. Creates a new index for the embeddings in the Weaviate instance. - 3. Adds the documents to the newly created Weaviate index. - - This is intended to be a quick way to get started. - - Example: - .. code-block:: python - - from langchain.vectorstores.weaviate import Weaviate - from langchain.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - weaviate = Weaviate.from_texts( - texts, - embeddings, - weaviate_url="http://localhost:8080" - ) - """ - - client = _create_weaviate_client(**kwargs) - - from weaviate.util import get_valid_uuid - - index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") - embeddings = embedding.embed_documents(texts) if embedding else None - text_key = "text" - schema = _default_schema(index_name) - attributes = list(metadatas[0].keys()) if metadatas else None - - # check whether the index already exists - if not client.schema.contains(schema): - client.schema.create_class(schema) - - with client.batch as batch: - for i, text in enumerate(texts): - data_properties = { - text_key: text, - } - if metadatas is not None: - for key in metadatas[i].keys(): - data_properties[key] = metadatas[i][key] - - # If the UUID of one of the objects already exists - # then the existing objectwill be replaced by the new object. - if "uuids" in kwargs: - _id = kwargs["uuids"][i] - else: - _id = get_valid_uuid(uuid4()) - - # if an embedding strategy is not provided, we let - # weaviate create the embedding. Note that this will only - # work if weaviate has been installed with a vectorizer module - # like text2vec-contextionary for example - params = { - "uuid": _id, - "data_object": data_properties, - "class_name": index_name, - } - if embeddings is not None: - params["vector"] = embeddings[i] - - batch.add_data_object(**params) - - batch.flush() - - relevance_score_fn = kwargs.get("relevance_score_fn") - by_text: bool = kwargs.get("by_text", False) - - return cls( - client, - index_name, - text_key, - embedding=embedding, - attributes=attributes, - relevance_score_fn=relevance_score_fn, - by_text=by_text, - ) - - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None: - """Delete by vector IDs. - - Args: - ids: List of ids to delete. - """ - - if ids is None: - raise ValueError("No ids provided to delete.") - - # TODO: Check if this can be done in bulk - for id in ids: - self._client.data_object.delete(uuid=id) diff --git a/api/core/vector_store/weaviate_vector_store.py b/api/core/vector_store/weaviate_vector_store.py deleted file mode 100644 index b5b3d84a9a..0000000000 --- a/api/core/vector_store/weaviate_vector_store.py +++ /dev/null @@ -1,38 +0,0 @@ -from core.vector_store.vector.weaviate import Weaviate - - -class WeaviateVectorStore(Weaviate): - def del_texts(self, where_filter: dict): - if not where_filter: - raise ValueError('where_filter must not be empty') - - self._client.batch.delete_objects( - class_name=self._index_name, - where=where_filter, - output='minimal' - ) - - def del_text(self, uuid: str) -> None: - self._client.data_object.delete( - uuid, - class_name=self._index_name - ) - - def text_exists(self, uuid: str) -> bool: - result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ - "path": ["doc_id"], - "operator": "Equal", - "valueText": uuid, - }).with_limit(1).do() - - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - - entries = result["data"]["Get"][self._index_name] - if len(entries) == 0: - return False - - return True - - def delete(self): - self._client.schema.delete_class(self._index_name) diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 93181ea161..42f1c70614 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -6,4 +6,4 @@ from tasks.clean_dataset_task import clean_dataset_task def handle(sender, **kwargs): dataset = sender clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, - dataset.index_struct, dataset.collection_binding_id) + dataset.index_struct, dataset.collection_binding_id, dataset.doc_form) diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index d6553b385e..d0bec667a9 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -6,4 +6,5 @@ from tasks.clean_document_task import clean_document_task def handle(sender, **kwargs): document_id = sender dataset_id = kwargs.get('dataset_id') - clean_document_task.delay(document_id, dataset_id) + doc_form = kwargs.get('doc_form') + clean_document_task.delay(document_id, dataset_id, doc_form) diff --git a/api/models/dataset.py b/api/models/dataset.py index d31e49f6ca..473a796be5 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -94,6 +94,14 @@ class Dataset(db.Model): return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ .filter(Document.dataset_id == self.id).scalar() + @property + def doc_form(self): + document = db.session.query(Document).filter( + Document.dataset_id == self.id).first() + if document: + return document.doc_form + return None + @property def retrieval_model_dict(self): default_retrieval_model = { diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 5db863fe8d..cdcb3121b9 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -6,7 +6,7 @@ from flask import current_app from werkzeug.exceptions import NotFound import app -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, DatasetQuery, Document @@ -41,18 +41,9 @@ def clean_unused_datasets_task(): if not documents or len(documents) == 0: try: # remove index - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - # delete from vector index - if vector_index: - if dataset.collection_binding_id: - vector_index.delete_by_group_id(dataset.id) - else: - if dataset.collection_binding_id: - vector_index.delete_by_group_id(dataset.id) - else: - vector_index.delete() - kw_index.delete() + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + index_processor.clean(dataset, None) + # update document update_params = { Document.enabled: False diff --git a/api/services/account_service.py b/api/services/account_service.py index 17999c9e25..e35d325ae4 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -282,9 +282,9 @@ class TenantService: else: TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False}) tenant_account_join.current = True - db.session.commit() # Set the current tenant for the account account.current_tenant_id = tenant_account_join.tenant_id + db.session.commit() @staticmethod def get_tenant_members(tenant: Tenant) -> list[Account]: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 0a9e835586..db4639d40b 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from extensions.ext_redis import redis_client from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation +from services.feature_service import FeatureService from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task @@ -284,6 +285,12 @@ class AppAnnotationService: result.append(content) if len(result) == 0: raise ValueError("The CSV file is empty.") + # check annotation limit + features = FeatureService.get_features(current_user.current_tenant_id) + if features.billing.enabled: + annotation_quota_limit = features.annotation_quota_limit + if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size: + raise ValueError("The number of annotations exceeds the limit of your subscription.") # async job job_id = str(uuid.uuid4()) indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 66c45ab8da..b151ebada8 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -11,10 +11,11 @@ from flask_login import current_user from sqlalchemy import func from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from core.index.index import IndexBuilder from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.models.document import Document as RAGDocument from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db @@ -36,6 +37,7 @@ from services.errors.account import NoPermissionError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError +from services.feature_service import FeatureService from services.vector_service import VectorService from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task @@ -401,7 +403,7 @@ class DocumentService: @staticmethod def delete_document(document): # trigger document_was_deleted signal - document_was_deleted.send(document.id, dataset_id=document.dataset_id) + document_was_deleted.send(document.id, dataset_id=document.dataset_id, doc_form=document.doc_form) db.session.delete(document) db.session.commit() @@ -452,7 +454,9 @@ class DocumentService: created_from: str = 'web'): # check document limit - if current_app.config['EDITION'] == 'CLOUD': + features = FeatureService.get_features(current_user.current_tenant_id) + + if features.billing.enabled: if 'original_document_id' not in document_data or not document_data['original_document_id']: count = 0 if document_data["data_source"]["type"] == "upload_file": @@ -462,6 +466,9 @@ class DocumentService: notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] for notion_info in notion_info_list: count = count + len(notion_info['pages']) + batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: dataset.data_source_type = document_data["data_source"]["type"] @@ -741,14 +748,20 @@ class DocumentService: @staticmethod def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account): - count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] - count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] - for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + features = FeatureService.get_features(current_user.current_tenant_id) + + if features.billing.enabled: + count = 0 + if document_data["data_source"]["type"] == "upload_file": + upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + count = len(upload_file_list) + elif document_data["data_source"]["type"] == "notion_import": + notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + for notion_info in notion_info_list: + count = count + len(notion_info['pages']) + batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") embedding_model = None dataset_collection_binding_id = None @@ -1048,7 +1061,7 @@ class SegmentService: # save vector index try: - VectorService.create_segment_vector(args['keywords'], segment_document, dataset) + VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False @@ -1075,6 +1088,7 @@ class SegmentService: ).scalar() pre_segment_data_list = [] segment_data_list = [] + keywords_list = [] for segment_item in segments: content = segment_item['content'] doc_id = str(uuid.uuid4()) @@ -1107,15 +1121,13 @@ class SegmentService: segment_document.answer = segment_item['answer'] db.session.add(segment_document) segment_data_list.append(segment_document) - pre_segment_data = { - 'segment': segment_document, - 'keywords': segment_item['keywords'] - } - pre_segment_data_list.append(pre_segment_data) + + pre_segment_data_list.append(segment_document) + keywords_list.append(segment_item['keywords']) try: # save vector index - VectorService.multi_create_segment_vector(pre_segment_data_list, dataset) + VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) except Exception as e: logging.exception("create segment index failed") for segment_document in segment_data_list: @@ -1139,17 +1151,24 @@ class SegmentService: segment.answer = args['answer'] if 'keywords' in args and args['keywords']: segment.keywords = args['keywords'] - if'enabled' in args and args['enabled'] is not None: + if 'enabled' in args and args['enabled'] is not None: segment.enabled = args['enabled'] db.session.add(segment) db.session.commit() # update segment index task if args['keywords']: - kw_index = IndexBuilder.get_index(dataset, 'economy') - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) - # save keyword index - kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords) + keyword = Keyword(dataset) + keyword.delete_by_ids([segment.index_node_id]) + document = RAGDocument( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + keyword.add_texts([document], keywords_list=[args['keywords']]) else: segment_hash = helper.generate_text_hash(content) tokens = 0 diff --git a/api/services/file_service.py b/api/services/file_service.py index 215ccf688a..53dd090236 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -9,8 +9,8 @@ from flask_login import current_user from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound -from core.data_loader.file_extractor import FileExtractor from core.file.upload_file_parser import UploadFileParser +from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Account @@ -20,9 +20,9 @@ from services.errors.file import FileTooLargeError, UnsupportedFileTypeError IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) -ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv'] + IMAGE_EXTENSIONS +ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv'] UNSTRUSTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', - 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml'] + IMAGE_EXTENSIONS + 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml'] PREVIEW_WORDS_LIMIT = 3000 @@ -32,7 +32,8 @@ class FileService: def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: extension = file.filename.split('.')[-1] etl_type = current_app.config['ETL_TYPE'] - allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS + allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ + else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() elif only_image and extension.lower() not in IMAGE_EXTENSIONS: @@ -136,7 +137,7 @@ class FileService: if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() - text = FileExtractor.load(upload_file, return_text=True) + text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) text = text[0:PREVIEW_WORDS_LIMIT] if text else '' return text @@ -164,7 +165,7 @@ class FileService: return generator, upload_file.mime_type @staticmethod - def get_public_image_preview(file_id: str) -> str: + def get_public_image_preview(file_id: str) -> tuple[Generator, str]: upload_file = db.session.query(UploadFile) \ .filter(UploadFile.id == file_id) \ .first() diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index e52527c627..568974b74f 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,21 +1,18 @@ import logging -import threading import time import numpy as np -from flask import current_app -from langchain.embeddings.base import Embeddings -from langchain.schema import Document from sklearn.manifold import TSNE from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rerank.rerank import RerankRunner +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document from extensions.ext_database import db from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment -from services.retrieval_service import RetrievalService default_retrieval_model = { 'search_method': 'semantic_search', @@ -28,6 +25,7 @@ default_retrieval_model = { 'score_threshold_enabled': False } + class HitTestingService: @classmethod def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: @@ -57,61 +55,15 @@ class HitTestingService: embeddings = CacheEmbedding(embedding_model) - all_documents = [] - threads = [] - - # retrieval_model source with semantic - if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'top_k': retrieval_model['top_k'], - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, - 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, - 'all_documents': all_documents, - 'search_method': retrieval_model['search_method'], - 'embeddings': embeddings - }) - threads.append(embedding_thread) - embedding_thread.start() - - # retrieval source with full text - if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'search_method': retrieval_model['search_method'], - 'embeddings': embeddings, - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, - 'top_k': retrieval_model['top_k'], - 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, - 'all_documents': all_documents - }) - threads.append(full_text_index_thread) - full_text_index_thread.start() - - for thread in threads: - thread.join() - - if retrieval_model['search_method'] == 'hybrid_search': - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=retrieval_model['reranking_model']['reranking_provider_name'], - model_type=ModelType.RERANK, - model=retrieval_model['reranking_model']['reranking_model_name'] - ) - - rerank_runner = RerankRunner(rerank_model_instance) - all_documents = rerank_runner.run( - query=query, - documents=all_documents, - score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, - top_n=retrieval_model['top_k'], - user=f"account-{account.id}" - ) + all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=retrieval_model['top_k'], + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") @@ -203,4 +155,3 @@ class HitTestingService: if not query or len(query) > 250: raise ValueError('Query is required and cannot exceed 250 characters') - diff --git a/api/services/retrieval_service.py b/api/services/retrieval_service.py deleted file mode 100644 index bc8f4ad5be..0000000000 --- a/api/services/retrieval_service.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import Optional - -from flask import Flask, current_app -from langchain.embeddings.base import Embeddings - -from core.index.vector_index.vector_index import VectorIndex -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError -from core.rerank.rerank import RerankRunner -from extensions.ext_database import db -from models.dataset import Dataset - -default_retrieval_model = { - 'search_method': 'semantic_search', - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False -} - - -class RetrievalService: - - @classmethod - def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, search_method: str, embeddings: Embeddings): - with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() - - vector_index = VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) - - documents = vector_index.search( - query, - search_type='similarity_score_threshold', - search_kwargs={ - 'k': top_k, - 'score_threshold': score_threshold, - 'filter': { - 'group_id': [dataset.id] - } - } - ) - - if documents: - if reranking_model and search_method == 'semantic_search': - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=reranking_model['reranking_provider_name'], - model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] - ) - except InvokeAuthorizationError: - return - - rerank_runner = RerankRunner(rerank_model_instance) - all_documents.extend(rerank_runner.run( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) - else: - all_documents.extend(documents) - - @classmethod - def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, search_method: str, embeddings: Embeddings): - with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() - - vector_index = VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) - - documents = vector_index.search_by_full_text_index( - query, - search_type='similarity_score_threshold', - top_k=top_k - ) - if documents: - if reranking_model and search_method == 'full_text_search': - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=reranking_model['reranking_provider_name'], - model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] - ) - except InvokeAuthorizationError: - return - - rerank_runner = RerankRunner(rerank_model_instance) - all_documents.extend(rerank_runner.run( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) - else: - all_documents.extend(documents) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 1a0447b38f..d336162bae 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,44 +1,18 @@ - from typing import Optional -from langchain.schema import Document - -from core.index.index import IndexBuilder +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from models.dataset import Dataset, DocumentSegment class VectorService: @classmethod - def create_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document], duplicate_check=True) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - if keywords and len(keywords) > 0: - index.create_segment_keywords(segment.index_node_id, keywords) - else: - index.add_texts([document]) - - @classmethod - def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset): + def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], + segments: list[DocumentSegment], dataset: Dataset): documents = [] - for pre_segment_data in pre_segment_data_list: - segment = pre_segment_data['segment'] + for segment in segments: document = Document( page_content=segment.content, metadata={ @@ -49,30 +23,26 @@ class VectorService: } ) documents.append(document) - - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts(documents, duplicate_check=True) + if dataset.indexing_technique == 'high_quality': + # save vector index + vector = Vector( + dataset=dataset + ) + vector.add_texts(documents, duplicate_check=True) # save keyword index - keyword_index = IndexBuilder.get_index(dataset, 'economy') - if keyword_index: - keyword_index.multi_create_segment_keywords(pre_segment_data_list) + keyword = Keyword(dataset) + + if keywords_list and len(keywords_list) > 0: + keyword.add_texts(documents, keyword_list=keywords_list) + else: + keyword.add_texts(documents) @classmethod def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): # update segment index task - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - # delete from vector index - if vector_index: - vector_index.delete_by_ids([segment.index_node_id]) - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) - - # add new index + # format new index document = Document( page_content=segment.content, metadata={ @@ -82,13 +52,20 @@ class VectorService: "dataset_id": segment.dataset_id, } ) + if dataset.indexing_technique == 'high_quality': + # update vector index + vector = Vector( + dataset=dataset + ) + vector.delete_by_ids([segment.index_node_id]) + vector.add_texts([document], duplicate_check=True) - # save vector index - if vector_index: - vector_index.add_texts([document], duplicate_check=True) + # update keyword index + keyword = Keyword(dataset) + keyword.delete_by_ids([segment.index_node_id]) # save keyword index if keywords and len(keywords) > 0: - kw_index.create_segment_keywords(segment.index_node_id, keywords) + keyword.add_texts([document], keywords_list=[keywords]) else: - kw_index.add_texts([document]) + keyword.add_texts([document]) diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index ae235a2a63..a26ecf5526 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,10 +4,10 @@ import time import click from celery import shared_task -from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Document as DatasetDocument @@ -60,15 +60,9 @@ def add_document_to_index_task(dataset_document_id: str): if not dataset: raise Exception('Document has no dataset') - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts(documents) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - index.add_texts(documents) + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, documents) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 61529f9bde..b3aa8b596c 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -3,9 +3,9 @@ import time import click from celery import shared_task -from langchain.schema import Document -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -48,9 +48,9 @@ def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: s "doc_id": annotation_id } ) - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document]) + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.create([document], duplicate_check=True) + end_at = time.perf_counter() logging.info( click.style( diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 5b6c45b4f3..063ffca6fd 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -3,10 +3,10 @@ import time import click from celery import shared_task -from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -79,9 +79,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: collection_binding_id=dataset_collection_binding.id ) - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts(documents) + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.create(documents, duplicate_check=True) db.session.commit() redis_client.setex(indexing_cache_key, 600, 'completed') diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 852f899512..81155a35e4 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -30,12 +30,11 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str collection_binding_id=dataset_collection_binding.id ) - vector_index = IndexBuilder.get_default_high_quality_index(dataset) - if vector_index: - try: - vector_index.delete_by_metadata_field('annotation_id', annotation_id) - except Exception: - logging.exception("Delete annotation index failed when annotation deleted.") + try: + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.delete_by_metadata_field('annotation_id', annotation_id) + except Exception: + logging.exception("Delete annotation index failed when annotation deleted.") end_at = time.perf_counter() logging.info( click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index c5f028c72d..7b88d7ac50 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -48,12 +48,11 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): collection_binding_id=app_annotation_setting.collection_binding_id ) - vector_index = IndexBuilder.get_default_high_quality_index(dataset) - if vector_index: - try: - vector_index.delete_by_metadata_field('app_id', app_id) - except Exception: - logging.exception("Delete doc index failed when dataset deleted.") + try: + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.delete_by_metadata_field('app_id', app_id) + except Exception: + logging.exception("Delete annotation index failed when annotation deleted.") redis_client.setex(disable_app_annotation_job_key, 600, 'completed') # delete annotation setting diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index a125dd5717..f3260bbb50 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -4,10 +4,10 @@ import time import click from celery import shared_task -from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -81,15 +81,15 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_ } ) documents.append(document) - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - try: - index.delete_by_metadata_field('app_id', app_id) - except Exception as e: - logging.info( - click.style('Delete annotation index error: {}'.format(str(e)), - fg='red')) - index.add_texts(documents) + + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + try: + vector.delete_by_metadata_field('app_id', app_id) + except Exception as e: + logging.info( + click.style('Delete annotation index error: {}'.format(str(e)), + fg='red')) + vector.add_texts(documents) db.session.commit() redis_client.setex(enable_app_annotation_job_key, 600, 'completed') end_at = time.perf_counter() diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index e632c3a24e..7219abd3cd 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -3,9 +3,9 @@ import time import click from celery import shared_task -from langchain.schema import Document -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -49,10 +49,9 @@ def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id "doc_id": annotation_id } ) - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.delete_by_metadata_field('annotation_id', annotation_id) - index.add_texts([document]) + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.delete_by_metadata_field('annotation_id', annotation_id) + vector.add_texts([document]) end_at = time.perf_counter() logging.info( click.style( diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 74ebcea15f..16e4affc91 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import ( AppDatasetJoin, @@ -18,7 +18,7 @@ from models.dataset import ( @shared_task(queue='dataset') def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, - index_struct: str, collection_binding_id: str): + index_struct: str, collection_binding_id: str, doc_form: str): """ Clean dataset when dataset deleted. :param dataset_id: dataset id @@ -26,6 +26,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, :param indexing_technique: indexing technique :param index_struct: index struct dict :param collection_binding_id: collection binding id + :param doc_form: dataset form Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ @@ -38,26 +39,14 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, tenant_id=tenant_id, indexing_technique=indexing_technique, index_struct=index_struct, - collection_binding_id=collection_binding_id + collection_binding_id=collection_binding_id, + doc_form=doc_form ) documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from vector index - if dataset.indexing_technique == 'high_quality': - vector_index = IndexBuilder.get_default_high_quality_index(dataset) - try: - vector_index.delete_by_group_id(dataset.id) - except Exception: - logging.exception("Delete doc index failed when dataset deleted.") - - # delete from keyword index - try: - kw_index.delete() - except Exception: - logging.exception("Delete nodes index failed when dataset deleted.") + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, None) for document in documents: db.session.delete(document) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 76eb1a572c..71ebad1da4 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -4,17 +4,18 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment @shared_task(queue='dataset') -def clean_document_task(document_id: str, dataset_id: str): +def clean_document_task(document_id: str, dataset_id: str, doc_form: str): """ Clean document when document deleted. :param document_id: document id :param dataset_id: dataset id + :param doc_form: doc_form Usage: clean_document_task.delay(document_id, dataset_id) """ @@ -27,21 +28,12 @@ def clean_document_task(document_id: str, dataset_id: str): if not dataset: raise Exception('Document has no dataset') - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - if vector_index: - vector_index.delete_by_document_id(document_id) - - # delete from keyword index - if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index b85938d7d6..9b697b6351 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -26,9 +26,8 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): if not dataset: raise Exception('Document has no dataset') - - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() for document_id in document_ids: document = db.session.query(Document).filter( Document.id == document_id @@ -38,13 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - if vector_index: - vector_index.delete_by_document_id(document_id) - - # delete from keyword index - if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + index_processor.clean(dataset, index_node_ids) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 6c492f0692..f33a1e91bf 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -5,10 +5,10 @@ from typing import Optional import click from celery import shared_task -from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -68,18 +68,9 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) return - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document], duplicate_check=True) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - if keywords and len(keywords) > 0: - index.create_segment_keywords(segment.index_node_id, keywords) - else: - index.add_texts([document]) + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, [document]) # update segment to completed update_params = { diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 008f122a82..3827d62fbf 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -3,9 +3,9 @@ import time import click from celery import shared_task -from langchain.schema import Document -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -29,10 +29,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if not dataset: raise Exception('Dataset not found') - + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": - index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) - index.delete_by_group_id(dataset.id) + index_processor.clean(dataset, None, with_keywords=False) elif action == "add": dataset_documents = db.session.query(DatasetDocument).filter( DatasetDocument.dataset_id == dataset_id, @@ -42,8 +42,6 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ).all() if dataset_documents: - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False) documents = [] for dataset_document in dataset_documents: # delete from vector index @@ -65,7 +63,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): documents.append(document) # save vector index - index.create(documents) + index_processor.load(dataset, documents, with_keywords=False) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 9c9b00a2f5..d79286cf3d 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, Document @@ -39,15 +39,9 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan')) return - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from vector index - if vector_index: - vector_index.delete_by_ids([index_node_id]) - - # delete from keyword index - kw_index.delete_by_ids([index_node_id]) + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.clean(dataset, [index_node_id]) end_at = time.perf_counter() logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green')) diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 97f4fd0677..4788bf4e4b 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -48,15 +48,9 @@ def disable_segment_from_index_task(segment_id: str): logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) return - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from vector index - if vector_index: - vector_index.delete_by_ids([segment.index_node_id]) - - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.clean(dataset, [segment.index_node_id]) end_at = time.perf_counter() logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 57f080e3ff..84e2029705 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -6,9 +6,9 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.data_loader.loader.notion import NotionLoader -from core.index.index import IndexBuilder from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.rag.extractor.notion_extractor import NotionExtractor +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.source import DataSourceBinding @@ -54,11 +54,11 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): if not data_source_binding: raise ValueError('Data source binding not found.') - loader = NotionLoader( - notion_access_token=data_source_binding.access_token, + loader = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, - notion_page_type=page_type + notion_page_type=page_type, + notion_access_token=data_source_binding.access_token ) last_edited_time = loader.get_notion_last_edited_time() @@ -74,20 +74,14 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: raise Exception('Dataset not found') - - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - if vector_index: - vector_index.delete_by_document_id(document_id) - - # delete from keyword index - if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + index_processor.clean(dataset, index_node_ids) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 87081e19e3..b776207050 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -4,10 +4,12 @@ import time import click from celery import shared_task +from flask import current_app from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db -from models.dataset import Document +from models.dataset import Dataset, Document +from services.feature_service import FeatureService @shared_task(queue='dataset') @@ -21,6 +23,35 @@ def document_indexing_task(dataset_id: str, document_ids: list): """ documents = [] start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + count = len(document_ids) + batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + if 0 < vector_space.limit <= vector_space.size: + raise ValueError("Your total number of documents plus the number of uploads have over the limit of " + "your subscription.") + except Exception as e: + for document_id in document_ids: + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() + if document: + document.indexing_status = 'error' + document.error = str(e) + document.stopped_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + return + for document_id in document_ids: logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 12014799b0..e59c549a65 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -6,8 +6,8 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -42,19 +42,14 @@ def document_indexing_update_task(dataset_id: str, document_id: str): if not dataset: raise Exception('Dataset not found') - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - if vector_index: - vector_index.delete_by_ids(index_node_ids) - - # delete from keyword index - if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + index_processor.clean(dataset, index_node_ids) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 8dffd01520..a6254a822d 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,10 +4,10 @@ import time import click from celery import shared_task -from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -60,15 +60,9 @@ def enable_segment_to_index_task(segment_id: str): logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) return + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document], duplicate_check=True) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - index.add_texts([document]) + index_processor.load(dataset, [document]) end_at = time.perf_counter() logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index a18842a59a..cff8dddc53 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Document, DocumentSegment @@ -37,18 +37,15 @@ def remove_document_from_index_task(document_id: str): if not dataset: raise Exception('Document has no dataset') - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - # delete from vector index - if vector_index: - vector_index.delete_by_document_id(document.id) - - # delete from keyword index segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + try: + index_processor.clean(dataset, index_node_ids) + except Exception: + logging.exception(f"clean dataset {dataset.id} from index failed") end_at = time.perf_counter() logging.info( diff --git a/api/tasks/update_segment_index_task.py b/api/tasks/update_segment_index_task.py deleted file mode 100644 index 802bac8857..0000000000 --- a/api/tasks/update_segment_index_task.py +++ /dev/null @@ -1,114 +0,0 @@ -import datetime -import logging -import time -from typing import Optional - -import click -from celery import shared_task -from langchain.schema import Document -from werkzeug.exceptions import NotFound - -from core.index.index import IndexBuilder -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.dataset import DocumentSegment - - -@shared_task(queue='dataset') -def update_segment_index_task(segment_id: str, keywords: Optional[list[str]] = None): - """ - Async update segment index - :param segment_id: - :param keywords: - Usage: update_segment_index_task.delay(segment_id) - """ - logging.info(click.style('Start update segment index: {}'.format(segment_id), fg='green')) - start_at = time.perf_counter() - - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() - if not segment: - raise NotFound('Segment not found') - - if segment.status != 'updating': - return - - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) - - try: - dataset = segment.dataset - - if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) - return - - dataset_document = segment.document - - if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) - return - - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) - return - - # update segment status to indexing - update_params = { - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.utcnow() - } - DocumentSegment.query.filter_by(id=segment.id).update(update_params) - db.session.commit() - - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from vector index - if vector_index: - vector_index.delete_by_ids([segment.index_node_id]) - - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) - - # add new index - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document], duplicate_check=True) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - if keywords and len(keywords) > 0: - index.create_segment_keywords(segment.index_node_id, keywords) - else: - index.add_texts([document]) - - # update segment to completed - update_params = { - DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.utcnow() - } - DocumentSegment.query.filter_by(id=segment.id).update(update_params) - db.session.commit() - - end_at = time.perf_counter() - logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) - except Exception as e: - logging.exception("update segment index failed") - segment.enabled = False - segment.disabled_at = datetime.datetime.utcnow() - segment.status = 'error' - segment.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) diff --git a/api/tasks/update_segment_keyword_index_task.py b/api/tasks/update_segment_keyword_index_task.py deleted file mode 100644 index ee88beba98..0000000000 --- a/api/tasks/update_segment_keyword_index_task.py +++ /dev/null @@ -1,68 +0,0 @@ -import datetime -import logging -import time - -import click -from celery import shared_task -from werkzeug.exceptions import NotFound - -from core.index.index import IndexBuilder -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.dataset import DocumentSegment - - -@shared_task(queue='dataset') -def update_segment_keyword_index_task(segment_id: str): - """ - Async update segment index - :param segment_id: - Usage: update_segment_keyword_index_task.delay(segment_id) - """ - logging.info(click.style('Start update segment keyword index: {}'.format(segment_id), fg='green')) - start_at = time.perf_counter() - - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() - if not segment: - raise NotFound('Segment not found') - - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) - - try: - dataset = segment.dataset - - if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) - return - - dataset_document = segment.document - - if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) - return - - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) - return - - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - index.update_segment_keywords_index(segment.index_node_id, segment.keywords) - - end_at = time.perf_counter() - logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) - except Exception as e: - logging.exception("update segment index failed") - segment.enabled = False - segment.disabled_at = datetime.datetime.utcnow() - segment.status = 'error' - segment.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) diff --git a/web/.eslintrc.json b/web/.eslintrc.json index 1ab9727739..53308105b6 100644 --- a/web/.eslintrc.json +++ b/web/.eslintrc.json @@ -8,6 +8,7 @@ "error", "type" ], + "@typescript-eslint/no-var-requires": "off", "no-console": "off", "indent": "off", "@typescript-eslint/indent": [ diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx index 7061bf1253..8f924268f2 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx @@ -1,8 +1,7 @@ import React from 'react' import ChartView from './chartView' import CardView from './cardView' -import { getLocaleOnServer } from '@/i18n/server' -import { useTranslation as translate } from '@/i18n/i18next-serverside-config' +import { getLocaleOnServer, useTranslation as translate } from '@/i18n/server' import ApikeyInfoPanel from '@/app/components/app/overview/apikey-info-panel' export type IDevelopProps = { diff --git a/web/app/(commonLayout)/apps/page.tsx b/web/app/(commonLayout)/apps/page.tsx index 1b54c3c5ad..feb4cb0821 100644 --- a/web/app/(commonLayout)/apps/page.tsx +++ b/web/app/(commonLayout)/apps/page.tsx @@ -1,8 +1,7 @@ import classNames from 'classnames' import style from '../list.module.css' import Apps from './Apps' -import { getLocaleOnServer } from '@/i18n/server' -import { useTranslation as translate } from '@/i18n/i18next-serverside-config' +import { getLocaleOnServer, useTranslation as translate } from '@/i18n/server' const AppList = async () => { const locale = getLocaleOnServer() diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index b280e73d74..c93c0761fa 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -25,7 +25,7 @@ import Link from 'next/link' import s from './style.module.css' import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets' import type { RelatedApp, RelatedAppResponse } from '@/models/datasets' -import { getLocaleOnClient } from '@/i18n/client' +import { getLocaleOnClient } from '@/i18n' import AppSideBar from '@/app/components/app-sidebar' import Divider from '@/app/components/base/divider' import Indicator from '@/app/components/header/indicator' @@ -35,7 +35,7 @@ import FloatPopoverContainer from '@/app/components/base/float-popover-container import DatasetDetailContext from '@/context/dataset-detail' import { DataSourceType } from '@/models/datasets' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' -import { LanguagesSupported, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' export type IAppDetailLayoutProps = { children: React.ReactNode @@ -72,7 +72,7 @@ const LikedItem = ({ const TargetIcon = ({ className }: SVGProps) => { return - + @@ -105,7 +105,6 @@ type IExtraInfoProps = { const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => { const locale = getLocaleOnClient() - const language = getModelRuntimeSupported(locale) const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile) const { t } = useTranslation() @@ -150,7 +149,7 @@ const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => { { +const Settings = async () => { const locale = getLocaleOnServer() const { t } = await translate(locale, 'dataset-settings') @@ -19,7 +12,7 @@ const Settings = async ({
{t('title')}
{t('desc')}
-
+ ) } diff --git a/web/app/(commonLayout)/datasets/Doc.tsx b/web/app/(commonLayout)/datasets/Doc.tsx index b1b7b00cf7..a6dd8c23ef 100644 --- a/web/app/(commonLayout)/datasets/Doc.tsx +++ b/web/app/(commonLayout)/datasets/Doc.tsx @@ -5,7 +5,7 @@ import { useContext } from 'use-context-selector' import TemplateEn from './template/template.en.mdx' import TemplateZh from './template/template.zh.mdx' import I18n from '@/context/i18n' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' type DocProps = { apiBaseUrl: string @@ -14,11 +14,10 @@ const Doc: FC = ({ apiBaseUrl, }) => { const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) return (
{ - language !== LanguagesSupportedUnderscore[1] + locale !== LanguagesSupported[1] ? : } diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index a820a650ea..3cf88ce281 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -12,7 +12,7 @@ import Button from '@/app/components/base/button' import { SimpleSelect } from '@/app/components/base/select' import { timezones } from '@/utils/timezone' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported, languages } from '@/utils/language' +import { LanguagesSupported, languages } from '@/i18n/language' import { activateMember, invitationCheck } from '@/service/common' import Toast from '@/app/components/base/toast' import Loading from '@/app/components/base/loading' @@ -42,9 +42,9 @@ const ActivateForm = () => { const [name, setName] = useState('') const [password, setPassword] = useState('') const [timezone, setTimezone] = useState('Asia/Shanghai') - const [language, setLanguage] = useState(getModelRuntimeSupported(locale)) + const [language, setLanguage] = useState(locale) const [showSuccess, setShowSuccess] = useState(false) - const defaultLanguage = useCallback(() => (window.navigator.language.startsWith('zh') ? LanguagesSupportedUnderscore[1] : LanguagesSupportedUnderscore[0]) || LanguagesSupportedUnderscore[0], []) + const defaultLanguage = useCallback(() => (window.navigator.language.startsWith('zh') ? LanguagesSupported[1] : LanguagesSupported[0]) || LanguagesSupported[0], []) const showErrorMessage = useCallback((message: string) => { Toast.notify({ @@ -207,7 +207,7 @@ const ActivateForm = () => { {t('login.license.link')} diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx index 36bcdd9e7d..46dce55eb8 100644 --- a/web/app/components/app-sidebar/basic.tsx +++ b/web/app/components/app-sidebar/basic.tsx @@ -36,7 +36,7 @@ const WebappSvg = const NotionSvg = - + diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.tsx index 37512abf6c..fbbea70612 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { Download02 as DownloadIcon } from '@/app/components/base/icons/src/vender/solid/general' import I18n from '@/context/i18n' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' const CSV_TEMPLATE_QA_EN = [ ['question', 'answer'], @@ -25,11 +25,10 @@ const CSVDownload: FC = () => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const { CSVDownloader, Type } = useCSVDownloader() const getTemplate = () => { - return language !== LanguagesSupportedUnderscore[1] ? CSV_TEMPLATE_QA_EN : CSV_TEMPLATE_QA_CN + return locale !== LanguagesSupported[1] ? CSV_TEMPLATE_QA_EN : CSV_TEMPLATE_QA_CN } return ( @@ -58,7 +57,7 @@ const CSVDownload: FC = () => { diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index b5c749bcd7..90b1a9672e 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -20,7 +20,7 @@ import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows import I18n from '@/context/i18n' import { fetchExportAnnotationList } from '@/service/annotation' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' const CSV_HEADER_QA_EN = ['Question', 'Answer'] const CSV_HEADER_QA_CN = ['问题', '答案'] @@ -40,7 +40,6 @@ const HeaderOptions: FC = ({ }) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const { CSVDownloader, Type } = useCSVDownloader() const [list, setList] = useState([]) @@ -56,7 +55,7 @@ const HeaderOptions: FC = ({ const content = listTransformer(list).join('\n') const file = new Blob([content], { type: 'application/jsonl' }) a.href = URL.createObjectURL(file) - a.download = `annotations-${language}.jsonl` + a.download = `annotations-${locale}.jsonl` a.click() } @@ -110,10 +109,10 @@ const HeaderOptions: FC = ({ > [item.question, item.answer]), ]} > diff --git a/web/app/components/app/configuration/config-prompt/conversation-histroy/history-panel.tsx b/web/app/components/app/configuration/config-prompt/conversation-histroy/history-panel.tsx index bf4b9c9f51..f40bd4b733 100644 --- a/web/app/components/app/configuration/config-prompt/conversation-histroy/history-panel.tsx +++ b/web/app/components/app/configuration/config-prompt/conversation-histroy/history-panel.tsx @@ -7,7 +7,7 @@ import OperationBtn from '@/app/components/app/configuration/base/operation-btn' import Panel from '@/app/components/app/configuration/base/feature-panel' import { MessageClockCircle } from '@/app/components/base/icons/src/vender/solid/general' import I18n from '@/context/i18n' -import { LanguagesSupported, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' type Props = { showWarning: boolean @@ -20,7 +20,6 @@ const HistoryPanel: FC = ({ }) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) return ( = ({ {showWarning && (
{t('appDebug.feature.conversationHistory.tip')} - { const { t } = useTranslation() const pathname = usePathname() diff --git a/web/app/components/app/configuration/config/agent/agent-tools/choose-tool/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/choose-tool/index.tsx index b5dec18a5b..0b0e0676b7 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/choose-tool/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/choose-tool/index.tsx @@ -10,7 +10,7 @@ import Drawer from '@/app/components/base/drawer-plus' import ConfigContext from '@/context/debug-configuration' import type { ModelConfig } from '@/models/debug' import I18n from '@/context/i18n' -import { getModelRuntimeSupported } from '@/utils/language' + type Props = { show: boolean onHide: () => void @@ -24,7 +24,6 @@ const ChooseTool: FC = ({ }) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const { modelConfig, setModelConfig, @@ -60,7 +59,7 @@ const ChooseTool: FC = ({ provider_type: collection.type, provider_name: collection.name, tool_name: tool.name, - tool_label: tool.label[language], + tool_label: tool.label[locale], tool_parameters: parameters, enabled: true, }) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index 55b554bfcb..98f454bd20 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -13,7 +13,7 @@ import I18n from '@/context/i18n' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' import { DiagonalDividingLine } from '@/app/components/base/icons/src/public/common' -import { getModelRuntimeSupported } from '@/utils/language' +import { getLanguage } from '@/i18n/language' type Props = { collection: Collection toolName: string @@ -32,7 +32,7 @@ const SettingBuiltInTool: FC = ({ onSave, }) => { const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) + const language = getLanguage(locale) const { t } = useTranslation() const [isLoading, setIsLoading] = useState(true) @@ -116,7 +116,7 @@ const SettingBuiltInTool: FC = ({
) - const setttingUI = ( + const settingUI = ( = ({
: (
- {isInfoActive ? infoUI : setttingUI} + {isInfoActive ? infoUI : settingUI}
{!readonly && !isInfoActive && (
diff --git a/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx index 24d3e0e64a..6bd40547ca 100644 --- a/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx @@ -7,8 +7,9 @@ import { usePathname } from 'next/navigation' import Panel from '@/app/components/app/configuration/base/feature-panel' import { Speaker } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' import ConfigContext from '@/context/debug-configuration' -import { languages } from '@/utils/language' +import { languages } from '@/i18n/language' import { fetchAppVoices } from '@/service/apps' +import AudioBtn from '@/app/components/base/audio-btn' const TextToSpeech: FC = () => { const { t } = useTranslation() @@ -20,19 +21,28 @@ const TextToSpeech: FC = () => { const matched = pathname.match(/\/app\/([^/]+)/) const appId = (matched?.length && matched[1]) ? matched[1] : '' const language = textToSpeechConfig.language + const languageInfo = languages.find(i => i.value === textToSpeechConfig.language) + const voiceItems = useSWR({ appId, language }, fetchAppVoices).data const voiceItem = voiceItems?.find(item => item.value === textToSpeechConfig.voice) + return ( +
{t('appDebug.feature.textToSpeech.title')}
} headerIcon={} headerRight={ -
- {languages.find(i => i.value === textToSpeechConfig.language)?.name} - {voiceItem?.name ?? t('appDebug.voice.defaultDisplay')} +
+ {languageInfo && (`${languageInfo?.name} - `)}{voiceItem?.name ?? t('appDebug.voice.defaultDisplay')} + { languageInfo?.example && ( + + )}
} noBodySpacing diff --git a/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx b/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx index 0011b7054a..759a15213d 100644 --- a/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx +++ b/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx @@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import I18n from '@/context/i18n' import { FlipBackward } from '@/app/components/base/icons/src/vender/line/arrows' -import { LanguagesSupported, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' type Props = { onReturnToSimpleMode: () => void } @@ -15,7 +15,6 @@ const AdvancedModeWarning: FC = ({ }) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const [show, setShow] = React.useState(true) if (!show) return null @@ -27,7 +26,7 @@ const AdvancedModeWarning: FC = ({ {t('appDebug.promptMode.advancedWarning.description')}
{t('appDebug.promptMode.advancedWarning.learnMore')} diff --git a/web/app/components/app/configuration/toolbox/moderation/index.tsx b/web/app/components/app/configuration/toolbox/moderation/index.tsx index 7731d9ebd2..9eb14e98d2 100644 --- a/web/app/components/app/configuration/toolbox/moderation/index.tsx +++ b/web/app/components/app/configuration/toolbox/moderation/index.tsx @@ -7,12 +7,10 @@ import { useModalContext } from '@/context/modal-context' import ConfigContext from '@/context/debug-configuration' import { fetchCodeBasedExtensionList } from '@/service/common' import I18n from '@/context/i18n' -import { getModelRuntimeSupported } from '@/utils/language' const Moderation = () => { const { t } = useTranslation() const { setShowModerationSettingModal } = useModalContext() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const { moderationConfig, setModerationConfig, @@ -39,7 +37,7 @@ const Moderation = () => { else if (moderationConfig.type === 'api') prefix = t('common.apiBasedExtension.selector.title') else - prefix = codeBasedExtensionList?.data.find(item => item.name === moderationConfig.type)?.label[language] || '' + prefix = codeBasedExtensionList?.data.find(item => item.name === moderationConfig.type)?.label[locale] || '' if (moderationConfig.config?.inputs_config?.enabled && moderationConfig.config?.outputs_config?.enabled) suffix = t('appDebug.feature.moderation.allEnabled') diff --git a/web/app/components/app/configuration/toolbox/moderation/moderation-setting-modal.tsx b/web/app/components/app/configuration/toolbox/moderation/moderation-setting-modal.tsx index a1b8d7deda..5bc52e5564 100644 --- a/web/app/components/app/configuration/toolbox/moderation/moderation-setting-modal.tsx +++ b/web/app/components/app/configuration/toolbox/moderation/moderation-setting-modal.tsx @@ -17,7 +17,7 @@ import { } from '@/service/common' import type { CodeBasedExtensionItem } from '@/models/common' import I18n from '@/context/i18n' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' import { InfoCircle } from '@/app/components/base/icons/src/vender/line/general' import { useModalContext } from '@/context/modal-context' import { CustomConfigurationStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -44,7 +44,6 @@ const ModerationSettingModal: FC = ({ const { t } = useTranslation() const { notify } = useToastContext() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const { data: modelProviders, isLoading, mutate } = useSWR('/workspaces/current/model-providers', fetchModelProviders) const [localeData, setLocaleData] = useState(data) const { setShowAccountSettingModal } = useModalContext() @@ -200,12 +199,12 @@ const ModerationSettingModal: FC = ({ } if (localeData.type === 'keywords' && !localeData.config.keywords) { - notify({ type: 'error', message: t('appDebug.errorMessage.valueOfVarRequired', { key: language !== LanguagesSupportedUnderscore[1] ? 'keywords' : '关键词' }) }) + notify({ type: 'error', message: t('appDebug.errorMessage.valueOfVarRequired', { key: locale !== LanguagesSupported[1] ? 'keywords' : '关键词' }) }) return } if (localeData.type === 'api' && !localeData.config.api_based_extension_id) { - notify({ type: 'error', message: t('appDebug.errorMessage.valueOfVarRequired', { key: language !== LanguagesSupportedUnderscore[1] ? 'API Extension' : 'API 扩展' }) }) + notify({ type: 'error', message: t('appDebug.errorMessage.valueOfVarRequired', { key: locale !== LanguagesSupported[1] ? 'API Extension' : 'API 扩展' }) }) return } @@ -214,7 +213,7 @@ const ModerationSettingModal: FC = ({ if (!localeData.config?.[currentProvider.form_schema[i].variable] && currentProvider.form_schema[i].required) { notify({ type: 'error', - message: t('appDebug.errorMessage.valueOfVarRequired', { key: language !== LanguagesSupportedUnderscore[1] ? currentProvider.form_schema[i].label['en-US'] : currentProvider.form_schema[i].label['zh-Hans'] }), + message: t('appDebug.errorMessage.valueOfVarRequired', { key: locale !== LanguagesSupported[1] ? currentProvider.form_schema[i].label['en-US'] : currentProvider.form_schema[i].label['zh-Hans'] }), }) return } diff --git a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx index d52b66b65c..6a6e355956 100644 --- a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx +++ b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx @@ -12,7 +12,7 @@ import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/educatio import { fetchCodeBasedExtensionList } from '@/service/common' import { SimpleSelect } from '@/app/components/base/select' import I18n from '@/context/i18n' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' import type { CodeBasedExtensionItem, ExternalDataTool, @@ -41,7 +41,6 @@ const ExternalDataToolModal: FC = ({ const { t } = useTranslation() const { notify } = useToastContext() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const [localeData, setLocaleData] = useState(data.type ? data : { ...data, type: 'api' }) const [showEmojiPicker, setShowEmojiPicker] = useState(false) const { data: codeBasedExtensionList } = useSWR( @@ -157,7 +156,7 @@ const ExternalDataToolModal: FC = ({ } if (localeData.type === 'api' && !localeData.config?.api_based_extension_id) { - notify({ type: 'error', message: t('appDebug.errorMessage.valueOfVarRequired', { key: language !== LanguagesSupportedUnderscore[1] ? 'API Extension' : 'API 扩展' }) }) + notify({ type: 'error', message: t('appDebug.errorMessage.valueOfVarRequired', { key: locale !== LanguagesSupported[1] ? 'API Extension' : 'API 扩展' }) }) return } @@ -166,7 +165,7 @@ const ExternalDataToolModal: FC = ({ if (!localeData.config?.[currentProvider.form_schema[i].variable] && currentProvider.form_schema[i].required) { notify({ type: 'error', - message: t('appDebug.errorMessage.valueOfVarRequired', { key: language !== LanguagesSupportedUnderscore[1] ? currentProvider.form_schema[i].label['en-US'] : currentProvider.form_schema[i].label['zh-Hans'] }), + message: t('appDebug.errorMessage.valueOfVarRequired', { key: locale !== LanguagesSupported[1] ? currentProvider.form_schema[i].label['en-US'] : currentProvider.form_schema[i].label['zh-Hans'] }), }) return } diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx index 826a85aae5..5baf6dd4d3 100644 --- a/web/app/components/app/overview/customize/index.tsx +++ b/web/app/components/app/overview/customize/index.tsx @@ -9,7 +9,7 @@ import I18n from '@/context/i18n' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' import Tag from '@/app/components/base/tag' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' type IShareLinkProps = { isShow: boolean @@ -44,7 +44,6 @@ const CustomizeModal: FC = ({ }) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const isChatApp = mode === 'chat' return = ({ className='w-36 mt-2' onClick={() => window.open( - `https://docs.dify.ai/${language !== LanguagesSupportedUnderscore[1] + `https://docs.dify.ai/${locale !== LanguagesSupported[1] ? 'user-guide/launching-dify-apps/developing-with-apis' : `v/${locale.toLowerCase()}/guides/application-publishing/developing-with-apis` }`, diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 6926dace3c..026ba6ae10 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -13,7 +13,7 @@ import type { AppDetailResponse } from '@/models/app' import type { Language } from '@/types/app' import EmojiPicker from '@/app/components/base/emoji-picker' -import { languages } from '@/utils/language' +import { languages } from '@/i18n/language' export type ISettingsModalProps = { appInfo: AppDetailResponse @@ -122,7 +122,7 @@ const SettingsModal: FC = ({ />
{t(`${prefixSettings}.language`)}
item.supported)} defaultValue={language} onSelect={item => setLanguage(item.value as Language)} /> diff --git a/web/app/components/base/audio-btn/index.tsx b/web/app/components/base/audio-btn/index.tsx index eaaca72594..c10755d3a1 100644 --- a/web/app/components/base/audio-btn/index.tsx +++ b/web/app/components/base/audio-btn/index.tsx @@ -10,11 +10,13 @@ import { textToAudio } from '@/service/share' type AudioBtnProps = { value: string className?: string + isAudition?: boolean } const AudioBtn = ({ value, className, + isAudition, }: AudioBtnProps) => { const audioRef = useRef(null) const [isPlaying, setIsPlaying] = useState(false) @@ -97,10 +99,10 @@ const AudioBtn = ({ className='z-10' >
-
+
diff --git a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx index 20a54da83a..d90e4025d3 100644 --- a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx +++ b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx @@ -3,6 +3,7 @@ import { useTranslation } from 'react-i18next' import { useChatWithHistoryContext } from '../context' import Form from './form' import Button from '@/app/components/base/button' +import AppIcon from '@/app/components/base/app-icon' import { MessageDotsCircle } from '@/app/components/base/icons/src/vender/solid/communication' import { Edit02 } from '@/app/components/base/icons/src/vender/line/general' import { Star06 } from '@/app/components/base/icons/src/vender/solid/shapes' @@ -40,8 +41,13 @@ const ConfigPanel = () => { { showConfigPanelBeforeChat && ( <> -
- {appData?.site.icon} {appData?.site.title} +
+ + {appData?.site.title}
{ appData?.site.description && ( diff --git a/web/app/components/base/chat/chat/answer/operation.tsx b/web/app/components/base/chat/chat/answer/operation.tsx index 1f7c57ae75..eb5dead657 100644 --- a/web/app/components/base/chat/chat/answer/operation.tsx +++ b/web/app/components/base/chat/chat/answer/operation.tsx @@ -74,7 +74,7 @@ const Operation: FC = ({ ) } - {(!isOpeningStatement && config?.text_to_speech.enabled) && ( + {(!isOpeningStatement && config?.text_to_speech?.enabled) && ( = ({ }) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) - const isZh = language === LanguagesSupportedUnderscore[1] + + const isZh = locale === LanguagesSupported[1] const [loading, setLoading] = React.useState(false) const i18nPrefix = `billing.plans.${plan}` const isFreePlan = plan === Plan.sandbox diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index 3b8146a5d9..dfedffbbdc 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -12,7 +12,9 @@ import { upload } from '@/service/base' import { fetchFileUploadConfig } from '@/service/common' import { fetchSupportFileTypes } from '@/service/datasets' import I18n from '@/context/i18n' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' + +const FILES_NUMBER_LIMIT = 20 type IFileUploaderProps = { fileList: FileItem[] @@ -34,7 +36,6 @@ const FileUploader = ({ const { t } = useTranslation() const { notify } = useContext(ToastContext) const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const [dragging, setDragging] = useState(false) const dropRef = useRef(null) const dragRef = useRef(null) @@ -75,7 +76,7 @@ const FileUploader = ({ res = res.map(item => item.toLowerCase()) res = res.filter((item, index, self) => self.indexOf(item) === index) - return res.map(item => item.toUpperCase()).join(language !== LanguagesSupportedUnderscore[1] ? ', ' : '、 ') + return res.map(item => item.toUpperCase()).join(locale !== LanguagesSupported[1] ? ', ' : '、 ') })() const ACCEPTS = supportTypes.map((ext: string) => `.${ext}`) const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? { @@ -176,6 +177,11 @@ const FileUploader = ({ if (!files.length) return false + if (files.length + fileList.length > FILES_NUMBER_LIMIT) { + notify({ type: 'error', message: t('datasetCreation.stepOne.uploader.validation.filesNumber', { filesNumber: FILES_NUMBER_LIMIT }) }) + return false + } + const preparedFiles = files.map((file, index) => ({ fileID: `file${index}-${Date.now()}`, file, @@ -185,7 +191,7 @@ const FileUploader = ({ prepareFileList(newFiles) fileListRef.current = newFiles uploadMultipleFiles(preparedFiles) - }, [prepareFileList, uploadMultipleFiles]) + }, [prepareFileList, uploadMultipleFiles, notify, t, fileList]) const handleDragEnter = (e: DragEvent) => { e.preventDefault() diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 7a2c61e28e..4688d7afee 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -42,7 +42,7 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import Tooltip from '@/app/components/base/tooltip' import TooltipPlus from '@/app/components/base/tooltip-plus' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' type ValueOf = T[keyof T] type StepTwoProps = { @@ -89,7 +89,6 @@ const StepTwo = ({ }: StepTwoProps) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const media = useBreakpoints() const isMobile = media === MediaType.mobile @@ -114,7 +113,7 @@ const StepTwo = ({ const [docForm, setDocForm] = useState( (datasetId && documentDetail) ? documentDetail.doc_form : DocForm.TEXT, ) - const [docLanguage, setDocLanguage] = useState(language !== LanguagesSupportedUnderscore[1] ? 'English' : 'Chinese') + const [docLanguage, setDocLanguage] = useState(locale !== LanguagesSupported[1] ? 'English' : 'Chinese') const [QATipHide, setQATipHide] = useState(false) const [previewSwitched, setPreviewSwitched] = useState(false) const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean() diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx index d44f36303f..36216aa7c8 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx @@ -9,7 +9,7 @@ import { useContext } from 'use-context-selector' import { Download02 as DownloadIcon } from '@/app/components/base/icons/src/vender/solid/general' import { DocForm } from '@/models/datasets' import I18n from '@/context/i18n' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' const CSV_TEMPLATE_QA_EN = [ ['question', 'answer'], @@ -35,11 +35,10 @@ const CSV_TEMPLATE_CN = [ const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const { CSVDownloader, Type } = useCSVDownloader() const getTemplate = () => { - if (language === LanguagesSupportedUnderscore[1]) { + if (locale === LanguagesSupported[1]) { if (docForm === DocForm.QA) return CSV_TEMPLATE_QA_CN return CSV_TEMPLATE_CN diff --git a/web/app/components/datasets/documents/detail/embedding/index.tsx b/web/app/components/datasets/documents/detail/embedding/index.tsx index a1432fc5d5..3c152f3f99 100644 --- a/web/app/components/datasets/documents/detail/embedding/index.tsx +++ b/web/app/components/datasets/documents/detail/embedding/index.tsx @@ -35,7 +35,7 @@ type Props = { const StopIcon = ({ className }: SVGProps) => { return - + diff --git a/web/app/components/datasets/documents/index.tsx b/web/app/components/datasets/documents/index.tsx index 7a05bec341..13f8fd8f9b 100644 --- a/web/app/components/datasets/documents/index.tsx +++ b/web/app/components/datasets/documents/index.tsx @@ -38,7 +38,7 @@ const ThreeDotsIcon = ({ className }: React.SVGProps) => { const NotionIcon = ({ className }: React.SVGProps) => { return - + diff --git a/web/app/components/develop/doc.tsx b/web/app/components/develop/doc.tsx index 698849c8fc..41ed2e9b3a 100644 --- a/web/app/components/develop/doc.tsx +++ b/web/app/components/develop/doc.tsx @@ -5,7 +5,7 @@ import TemplateZh from './template/template.zh.mdx' import TemplateChatEn from './template/template_chat.en.mdx' import TemplateChatZh from './template/template_chat.zh.mdx' import I18n from '@/context/i18n' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' type IDocProps = { appDetail: any @@ -13,7 +13,7 @@ type IDocProps = { const Doc = ({ appDetail }: IDocProps) => { const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) + const variables = appDetail?.model_config?.configs?.prompt_variables || [] const inputs = variables.reduce((res: any, variable: any) => { res[variable.key] = variable.name || '' @@ -24,10 +24,10 @@ const Doc = ({ appDetail }: IDocProps) => {
{appDetail?.mode === 'completion' ? ( - language !== LanguagesSupportedUnderscore[1] ? : + locale !== LanguagesSupported[1] ? : ) : ( - language !== LanguagesSupportedUnderscore[1] ? : + locale !== LanguagesSupported[1] ? : )}
) diff --git a/web/app/components/develop/secret-key/secret-key-modal.tsx b/web/app/components/develop/secret-key/secret-key-modal.tsx index 852c91cf37..510399e187 100644 --- a/web/app/components/develop/secret-key/secret-key-modal.tsx +++ b/web/app/components/develop/secret-key/secret-key-modal.tsx @@ -27,7 +27,7 @@ import Tooltip from '@/app/components/base/tooltip' import Loading from '@/app/components/base/loading' import Confirm from '@/app/components/base/confirm' import I18n from '@/context/i18n' -import { LanguagesSupported, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' import { useAppContext } from '@/context/app-context' type ISecretKeyModalProps = { @@ -56,7 +56,6 @@ const SecretKeyModal = ({ const [delKeyID, setDelKeyId] = useState('') const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) // const [isCopied, setIsCopied] = useState(false) const [copyValue, setCopyValue] = useState('') @@ -102,7 +101,7 @@ const SecretKeyModal = ({ } const formatDate = (timestamp: string) => { - if (language === LanguagesSupported[0]) + if (locale === LanguagesSupported[0]) return new Intl.DateTimeFormat('en-US', { year: 'numeric', month: 'long', day: 'numeric' }).format((+timestamp) * 1000) else return new Intl.DateTimeFormat('fr-CA', { year: 'numeric', month: '2-digit', day: '2-digit' }).format((+timestamp) * 1000) diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 7963b38b84..9e8dd69874 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -47,7 +47,7 @@ Chat applications support session persistence, allowing previous chat history to Allows the entry of various variable values defined by the App. - The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. + The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. Default `{}` The mode of response return, supporting: diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index 71e101e208..47f64466e7 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -44,9 +44,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户输入/提问内容。 - (选填)允许传入 App 定义的各变量值。 + 允许传入 App 定义的各变量值。 inputs 参数包含了多组键值对(Key/Value pairs),每组的键对应一个特定变量,每组的值则是该变量的具体值。 - + 默认 `{}` - `streaming` 流式模式(推荐)。基于 SSE(**[Server-Sent Events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)**)实现类似打字机输出方式的流式返回。 diff --git a/web/app/components/explore/category.tsx b/web/app/components/explore/category.tsx index dc5dbf1b7a..a2a11973b5 100644 --- a/web/app/components/explore/category.tsx +++ b/web/app/components/explore/category.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' import cn from 'classnames' -import exploreI18n from '@/i18n/lang/explore.en' +import exploreI18n from '@/i18n/en-US/explore' import type { AppCategory } from '@/models/explore' const categoryI18n = exploreI18n.category diff --git a/web/app/components/header/account-about/index.tsx b/web/app/components/header/account-about/index.tsx index 9ec99c6320..6ab0541b07 100644 --- a/web/app/components/header/account-about/index.tsx +++ b/web/app/components/header/account-about/index.tsx @@ -9,7 +9,7 @@ import { XClose } from '@/app/components/base/icons/src/vender/line/general' import type { LangGeniusVersionResponse } from '@/models/common' import { IS_CE_EDITION } from '@/config' import I18n from '@/context/i18n' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' import LogoSite from '@/app/components/base/logo/logo-site' type IAccountSettingProps = { @@ -26,7 +26,6 @@ export default function AccountAbout({ }: IAccountSettingProps) { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const isLatest = langeniusVersionInfo.current_version === langeniusVersionInfo.latest_version return ( @@ -49,8 +48,8 @@ export default function AccountAbout({ IS_CE_EDITION ? Open Source License : <> - Privacy Policy, - Terms of Service + Privacy Policy, + Terms of Service }
diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 739d5aabe4..da9d3b5dc6 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -16,7 +16,7 @@ import { useAppContext } from '@/context/app-context' import { ArrowUpRight, ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows' import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import { useModalContext } from '@/context/modal-context' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' export type IAppSelecotr = { isMobile: boolean } @@ -30,7 +30,6 @@ export default function AppSelector({ isMobile }: IAppSelecotr) { const [aboutVisible, setAboutVisible] = useState(false) const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const { t } = useTranslation() const { userProfile, langeniusVersionInfo } = useAppContext() const { setShowAccountSettingModal } = useModalContext() @@ -123,7 +122,7 @@ export default function AppSelector({ isMobile }: IAppSelecotr) {
{t('common.userProfile.helpCenter')}
diff --git a/web/app/components/header/account-setting/language-page/index.tsx b/web/app/components/header/account-setting/language-page/index.tsx index 0c5f5f4357..7f30113867 100644 --- a/web/app/components/header/account-setting/language-page/index.tsx +++ b/web/app/components/header/account-setting/language-page/index.tsx @@ -10,7 +10,7 @@ import { updateUserProfile } from '@/service/common' import { ToastContext } from '@/app/components/base/toast' import I18n from '@/context/i18n' import { timezones } from '@/utils/timezone' -import { languages } from '@/utils/language' +import { languages } from '@/i18n/language' const titleClassName = ` mb-2 text-sm font-medium text-gray-900 @@ -53,7 +53,7 @@ export default function LanguagePage() {
{t('common.language.displayLanguage')}
item.supported)} onSelect={item => handleSelect('language', item)} disabled={editing} /> diff --git a/web/app/components/header/account-setting/members-page/index.tsx b/web/app/components/header/account-setting/members-page/index.tsx index 0bd9c6cf4a..2c10bc24fa 100644 --- a/web/app/components/header/account-setting/members-page/index.tsx +++ b/web/app/components/header/account-setting/members-page/index.tsx @@ -20,7 +20,7 @@ import { useProviderContext } from '@/context/provider-context' import { Plan } from '@/app/components/billing/type' import UpgradeBtn from '@/app/components/billing/upgrade-btn' import { NUM_INFINITE } from '@/app/components/billing/config' -import { LanguagesSupportedUnderscore, getModelRuntimeSupported } from '@/utils/language' +import { LanguagesSupported } from '@/i18n/language' dayjs.extend(relativeTime) const MembersPage = () => { @@ -31,7 +31,7 @@ const MembersPage = () => { normal: t('common.members.normal'), } const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) + const { userProfile, currentWorkspace, isCurrentWorkspaceManager } = useAppContext() const { data, mutate } = useSWR({ url: '/workspaces/current/members' }, fetchMembers) const [inviteModalVisible, setInviteModalVisible] = useState(false) @@ -55,7 +55,7 @@ const MembersPage = () => { {isNotUnlimitedMemberPlan ? (
-
{t('billing.plansCommon.member')}{language !== LanguagesSupportedUnderscore[1] && accounts.length > 1 && 's'}
+
{t('billing.plansCommon.member')}{locale !== LanguagesSupported[1] && accounts.length > 1 && 's'}
{accounts.length}
/
{plan.total.teamMembers === NUM_INFINITE ? t('billing.plansCommon.unlimited') : plan.total.teamMembers}
@@ -64,7 +64,7 @@ const MembersPage = () => { : (
{accounts.length}
-
{t('billing.plansCommon.memberAfter')}{language !== LanguagesSupportedUnderscore[1] && accounts.length > 1 && 's'}
+
{t('billing.plansCommon.memberAfter')}{locale !== LanguagesSupported[1] && accounts.length > 1 && 's'}
)}
diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx index 5d7af96034..3b8cb8c699 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx @@ -15,7 +15,6 @@ import { emailRegex } from '@/config' import { ToastContext } from '@/app/components/base/toast' import type { InvitationResult } from '@/models/common' import I18n from '@/context/i18n' -import { getModelRuntimeSupported } from '@/utils/language' import 'react-multi-email/dist/style.css' type IInviteModalProps = { @@ -32,7 +31,6 @@ const InviteModal = ({ const { notify } = useContext(ToastContext) const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const InvitingRoles = useMemo(() => [ { @@ -51,7 +49,7 @@ const InviteModal = ({ try { const { result, invitation_results } = await inviteMember({ url: '/workspaces/current/members/invite-email', - body: { emails, role: role.name, language }, + body: { emails, role: role.name, language: locale }, }) if (result === 'success') { diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 7a683c8df8..da8c69b69d 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -1,8 +1,8 @@ export type FormValue = Record export type TypeWithI18N = { - 'en_US': T - 'zh_Hans': T + 'en-US': T + 'zh-Hans': T [key: string]: T } @@ -67,16 +67,16 @@ export enum ModelStatusEnum { export const MODEL_STATUS_TEXT: { [k: string]: TypeWithI18N } = { 'no-configure': { - en_US: 'No Configure', - zh_Hans: '未配置凭据', + 'en-US': 'No Configure', + 'zh-Hans': '未配置凭据', }, 'quota-exceeded': { - en_US: 'Quota Exceeded', - zh_Hans: '额度不足', + 'en-US': 'Quota Exceeded', + 'zh-Hans': '额度不足', }, 'no-permission': { - en_US: 'No Permission', - zh_Hans: '无使用权限', + 'en-US': 'No Permission', + 'zh-Hans': '无使用权限', }, } diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 57975f2f87..3b5bdbb682 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -16,7 +16,6 @@ import { ConfigurateMethodEnum, ModelTypeEnum, } from './declarations' -import { getModelRuntimeSupported } from '@/utils/language' import I18n from '@/context/i18n' import { fetchDefaultModal, @@ -59,7 +58,7 @@ export const useSystemDefaultModelAndModelList: UseDefaultModelAndModelList = ( export const useLanguage = () => { const { locale } = useContext(I18n) - return getModelRuntimeSupported(locale) + return locale.replace('-', '_') } export const useProviderCrenditialsFormSchemasValue = ( diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx index 365cefc26a..db436a1547 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx @@ -17,6 +17,7 @@ import Input from './Input' import { SimpleSelect } from '@/app/components/base/select' import Tooltip from '@/app/components/base/tooltip-plus' import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general' +import Radio from '@/app/components/base/radio' type FormProps = { value: FormValue onChange: (val: FormValue) => void @@ -47,7 +48,7 @@ const Form: FC = ({ const language = useLanguage() const [changeKey, setChangeKey] = useState('') - const handleFormChange = (key: string, val: string) => { + const handleFormChange = (key: string, val: string | boolean) => { if (isEditMode && (key === '__model_type' || key === '__model_name')) return @@ -214,6 +215,37 @@ const Form: FC = ({
) } + + if (formSchema.type === 'boolean') { + const { + variable, + label, + show_on, + } = formSchema as CredentialFormSchemaRadio + + if (show_on.length && !show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) + return null + + return ( +
+
+
+ {label[language]} + {tooltipContent} +
+ handleFormChange(variable, val === 1)} + > + True + False + +
+ {fieldMoreInfo?.(formSchema)} +
+ ) + } } return ( diff --git a/web/app/components/header/maintenance-notice.tsx b/web/app/components/header/maintenance-notice.tsx index 88e65ef526..a7f3faab8e 100644 --- a/web/app/components/header/maintenance-notice.tsx +++ b/web/app/components/header/maintenance-notice.tsx @@ -2,11 +2,10 @@ import { useState } from 'react' import { useContext } from 'use-context-selector' import I18n from '@/context/i18n' import { X } from '@/app/components/base/icons/src/vender/line/general' -import { NOTICE_I18N, getModelRuntimeSupported } from '@/utils/language' +import { NOTICE_I18N } from '@/i18n/language' const MaintenanceNotice = () => { const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) const [showNotice, setShowNotice] = useState(localStorage.getItem('hide-maintenance-notice') !== '1') const handleJumpNotice = () => { @@ -26,11 +25,11 @@ const MaintenanceNotice = () => { return (
-
{titleByLocale[language]}
+
{titleByLocale[locale]}
{ (NOTICE_I18N.href && NOTICE_I18N.href !== '#') - ?
{descByLocale[language]}
- :
{descByLocale[language]}
+ ?
{descByLocale[locale]}
+ :
{descByLocale[locale]}
}
diff --git a/web/app/components/i18n.tsx b/web/app/components/i18n.tsx index 4166449e76..7fe1df23e0 100644 --- a/web/app/components/i18n.tsx +++ b/web/app/components/i18n.tsx @@ -5,7 +5,7 @@ import React, { useEffect } from 'react' import { changeLanguage } from '@/i18n/i18next-config' import I18NContext from '@/context/i18n' import type { Locale } from '@/i18n' -import { setLocaleOnClient } from '@/i18n/client' +import { setLocaleOnClient } from '@/i18n' export type II18nProps = { locale: Locale diff --git a/web/app/components/locale-switcher.tsx b/web/app/components/locale-switcher.tsx deleted file mode 100644 index 018b3f907f..0000000000 --- a/web/app/components/locale-switcher.tsx +++ /dev/null @@ -1,23 +0,0 @@ -'use client' - -import { i18n } from '@/i18n' -import { setLocaleOnClient } from '@/i18n/client' - -const LocaleSwitcher = () => { - return ( -
-

Locale switcher:

-
    - {i18n.locales.map((locale) => { - return ( -
  • -
    setLocaleOnClient(locale)}>{locale}
    -
  • - ) - })} -
-
- ) -} - -export default LocaleSwitcher diff --git a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx index 150768096d..1deef1b531 100644 --- a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx @@ -79,7 +79,9 @@ const ConfigCredential: FC = ({ setTempCredential({ ...tempCredential, api_key_header: e.target.value })} - className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' /> + className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' + placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!} + />
@@ -87,7 +89,9 @@ const ConfigCredential: FC = ({ setTempCredential({ ...tempCredential, api_key_value: e.target.value })} - className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' /> + className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' + placeholder={t('tools.createTool.authMethod.types.apiValuePlaceholder')!} + />
)} diff --git a/web/app/components/tools/edit-custom-collection-modal/test-api.tsx b/web/app/components/tools/edit-custom-collection-modal/test-api.tsx index a2e3c454d4..b1a494317e 100644 --- a/web/app/components/tools/edit-custom-collection-modal/test-api.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/test-api.tsx @@ -10,7 +10,7 @@ import Button from '@/app/components/base/button' import Drawer from '@/app/components/base/drawer-plus' import I18n from '@/context/i18n' import { testAPIAvailable } from '@/service/tools' -import { getModelRuntimeSupported } from '@/utils/language' +import { getLanguage } from '@/i18n/language' type Props = { customCollection: CustomCollectionBackend @@ -27,7 +27,7 @@ const TestApi: FC = ({ }) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) + const language = getLanguage(locale) const [credentialsModalShow, setCredentialsModalShow] = useState(false) const [tempCredential, setTempCredential] = React.useState(customCollection.credentials) const [result, setResult] = useState('') diff --git a/web/app/components/tools/tool-list/header.tsx b/web/app/components/tools/tool-list/header.tsx index bb3907b8b8..5a243a0a2b 100644 --- a/web/app/components/tools/tool-list/header.tsx +++ b/web/app/components/tools/tool-list/header.tsx @@ -8,8 +8,7 @@ import type { Collection } from '../types' import { CollectionType, LOC } from '../types' import { Settings01 } from '../../base/icons/src/vender/line/general' import I18n from '@/context/i18n' -import { getModelRuntimeSupported } from '@/utils/language' - +import { getLanguage } from '@/i18n/language' type Props = { icon: JSX.Element collection: Collection @@ -26,7 +25,7 @@ const Header: FC = ({ onShowEditCustomCollection, }) => { const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) + const language = getLanguage(locale) const { t } = useTranslation() const isInToolsPage = loc === LOC.tools const isInDebugPage = !isInToolsPage diff --git a/web/app/components/tools/tool-list/item.tsx b/web/app/components/tools/tool-list/item.tsx index f4437f065a..c4e856e11a 100644 --- a/web/app/components/tools/tool-list/item.tsx +++ b/web/app/components/tools/tool-list/item.tsx @@ -10,8 +10,7 @@ import { CollectionType } from '../types' import TooltipPlus from '../../base/tooltip-plus' import I18n from '@/context/i18n' import SettingBuiltInTool from '@/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool' -import { getModelRuntimeSupported } from '@/utils/language' - +import { getLanguage } from '@/i18n/language' type Props = { collection: Collection icon: JSX.Element @@ -33,7 +32,8 @@ const Item: FC = ({ }) => { const { t } = useTranslation() const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) + const language = getLanguage(locale) + const isBuiltIn = collection.type === CollectionType.builtIn const canShowDetail = !isBuiltIn || (isBuiltIn && isInToolsPage) const [showDetail, setShowDetail] = useState(false) diff --git a/web/app/components/tools/tool-nav-list/item.tsx b/web/app/components/tools/tool-nav-list/item.tsx index 5e32503d18..2d3606d13e 100644 --- a/web/app/components/tools/tool-nav-list/item.tsx +++ b/web/app/components/tools/tool-nav-list/item.tsx @@ -6,7 +6,7 @@ import cn from 'classnames' import AppIcon from '../../base/app-icon' import type { Collection } from '@/app/components/tools/types' import I18n from '@/context/i18n' -import { getModelRuntimeSupported } from '@/utils/language' +import { getLanguage } from '@/i18n/language' type Props = { isCurrent: boolean @@ -20,7 +20,7 @@ const Item: FC = ({ onClick, }) => { const { locale } = useContext(I18n) - const language = getModelRuntimeSupported(locale) + const language = getLanguage(locale) return (
, formSchemas: { varia const newValues = { ...value } formSchemas.forEach((formSchema) => { const itemValue = value[formSchema.variable] - if (formSchema.default && (value === undefined || itemValue === null || itemValue === '')) + if ((formSchema.default !== undefined) && (value === undefined || itemValue === null || itemValue === '' || itemValue === undefined)) newValues[formSchema.variable] = formSchema.default }) return newValues diff --git a/web/app/install/installForm.tsx b/web/app/install/installForm.tsx index 91d6933d99..284062d26f 100644 --- a/web/app/install/installForm.tsx +++ b/web/app/install/installForm.tsx @@ -17,8 +17,6 @@ const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ const InstallForm = () => { const { t } = useTranslation() - // const { locale } = useContext(I18n) - // const language = getModelRuntimeSupported(locale) const router = useRouter() const [email, setEmail] = React.useState('') diff --git a/web/app/signin/_header.tsx b/web/app/signin/_header.tsx index 1aec9720a5..7180a66817 100644 --- a/web/app/signin/_header.tsx +++ b/web/app/signin/_header.tsx @@ -2,7 +2,7 @@ import React from 'react' import { useContext } from 'use-context-selector' import Select from '@/app/components/base/select/locale' -import { languages } from '@/utils/language' +import { languages } from '@/i18n/language' import { type Locale } from '@/i18n' import I18n from '@/context/i18n' import LogoSite from '@/app/components/base/logo/logo-site' @@ -17,7 +17,7 @@ const Header = () => {