From 94d04934b3aa943fce8ad9dfc5c9982e1375bab5 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Fri, 29 Mar 2024 22:15:16 +0800 Subject: [PATCH 1/6] fix: agent tool label (#3039) --- .../config/agent/agent-tools/choose-tool/index.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/choose-tool/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/choose-tool/index.tsx index 0b0e0676b7..d47406e95b 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/choose-tool/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/choose-tool/index.tsx @@ -59,7 +59,7 @@ const ChooseTool: FC = ({ provider_type: collection.type, provider_name: collection.name, tool_name: tool.name, - tool_label: tool.label[locale], + tool_label: tool.label[locale] || tool.label[locale.replaceAll('-', '_')], tool_parameters: parameters, enabled: true, }) From fc5ed17fe9708fe316b29b6094d5cc1f3aeb2b72 Mon Sep 17 00:00:00 2001 From: Leo Q Date: Sat, 30 Mar 2024 14:44:50 +0800 Subject: [PATCH 2/6] provide a bit more info in logs when parsing api schema error (#3026) --- api/core/tools/utils/parser.py | 130 ++++++++++----------------- api/services/tools_manage_service.py | 6 +- 2 files changed, 52 insertions(+), 84 deletions(-) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index de4ecc8708..5efd2e49b9 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,10 +1,12 @@ import re import uuid +from json import dumps as json_dumps from json import loads as json_loads +from json.decoder import JSONDecodeError from requests import get -from yaml import FullLoader, load +from yaml import YAMLError, safe_load from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiBasedToolBundle @@ -184,27 +186,11 @@ class ApiBasedToolSchemaParser: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} - openapi: dict = load(yaml, Loader=FullLoader) + openapi: dict = safe_load(yaml) if openapi is None: raise ToolApiSchemaError('Invalid openapi yaml.') return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) - @staticmethod - def parse_openapi_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: - """ - parse openapi yaml to tool bundle - - :param yaml: the yaml string - :return: the tool bundle - """ - warning = warning if warning is not None else {} - extra_info = extra_info if extra_info is not None else {} - - openapi: dict = json_loads(json) - if openapi is None: - raise ToolApiSchemaError('Invalid openapi json.') - return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) - @staticmethod def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict: """ @@ -271,38 +257,6 @@ class ApiBasedToolSchemaParser: return openapi - @staticmethod - def parse_swagger_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: - """ - parse swagger yaml to tool bundle - - :param yaml: the yaml string - :return: the tool bundle - """ - warning = warning if warning is not None else {} - extra_info = extra_info if extra_info is not None else {} - - swagger: dict = load(yaml, Loader=FullLoader) - - openapi = ApiBasedToolSchemaParser.parse_swagger_to_openapi(swagger, extra_info=extra_info, warning=warning) - return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) - - @staticmethod - def parse_swagger_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: - """ - parse swagger yaml to tool bundle - - :param yaml: the yaml string - :return: the tool bundle - """ - warning = warning if warning is not None else {} - extra_info = extra_info if extra_info is not None else {} - - swagger: dict = json_loads(json) - - openapi = ApiBasedToolSchemaParser.parse_swagger_to_openapi(swagger, extra_info=extra_info, warning=warning) - return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) - @staticmethod def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: """ @@ -346,40 +300,50 @@ class ApiBasedToolSchemaParser: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} - json_possible = False content = content.strip() + loaded_content = None + json_error = None + yaml_error = None + + try: + loaded_content = json_loads(content) + except JSONDecodeError as e: + json_error = e - if content.startswith('{') and content.endswith('}'): - json_possible = True + if loaded_content is None: + try: + loaded_content = safe_load(content) + except YAMLError as e: + yaml_error = e + if loaded_content is None: + raise ToolApiSchemaError(f'Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}') - if json_possible: - try: - return ApiBasedToolSchemaParser.parse_openapi_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ - ApiProviderSchemaType.OPENAPI.value - except: - pass + swagger_error = None + openapi_error = None + openapi_plugin_error = None + schema_type = None + + try: + openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(loaded_content, extra_info=extra_info, warning=warning) + schema_type = ApiProviderSchemaType.OPENAPI.value + return openapi, schema_type + except ToolApiSchemaError as e: + openapi_error = e + + # openai parse error, fallback to swagger + try: + converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(loaded_content, extra_info=extra_info, warning=warning) + schema_type = ApiProviderSchemaType.SWAGGER.value + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(converted_swagger, extra_info=extra_info, warning=warning), schema_type + except ToolApiSchemaError as e: + swagger_error = e + + # swagger parse error, fallback to openai plugin + try: + openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(json_dumps(loaded_content), extra_info=extra_info, warning=warning) + return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value + except ToolNotSupportedError as e: + # maybe it's not plugin at all + openapi_plugin_error = e - try: - return ApiBasedToolSchemaParser.parse_swagger_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ - ApiProviderSchemaType.SWAGGER.value - except: - pass - try: - return ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ - ApiProviderSchemaType.OPENAI_PLUGIN.value - except: - pass - else: - try: - return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ - ApiProviderSchemaType.OPENAPI.value - except: - pass - - try: - return ApiBasedToolSchemaParser.parse_swagger_yaml_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ - ApiProviderSchemaType.SWAGGER.value - except: - pass - - raise ToolApiSchemaError('Invalid api schema.') \ No newline at end of file + raise ToolApiSchemaError(f'Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}') diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index 70c6a44459..c1160f605c 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -1,4 +1,5 @@ import json +import logging from flask import current_app from httpx import get @@ -24,6 +25,8 @@ from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider from services.model_provider_service import ModelProviderService +logger = logging.getLogger(__name__) + class ToolManageService: @staticmethod @@ -309,6 +312,7 @@ class ToolManageService: # try to parse schema, avoid SSRF attack ToolManageService.parser_api_schema(schema) except Exception as e: + logger.error(f"parse api schema error: {str(e)}") raise ValueError('invalid schema, please check the url you provided') return { @@ -655,4 +659,4 @@ class ToolManageService: except Exception as e: return { 'error': str(e) } - return { 'result': result or 'empty response' } \ No newline at end of file + return { 'result': result or 'empty response' } From 12782cad4da0ea621ebe16d39568443fdb3b9e20 Mon Sep 17 00:00:00 2001 From: Nanguan Lin Date: Sun, 31 Mar 2024 12:41:16 +0800 Subject: [PATCH 3/6] Fix typo (#3041) --- .github/pull_request_template.md | 2 +- web/i18n/en-US/app-annotation.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 965831ebe3..da788d8f32 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -12,7 +12,7 @@ Please delete options that are not relevant. - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs) -- [ ] Improvement,including but not limited to code refactoring, performance optimization, and UI/UX improvement +- [ ] Improvement, including but not limited to code refactoring, performance optimization, and UI/UX improvement - [ ] Dependency upgrade # How Has This Been Tested? diff --git a/web/i18n/en-US/app-annotation.ts b/web/i18n/en-US/app-annotation.ts index 890e263d7f..43f24a7619 100644 --- a/web/i18n/en-US/app-annotation.ts +++ b/web/i18n/en-US/app-annotation.ts @@ -4,7 +4,7 @@ const translation = { editBy: 'Answer edited by {{author}}', noData: { title: 'No annotations', - description: 'You can edit annotations in app debuggiung, or import annotations in bulk here for high-quality response.', + description: 'You can edit annotations during app debugging or import annotations in bulk here for a high-quality response.', }, table: { header: { From e215aae39abd15eacd8ba7bdaf5d381a2b811e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Sun, 31 Mar 2024 12:44:11 +0800 Subject: [PATCH 4/6] feat:xinference audio model support (#3045) --- .../xinference/speech2text/__init__.py | 0 .../xinference/speech2text/speech2text.py | 148 ++++++++++++++++++ .../xinference/xinference.yaml | 1 + api/requirements.txt | 4 +- 4 files changed, 151 insertions(+), 2 deletions(-) create mode 100644 api/core/model_runtime/model_providers/xinference/speech2text/__init__.py create mode 100644 api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/__init__.py b/api/core/model_runtime/model_providers/xinference/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py new file mode 100644 index 0000000000..35269fceca --- /dev/null +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -0,0 +1,148 @@ +from typing import IO, Optional + +from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel + + +class XinferenceSpeech2TextModel(Speech2TextModel): + """ + Model class for Xinference speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, + file: IO[bytes], user: Optional[str] = None) \ + -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + return self._speech2text_invoke(model, credentials, file) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, 'rb') as audio_file: + self.invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def _speech2text_invoke( + self, + model: str, + credentials: dict, + file: IO[bytes], + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: Optional[str] = "json", + temperature: Optional[float] = 0, + ) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpe g,mpga, m4a, ogg, wav, or webm. + :param language: The language of the input audio. Supplying the input language in ISO-639-1 + :param prompt: An optional text to guide the model's style or continue a previous audio segment. + The prompt should match the audio language. + :param response_format: The format of the transcript output, in one of these options: json, text, srt, verbose _json, or vtt. + :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit. + :return: text for given audio file + """ + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + # initialize client + client = Client( + base_url=credentials['server_url'] + ) + + xinference_client = client.get_model(model_uid=credentials['model_uid']) + + if not isinstance(xinference_client, RESTfulAudioModelHandle): + raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model') + + response = xinference_client.transcriptions( + audio=file, + language = language, + prompt = prompt, + response_format = response_format, + temperature = temperature + ) + + return response["text"] + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={ }, + parameter_rules=[] + ) + + return entity diff --git a/api/core/model_runtime/model_providers/xinference/xinference.yaml b/api/core/model_runtime/model_providers/xinference/xinference.yaml index bb6c6d8668..6744c34268 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference.yaml +++ b/api/core/model_runtime/model_providers/xinference/xinference.yaml @@ -16,6 +16,7 @@ supported_model_types: - llm - text-embedding - rerank + - speech2text configurate_methods: - customizable-model model_credential_schema: diff --git a/api/requirements.txt b/api/requirements.txt index 1601f5da00..c534d36968 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -48,7 +48,7 @@ dashscope[tokenizer]~=1.14.0 huggingface_hub~=0.16.4 transformers~=4.31.0 pandas==1.5.3 -xinference-client==0.8.4 +xinference-client==0.9.4 safetensors==0.3.2 zhipuai==1.0.7 werkzeug~=3.0.1 @@ -73,4 +73,4 @@ yarl~=1.9.4 twilio==9.0.0 qrcode~=7.4.2 azure-storage-blob==12.9.0 -azure-identity==1.15.0 \ No newline at end of file +azure-identity==1.15.0 From 1716ac562ccd10efd1903fa659d0ea0174565b87 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 1 Apr 2024 01:34:21 +0800 Subject: [PATCH 5/6] add clean_unused_datasets_task (#3057) Co-authored-by: jyong --- api/extensions/ext_celery.py | 2 +- api/requirements.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index b27105f4d0..89a0924763 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -50,7 +50,7 @@ def init_app(app: Flask) -> Celery: }, 'clean_unused_datasets_task': { 'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task', - 'schedule': timedelta(days=7), + 'schedule': timedelta(minutes=3), } } celery_app.conf.update( diff --git a/api/requirements.txt b/api/requirements.txt index c534d36968..0c7d568aa7 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -74,3 +74,4 @@ twilio==9.0.0 qrcode~=7.4.2 azure-storage-blob==12.9.0 azure-identity==1.15.0 +lxml==5.1.0 \ No newline at end of file From 84d118de074ced0ea9c21fed0d62bc315c3e5553 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 1 Apr 2024 02:10:41 +0800 Subject: [PATCH 6/6] add redis lock on create collection in multiple thread mode (#3054) Co-authored-by: jyong --- .../rag/datasource/keyword/jieba/jieba.py | 41 ++++---- .../datasource/vdb/milvus/milvus_vector.py | 97 ++++++++++--------- .../datasource/vdb/qdrant/qdrant_vector.py | 73 +++++++------- .../vdb/weaviate/weaviate_vector.py | 22 +++-- 4 files changed, 128 insertions(+), 105 deletions(-) diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 94a692637f..344ef7babe 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -8,6 +8,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK from core.rag.datasource.keyword.keyword_base import BaseKeyword from core.rag.models.document import Document from extensions.ext_database import db +from extensions.ext_redis import redis_client from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment @@ -121,26 +122,28 @@ class Jieba(BaseKeyword): db.session.commit() def _get_dataset_keyword_table(self) -> Optional[dict]: - dataset_keyword_table = self.dataset.dataset_keyword_table - if dataset_keyword_table: - if dataset_keyword_table.keyword_table_dict: - return dataset_keyword_table.keyword_table_dict['__data__']['table'] - else: - dataset_keyword_table = DatasetKeywordTable( - dataset_id=self.dataset.id, - keyword_table=json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) - ) - db.session.add(dataset_keyword_table) - db.session.commit() + lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + with redis_client.lock(lock_name, timeout=20): + dataset_keyword_table = self.dataset.dataset_keyword_table + if dataset_keyword_table: + if dataset_keyword_table.keyword_table_dict: + return dataset_keyword_table.keyword_table_dict['__data__']['table'] + else: + dataset_keyword_table = DatasetKeywordTable( + dataset_id=self.dataset.id, + keyword_table=json.dumps({ + '__type__': 'keyword_table', + '__data__': { + "index_id": self.dataset.id, + "summary": None, + "table": {} + } + }, cls=SetEncoder) + ) + db.session.add(dataset_keyword_table) + db.session.commit() - return {} + return {} def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: for keyword in keywords: diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index f62d603d8d..dcb37ccbe6 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -8,6 +8,7 @@ from pymilvus import MilvusClient, MilvusException, connections from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document +from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -61,17 +62,7 @@ class MilvusVector(BaseVector): 'params': {"M": 8, "efConstruction": 64} } metadatas = [d.metadata for d in texts] - - # Grab the existing collection if it exists - from pymilvus import utility - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) - if not utility.has_collection(self._collection_name, using=alias): - self.create_collection(embeddings, metadatas, index_params) + self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -187,46 +178,60 @@ class MilvusVector(BaseVector): def create_collection( self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None - ) -> str: - from pymilvus import CollectionSchema, DataType, FieldSchema - from pymilvus.orm.types import infer_dtype_bydata + ): + lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + # Grab the existing collection if it exists + from pymilvus import utility + alias = uuid4().hex + if self._client_config.secure: + uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) + else: + uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) + connections.connect(alias=alias, uri=uri, user=self._client_config.user, + password=self._client_config.password) + if not utility.has_collection(self._collection_name, using=alias): + from pymilvus import CollectionSchema, DataType, FieldSchema + from pymilvus.orm.types import infer_dtype_bydata - # Determine embedding dim - dim = len(embeddings[0]) - fields = [] - if metadatas: - fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) + # Determine embedding dim + dim = len(embeddings[0]) + fields = [] + if metadatas: + fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) - # Create the text field - fields.append( - FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) - ) - # Create the primary key field - fields.append( - FieldSchema( - Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True - ) - ) - # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) - ) + # Create the text field + fields.append( + FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) + ) + # Create the primary key field + fields.append( + FieldSchema( + Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True + ) + ) + # Create the vector field, supports binary or float vectors + fields.append( + FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) + ) - # Create the schema for the collection - schema = CollectionSchema(fields) + # Create the schema for the collection + schema = CollectionSchema(fields) - for x in schema.fields: - self._fields.append(x.name) - # Since primary field is auto-id, no need to track it - self._fields.remove(Field.PRIMARY_KEY.value) - - # Create the collection - collection_name = self._collection_name - self._client.create_collection_with_schema(collection_name=collection_name, - schema=schema, index_param=index_params, - consistency_level=self._consistency_level) - return collection_name + for x in schema.fields: + self._fields.append(x.name) + # Since primary field is auto-id, no need to track it + self._fields.remove(Field.PRIMARY_KEY.value) + # Create the collection + collection_name = self._collection_name + self._client.create_collection_with_schema(collection_name=collection_name, + schema=schema, index_param=index_params, + consistency_level=self._consistency_level) + redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: if config.secure: uri = "https://" + str(config.host) + ":" + str(config.port) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 436e6b5f6a..41e8c6154a 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -20,6 +20,7 @@ from qdrant_client.local.qdrant_local import QdrantLocal from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document +from extensions.ext_redis import redis_client if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -77,6 +78,17 @@ class QdrantVector(BaseVector): vector_size = len(embeddings[0]) # get collection name collection_name = self._collection_name + # create collection + self.create_collection(collection_name, vector_size) + + self.add_texts(texts, embeddings, **kwargs) + + def create_collection(self, collection_name: str, vector_size: int): + lock_name = 'vector_indexing_lock_{}'.format(collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return collection_name = collection_name or uuid.uuid4().hex all_collection_name = [] collections_response = self._client.get_collections() @@ -84,40 +96,35 @@ class QdrantVector(BaseVector): for collection in collection_list: all_collection_name.append(collection.name) if collection_name not in all_collection_name: - # create collection - self.create_collection(collection_name, vector_size) + from qdrant_client.http import models as rest + vectors_config = rest.VectorParams( + size=vector_size, + distance=rest.Distance[self._distance_func], + ) + hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, + max_indexing_threads=0, on_disk=False) + self._client.recreate_collection( + collection_name=collection_name, + vectors_config=vectors_config, + hnsw_config=hnsw_config, + timeout=int(self._client_config.timeout), + ) - self.add_texts(texts, embeddings, **kwargs) - - def create_collection(self, collection_name: str, vector_size: int): - from qdrant_client.http import models as rest - vectors_config = rest.VectorParams( - size=vector_size, - distance=rest.Distance[self._distance_func], - ) - hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False) - self._client.recreate_collection( - collection_name=collection_name, - vectors_config=vectors_config, - hnsw_config=hnsw_config, - timeout=int(self._client_config.timeout), - ) - - # create payload index - self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, - field_schema=PayloadSchemaType.KEYWORD, - field_type=PayloadSchemaType.KEYWORD) - # creat full text index - text_index_params = TextIndexParams( - type=TextIndexType.TEXT, - tokenizer=TokenizerType.MULTILINGUAL, - min_token_len=2, - max_token_len=20, - lowercase=True - ) - self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, - field_schema=text_index_params) + # create payload index + self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, + field_schema=PayloadSchemaType.KEYWORD, + field_type=PayloadSchemaType.KEYWORD) + # creat full text index + text_index_params = TextIndexParams( + type=TextIndexType.TEXT, + tokenizer=TokenizerType.MULTILINGUAL, + min_token_len=2, + max_token_len=20, + lowercase=True + ) + self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, + field_schema=text_index_params) + redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 5d24ee9fd2..59fbaeee6a 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, root_validator from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document +from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -79,16 +80,23 @@ class WeaviateVector(BaseVector): } def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - - schema = self._default_schema(self._collection_name) - - # check whether the index already exists - if not self._client.schema.contains(schema): - # create collection - self._client.schema.create_class(schema) + # create collection + self._create_collection() # create vector self.add_texts(texts, embeddings) + def _create_collection(self): + lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + schema = self._default_schema(self._collection_name) + if not self._client.schema.contains(schema): + # create collection + self._client.schema.create_class(schema) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents]