diff --git a/api/.env.example b/api/.env.example index 832c7e3bab..28bdcad7ed 100644 --- a/api/.env.example +++ b/api/.env.example @@ -39,7 +39,7 @@ DB_DATABASE=dify # Storage configuration # use for store upload files, private keys... -# storage type: local, s3 +# storage type: local, s3, azure-blob STORAGE_TYPE=local STORAGE_LOCAL_PATH=storage S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com @@ -47,6 +47,11 @@ S3_BUCKET_NAME=your-bucket-name S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key S3_REGION=your-region +# Azure Blob Storage configuration +AZURE_BLOB_ACCOUNT_NAME=your-account-name +AZURE_BLOB_ACCOUNT_KEY=your-account-key +AZURE_BLOB_CONTAINER_NAME=yout-container-name +AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net # CORS configuration WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* diff --git a/api/config.py b/api/config.py index ed933372a2..76d442e5f9 100644 --- a/api/config.py +++ b/api/config.py @@ -186,6 +186,10 @@ class Config: self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY') self.S3_SECRET_KEY = get_env('S3_SECRET_KEY') self.S3_REGION = get_env('S3_REGION') + self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME') + self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY') + self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME') + self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL') # ------------------------ # Vector Store Configurations. diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 3f7cfcaea8..5d3a081357 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -197,11 +197,11 @@ class DatasetSegmentApi(DatasetApiResource): # validate args parser = reqparse.RequestParser() - parser.add_argument('segments', type=dict, required=False, nullable=True, location='json') + parser.add_argument('segment', type=dict, required=False, nullable=True, location='json') args = parser.parse_args() - SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(args, segment, document, dataset) + SegmentService.segment_create_args_validate(args['segment'], document) + segment = SegmentService.update_segment(args['segment'], segment, document, dataset) return { 'data': marshal(segment, segment_fields), 'doc_form': document.doc_form diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 049ad67a77..7b4416f44e 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -11,6 +11,8 @@ - groq - replicate - huggingface_hub +- xinference +- triton_inference_server - zhipuai - baichuan - spark @@ -20,7 +22,6 @@ - moonshot - jina - chatglm -- xinference - yi - openllm - localai diff --git a/api/core/model_runtime/model_providers/triton_inference_server/__init__.py b/api/core/model_runtime/model_providers/triton_inference_server/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_l_en.png new file mode 100644 index 0000000000..dd32d45803 Binary files /dev/null and b/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_s_en.svg new file mode 100644 index 0000000000..9fc02f9164 --- /dev/null +++ b/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_s_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/__init__.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py new file mode 100644 index 0000000000..95272a41c2 --- /dev/null +++ b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py @@ -0,0 +1,267 @@ +from collections.abc import Generator + +from httpx import Response, post +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + + +class TritonInferenceAILargeLanguageModel(LargeLanguageModel): + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + """ + invoke LLM + + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + """ + return self._generate( + model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, + tools=tools, stop=stop, stream=stream, user=user, + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + validate credentials + """ + if 'server_url' not in credentials: + raise CredentialsValidateFailedError('server_url is required in credentials') + + try: + self._invoke(model=model, credentials=credentials, prompt_messages=[ + UserPromptMessage(content='ping') + ], model_parameters={}, stream=False) + except InvokeError as ex: + raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}') + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None) -> int: + """ + get number of tokens + + cause TritonInference LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default + """ + return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) + + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: + """ + convert prompt message to text + """ + text = '' + for item in message: + if isinstance(item, UserPromptMessage): + text += f'User: {item.content}' + elif isinstance(item, SystemPromptMessage): + text += f'System: {item.content}' + elif isinstance(item, AssistantPromptMessage): + text += f'Assistant: {item.content}' + else: + raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + return text + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ), + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=int(credentials.get('context_length', 2048)), + default=min(512, int(credentials.get('context_length', 2048))), + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ) + ] + + completion_type = None + + if 'completion_type' in credentials: + if credentials['completion_type'] == 'chat': + completion_type = LLMMode.CHAT.value + elif credentials['completion_type'] == 'completion': + completion_type = LLMMode.COMPLETION.value + else: + raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + parameter_rules=rules, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + model_properties={ + ModelPropertyKey.MODE: completion_type, + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)), + }, + ) + + return entity + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + """ + generate text from LLM + """ + if 'server_url' not in credentials: + raise CredentialsValidateFailedError('server_url is required in credentials') + + if 'stream' in credentials and not bool(credentials['stream']) and stream: + raise ValueError(f'stream is not supported by model {model}') + + try: + parameters = {} + if 'temperature' in model_parameters: + parameters['temperature'] = model_parameters['temperature'] + if 'top_p' in model_parameters: + parameters['top_p'] = model_parameters['top_p'] + if 'top_k' in model_parameters: + parameters['top_k'] = model_parameters['top_k'] + if 'presence_penalty' in model_parameters: + parameters['presence_penalty'] = model_parameters['presence_penalty'] + if 'frequency_penalty' in model_parameters: + parameters['frequency_penalty'] = model_parameters['frequency_penalty'] + + response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={ + 'text_input': self._convert_prompt_message_to_text(prompt_messages), + 'max_tokens': model_parameters.get('max_tokens', 512), + 'parameters': { + 'stream': False, + **parameters + }, + }, timeout=(10, 120)) + response.raise_for_status() + if response.status_code != 200: + raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}') + + if stream: + return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=response) + return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=response) + except Exception as ex: + raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}') + + def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response) -> LLMResult: + """ + handle normal chat generate response + """ + text = resp.json()['text_output'] + + usage = LLMUsage.empty_usage() + usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + usage.completion_tokens = self._get_num_tokens_by_gpt2(text) + + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=text + ), + usage=usage + ) + + def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response) -> Generator: + """ + handle normal chat generate response + """ + text = resp.json()['text_output'] + + usage = LLMUsage.empty_usage() + usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + usage.completion_tokens = self._get_num_tokens_by_gpt2(text) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=text + ), + usage=usage + ) + ) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + ], + InvokeServerUnavailableError: [ + ], + InvokeRateLimitError: [ + ], + InvokeAuthorizationError: [ + ], + InvokeBadRequestError: [ + ValueError + ] + } \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py new file mode 100644 index 0000000000..06846825ab --- /dev/null +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py @@ -0,0 +1,9 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + +class XinferenceAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml new file mode 100644 index 0000000000..50a804743d --- /dev/null +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml @@ -0,0 +1,84 @@ +provider: triton_inference_server +label: + en_US: Triton Inference Server +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#EFFDFD" +help: + title: + en_US: How to deploy Triton Inference Server + zh_Hans: 如何部署 Triton Inference Server + url: + en_US: https://github.com/triton-inference-server/server +supported_model_types: + - llm +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: server_url + label: + zh_Hans: 服务器URL + en_US: Server url + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000 + en_US: Enter the url of your Triton Inference Server, e.g. http://192.168.1.100:8000 + - variable: context_size + label: + zh_Hans: 上下文大小 + en_US: Context size + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的上下文大小 + en_US: Enter the context size + default: 2048 + - variable: completion_type + label: + zh_Hans: 补全类型 + en_US: Model type + type: select + required: true + default: chat + placeholder: + zh_Hans: 在此输入您的补全类型 + en_US: Enter the completion type + options: + - label: + zh_Hans: 补全模型 + en_US: Completion model + value: completion + - label: + zh_Hans: 对话模型 + en_US: Chat model + value: chat + - variable: stream + label: + zh_Hans: 流式输出 + en_US: Stream output + type: select + required: true + default: true + placeholder: + zh_Hans: 是否支持流式输出 + en_US: Whether to support stream output + options: + - label: + zh_Hans: 是 + en_US: Yes + value: true + - label: + zh_Hans: 否 + en_US: No + value: false diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index 8ad6462514..d33f38333b 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -1,30 +1,119 @@ from collections.abc import Generator from typing import Optional, Union +from urllib.parse import urlparse + +import tiktoken from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, + SystemPromptMessage, ) -from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel +from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel -class YiLargeLanguageModel(OAIAPICompatLargeLanguageModel): +class YiLargeLanguageModel(OpenAILargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) + + # yi-vl-plus not support system prompt yet. + if model == "yi-vl-plus": + prompt_message_except_system: list[PromptMessage] = [] + for message in prompt_messages: + if not isinstance(message, SystemPromptMessage): + prompt_message_except_system.append(message) + return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) + # refactored from openai model runtime, use cl100k_base for calculate token number + def _num_tokens_from_string(self, model: str, text: str, + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Calculate num tokens for text completion model with tiktoken package. + + :param model: model name + :param text: prompt text + :param tools: tools for tool calling + :return: number of tokens + """ + encoding = tiktoken.get_encoding("cl100k_base") + num_tokens = len(encoding.encode(text)) + + if tools: + num_tokens += self._num_tokens_for_tools(encoding, tools) + + return num_tokens + + # refactored from openai model runtime, use cl100k_base for calculate token number + def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. + + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + encoding = tiktoken.get_encoding("cl100k_base") + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + # Cast str(value) in case the message value is not a string + # This occurs with function messages + # TODO: The current token calculation method for the image type is not implemented, + # which need to download the image and then get the resolution for calculation, + # and will increase the request delay + if isinstance(value, list): + text = '' + for item in value: + if isinstance(item, dict) and item['type'] == 'text': + text += item['text'] + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += len(encoding.encode(t_key)) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += len(encoding.encode(f_key)) + num_tokens += len(encoding.encode(f_value)) + else: + num_tokens += len(encoding.encode(t_key)) + num_tokens += len(encoding.encode(t_value)) + else: + num_tokens += len(encoding.encode(str(value))) + + if key == "name": + num_tokens += tokens_per_name + + # every reply is primed with assistant + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(encoding, tools) + + return num_tokens + @staticmethod def _add_custom_parameters(credentials: dict) -> None: credentials['mode'] = 'chat' - + credentials['openai_api_key']=credentials['api_key'] if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['endpoint_url'] = 'https://api.lingyiwanwu.com/v1' + credentials['openai_api_base']='https://api.lingyiwanwu.com' + else: + parsed_url = urlparse(credentials['endpoint_url']) + credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 109d36583c..27ae15a025 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -66,7 +66,7 @@ class Vector: raise ValueError('Dataset Collection Bindings is not exist!') else: if self._dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] collection_name = class_prefix else: dataset_id = self._dataset.id diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 3ce9935e79..497ce5d2b7 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -2,9 +2,11 @@ import os import shutil from collections.abc import Generator from contextlib import closing +from datetime import datetime, timedelta from typing import Union import boto3 +from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas from botocore.exceptions import ClientError from flask import Flask @@ -27,6 +29,18 @@ class Storage: endpoint_url=app.config.get('S3_ENDPOINT'), region_name=app.config.get('S3_REGION') ) + elif self.storage_type == 'azure-blob': + self.bucket_name = app.config.get('AZURE_BLOB_CONTAINER_NAME') + sas_token = generate_account_sas( + account_name=app.config.get('AZURE_BLOB_ACCOUNT_NAME'), + account_key=app.config.get('AZURE_BLOB_ACCOUNT_KEY'), + resource_types=ResourceTypes(service=True, container=True, object=True), + permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), + expiry=datetime.utcnow() + timedelta(hours=1) + ) + self.client = BlobServiceClient(account_url=app.config.get('AZURE_BLOB_ACCOUNT_URL'), + credential=sas_token) + else: self.folder = app.config.get('STORAGE_LOCAL_PATH') if not os.path.isabs(self.folder): @@ -35,6 +49,9 @@ class Storage: def save(self, filename, data): if self.storage_type == 's3': self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) + elif self.storage_type == 'azure-blob': + blob_container = self.client.get_container_client(container=self.bucket_name) + blob_container.upload_blob(filename, data) else: if not self.folder or self.folder.endswith('/'): filename = self.folder + filename @@ -63,6 +80,10 @@ class Storage: raise FileNotFoundError("File not found") else: raise + elif self.storage_type == 'azure-blob': + blob = self.client.get_container_client(container=self.bucket_name) + blob = blob.get_blob_client(blob=filename) + data = blob.download_blob().readall() else: if not self.folder or self.folder.endswith('/'): filename = self.folder + filename @@ -90,6 +111,11 @@ class Storage: raise FileNotFoundError("File not found") else: raise + elif self.storage_type == 'azure-blob': + blob = self.client.get_blob_client(container=self.bucket_name, blob=filename) + with closing(blob.download_blob()) as blob_stream: + while chunk := blob_stream.readall(4096): + yield chunk else: if not self.folder or self.folder.endswith('/'): filename = self.folder + filename @@ -109,6 +135,11 @@ class Storage: if self.storage_type == 's3': with closing(self.client) as client: client.download_file(self.bucket_name, filename, target_filepath) + elif self.storage_type == 'azure-blob': + blob = self.client.get_blob_client(container=self.bucket_name, blob=filename) + with open(target_filepath, "wb") as my_blob: + blob_data = blob.download_blob() + blob_data.readinto(my_blob) else: if not self.folder or self.folder.endswith('/'): filename = self.folder + filename @@ -128,6 +159,9 @@ class Storage: return True except: return False + elif self.storage_type == 'azure-blob': + blob = self.client.get_blob_client(container=self.bucket_name, blob=filename) + return blob.exists() else: if not self.folder or self.folder.endswith('/'): filename = self.folder + filename diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 6c8e0c2777..30a795bd70 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -16,7 +16,7 @@ class SMTPClient: smtp = smtplib.SMTP(self.server, self.port) if self._use_tls: smtp.starttls() - if (self.username): + if self.username and self.password: smtp.login(self.username, self.password) msg = MIMEMultipart() msg['Subject'] = mail['subject'] diff --git a/api/requirements.txt b/api/requirements.txt index 886d7e42d0..6f9fe0cf00 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -67,8 +67,10 @@ yfinance~=0.2.35 pydub~=0.25.1 gmpy2~=2.1.5 numexpr~=2.9.0 -duckduckgo-search==4.4.3 +duckduckgo-search==5.1.0 arxiv==2.1.0 yarl~=1.9.4 twilio==9.0.0 qrcode~=7.4.2 +azure-storage-blob==12.9.0 +azure-identity==1.15.0 \ No newline at end of file diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 9ae0594bf4..ab08d3eeef 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -1,4 +1,4 @@ -version: '3.1' +version: '3' services: # The postgres database. db: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7f5659bfee..97860d3709 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,4 +1,4 @@ -version: '3.1' +version: '3' services: # API service api: @@ -70,7 +70,7 @@ services: # If you want to enable cross-origin support, # you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`. # - # The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local` + # The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob`, Default: `local` STORAGE_TYPE: local # The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`. # only available when STORAGE_TYPE is `local`. @@ -81,6 +81,11 @@ services: S3_ACCESS_KEY: 'ak-difyai' S3_SECRET_KEY: 'sk-difyai' S3_REGION: 'us-east-1' + # The Azure Blob storage configurations, only available when STORAGE_TYPE is `azure-blob`. + AZURE_BLOB_ACCOUNT_NAME: 'difyai' + AZURE_BLOB_ACCOUNT_KEY: 'difyai' + AZURE_BLOB_CONTAINER_NAME: 'difyai-container' + AZURE_BLOB_ACCOUNT_URL: 'https://.blob.core.windows.net' # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`. VECTOR_STORE: weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. @@ -167,9 +172,20 @@ services: REDIS_USE_SSL: 'false' # The configurations of celery broker. CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1 - # The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local` + # The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob`, Default: `local` STORAGE_TYPE: local STORAGE_LOCAL_PATH: storage + # The S3 storage configurations, only available when STORAGE_TYPE is `s3`. + S3_ENDPOINT: 'https://xxx.r2.cloudflarestorage.com' + S3_BUCKET_NAME: 'difyai' + S3_ACCESS_KEY: 'ak-difyai' + S3_SECRET_KEY: 'sk-difyai' + S3_REGION: 'us-east-1' + # The Azure Blob storage configurations, only available when STORAGE_TYPE is `azure-blob`. + AZURE_BLOB_ACCOUNT_NAME: 'difyai' + AZURE_BLOB_ACCOUNT_KEY: 'difyai' + AZURE_BLOB_CONTAINER_NAME: 'difyai-container' + AZURE_BLOB_ACCOUNT_URL: 'https://.blob.core.windows.net' # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`. VECTOR_STORE: weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 786b3277d0..793b47684a 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -935,7 +935,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - + - content (text) text content/question content,required - answer (text) Answer content, not required, passed if the Knowledge is in qa mode - keywords (list) keyword, not required @@ -948,13 +948,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'\\\n--data-raw '{\"segments\": {\"content\": \"1\",\"answer\": \"1\", \"keywords\": [\"a\"], \"enabled\": false}}'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'\\\n--data-raw '{\"segment\": {\"content\": \"1\",\"answer\": \"1\", \"keywords\": [\"a\"], \"enabled\": false}}'`} > ```bash {{ title: 'cURL' }} curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "segments": { + "segment": { "content": "1", "answer": "1", "keywords": ["a"], diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index f0bf12fac5..c44a192a03 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -935,7 +935,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - + - content (text) 文本内容/问题内容,必填 - answer (text) 答案内容,非必填,如果知识库的模式为qa模式则传值 - keywords (list) 关键字,非必填 @@ -948,14 +948,14 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'\\\n--data-raw '{\"segments\": {\"content\": \"1\",\"answer\": \"1\", \"keywords\": [\"a\"], \"enabled\": false}}'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'\\\n--data-raw '{\"segment\": {\"content\": \"1\",\"answer\": \"1\", \"keywords\": [\"a\"], \"enabled\": false}}'`} > ```bash {{ title: 'cURL' }} curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "segments": { + "segment": { "content": "1", "answer": "1", "keywords": ["a"], diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index d4d1e8e4b6..6a0066bbe6 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -444,10 +444,18 @@ Chat applications support session persistence, allowing previous chat history to Message ID + + ### Query + + + User identifier, used to define the identity of the end-user for retrieval and statistics. + Should be uniquely defined by the developer within the application. + + - + ```bash {{ title: 'cURL' }} curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \ diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index dc8d600d00..e513f39339 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -459,10 +459,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' Message ID + + ### Query + + + 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + - + ```bash {{ title: 'cURL' }} curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \