mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/workflow
This commit is contained in:
commit
6cf0e0c242
|
|
@ -39,7 +39,7 @@ DB_DATABASE=dify
|
|||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3
|
||||
# storage type: local, s3, azure-blob
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
|
||||
|
|
@ -47,6 +47,11 @@ S3_BUCKET_NAME=your-bucket-name
|
|||
S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
# Azure Blob Storage configuration
|
||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
AZURE_BLOB_CONTAINER_NAME=yout-container-name
|
||||
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
|
|
|||
|
|
@ -186,6 +186,10 @@ class Config:
|
|||
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
|
||||
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
|
||||
self.S3_REGION = get_env('S3_REGION')
|
||||
self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME')
|
||||
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
|
||||
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
|
||||
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
|
|
|
|||
|
|
@ -197,11 +197,11 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
|
||||
parser.add_argument('segment', type=dict, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||
SegmentService.segment_create_args_validate(args['segment'], document)
|
||||
segment = SegmentService.update_segment(args['segment'], segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@
|
|||
- groq
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- xinference
|
||||
- triton_inference_server
|
||||
- zhipuai
|
||||
- baichuan
|
||||
- spark
|
||||
|
|
@ -20,7 +22,6 @@
|
|||
- moonshot
|
||||
- jina
|
||||
- chatglm
|
||||
- xinference
|
||||
- yi
|
||||
- openllm
|
||||
- localai
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
<svg width="567" height="376" viewBox="0 0 567 376" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M58.0366 161.868C58.0366 161.868 109.261 86.2912 211.538 78.4724V51.053C98.2528 60.1511 0.152344 156.098 0.152344 156.098C0.152344 156.098 55.7148 316.717 211.538 331.426V302.282C97.1876 287.896 58.0366 161.868 58.0366 161.868ZM211.538 244.32V271.013C125.114 255.603 101.125 165.768 101.125 165.768C101.125 165.768 142.621 119.799 211.538 112.345V141.633C211.486 141.633 211.449 141.617 211.406 141.617C175.235 137.276 146.978 171.067 146.978 171.067C146.978 171.067 162.816 227.949 211.538 244.32ZM211.538 0.47998V51.053C214.864 50.7981 218.189 50.5818 221.533 50.468C350.326 46.1273 434.243 156.098 434.243 156.098C434.243 156.098 337.861 273.296 237.448 273.296C228.245 273.296 219.63 272.443 211.538 271.009V302.282C218.695 303.201 225.903 303.667 233.119 303.675C326.56 303.675 394.134 255.954 459.566 199.474C470.415 208.162 514.828 229.299 523.958 238.55C461.745 290.639 316.752 332.626 234.551 332.626C226.627 332.626 219.018 332.148 211.538 331.426V375.369H566.701V0.47998H211.538ZM211.538 112.345V78.4724C214.829 78.2425 218.146 78.0672 221.533 77.9602C314.148 75.0512 374.909 157.548 374.909 157.548C374.909 157.548 309.281 248.693 238.914 248.693C228.787 248.693 219.707 247.065 211.536 244.318V141.631C247.591 145.987 254.848 161.914 276.524 198.049L324.737 157.398C324.737 157.398 289.544 111.243 230.219 111.243C223.768 111.241 217.597 111.696 211.538 112.345Z" fill="#77B900"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
|
|
@ -0,0 +1,267 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from httpx import Response, post
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
"""
|
||||
invoke LLM
|
||||
|
||||
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
|
||||
"""
|
||||
return self._generate(
|
||||
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
|
||||
tools=tools, stop=stop, stream=stream, user=user,
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate credentials
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('server_url is required in credentials')
|
||||
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, prompt_messages=[
|
||||
UserPromptMessage(content='ping')
|
||||
], model_parameters={}, stream=False)
|
||||
except InvokeError as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}')
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
"""
|
||||
get number of tokens
|
||||
|
||||
cause TritonInference LLM is a customized model, we could net detect which tokenizer to use
|
||||
so we just take the GPT2 tokenizer as default
|
||||
"""
|
||||
return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages))
|
||||
|
||||
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
|
||||
"""
|
||||
convert prompt message to text
|
||||
"""
|
||||
text = ''
|
||||
for item in message:
|
||||
if isinstance(item, UserPromptMessage):
|
||||
text += f'User: {item.content}'
|
||||
elif isinstance(item, SystemPromptMessage):
|
||||
text += f'System: {item.content}'
|
||||
elif isinstance(item, AssistantPromptMessage):
|
||||
text += f'Assistant: {item.content}'
|
||||
else:
|
||||
raise NotImplementedError(f'PromptMessage type {type(item)} is not supported')
|
||||
return text
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度',
|
||||
en_US='Temperature'
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P',
|
||||
en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
max=int(credentials.get('context_length', 2048)),
|
||||
default=min(512, int(credentials.get('context_length', 2048))),
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
completion_type = None
|
||||
|
||||
if 'completion_type' in credentials:
|
||||
if credentials['completion_type'] == 'chat':
|
||||
completion_type = LLMMode.CHAT.value
|
||||
elif credentials['completion_type'] == 'completion':
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
parameter_rules=rules,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: completion_type,
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)),
|
||||
},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
"""
|
||||
generate text from LLM
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('server_url is required in credentials')
|
||||
|
||||
if 'stream' in credentials and not bool(credentials['stream']) and stream:
|
||||
raise ValueError(f'stream is not supported by model {model}')
|
||||
|
||||
try:
|
||||
parameters = {}
|
||||
if 'temperature' in model_parameters:
|
||||
parameters['temperature'] = model_parameters['temperature']
|
||||
if 'top_p' in model_parameters:
|
||||
parameters['top_p'] = model_parameters['top_p']
|
||||
if 'top_k' in model_parameters:
|
||||
parameters['top_k'] = model_parameters['top_k']
|
||||
if 'presence_penalty' in model_parameters:
|
||||
parameters['presence_penalty'] = model_parameters['presence_penalty']
|
||||
if 'frequency_penalty' in model_parameters:
|
||||
parameters['frequency_penalty'] = model_parameters['frequency_penalty']
|
||||
|
||||
response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={
|
||||
'text_input': self._convert_prompt_message_to_text(prompt_messages),
|
||||
'max_tokens': model_parameters.get('max_tokens', 512),
|
||||
'parameters': {
|
||||
'stream': False,
|
||||
**parameters
|
||||
},
|
||||
}, timeout=(10, 120))
|
||||
response.raise_for_status()
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}')
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
tools=tools, resp=response)
|
||||
return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
tools=tools, resp=response)
|
||||
except Exception as ex:
|
||||
raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}')
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: Response) -> LLMResult:
|
||||
"""
|
||||
handle normal chat generate response
|
||||
"""
|
||||
text = resp.json()['text_output']
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=text
|
||||
),
|
||||
usage=usage
|
||||
)
|
||||
|
||||
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: Response) -> Generator:
|
||||
"""
|
||||
handle normal chat generate response
|
||||
"""
|
||||
text = resp.json()['text_output']
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=text
|
||||
),
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class XinferenceAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
provider: triton_inference_server
|
||||
label:
|
||||
en_US: Triton Inference Server
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#EFFDFD"
|
||||
help:
|
||||
title:
|
||||
en_US: How to deploy Triton Inference Server
|
||||
zh_Hans: 如何部署 Triton Inference Server
|
||||
url:
|
||||
en_US: https://github.com/triton-inference-server/server
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: server_url
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000
|
||||
en_US: Enter the url of your Triton Inference Server, e.g. http://192.168.1.100:8000
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 上下文大小
|
||||
en_US: Context size
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的上下文大小
|
||||
en_US: Enter the context size
|
||||
default: 2048
|
||||
- variable: completion_type
|
||||
label:
|
||||
zh_Hans: 补全类型
|
||||
en_US: Model type
|
||||
type: select
|
||||
required: true
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的补全类型
|
||||
en_US: Enter the completion type
|
||||
options:
|
||||
- label:
|
||||
zh_Hans: 补全模型
|
||||
en_US: Completion model
|
||||
value: completion
|
||||
- label:
|
||||
zh_Hans: 对话模型
|
||||
en_US: Chat model
|
||||
value: chat
|
||||
- variable: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream output
|
||||
type: select
|
||||
required: true
|
||||
default: true
|
||||
placeholder:
|
||||
zh_Hans: 是否支持流式输出
|
||||
en_US: Whether to support stream output
|
||||
options:
|
||||
- label:
|
||||
zh_Hans: 是
|
||||
en_US: Yes
|
||||
value: true
|
||||
- label:
|
||||
zh_Hans: 否
|
||||
en_US: No
|
||||
value: false
|
||||
|
|
@ -1,30 +1,119 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import tiktoken
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
|
||||
|
||||
|
||||
class YiLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
class YiLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
# yi-vl-plus not support system prompt yet.
|
||||
if model == "yi-vl-plus":
|
||||
prompt_message_except_system: list[PromptMessage] = []
|
||||
for message in prompt_messages:
|
||||
if not isinstance(message, SystemPromptMessage):
|
||||
prompt_message_except_system.append(message)
|
||||
return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_string(self, model: str, text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
|
||||
credentials['openai_api_key']=credentials['api_key']
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
credentials['endpoint_url'] = 'https://api.lingyiwanwu.com/v1'
|
||||
credentials['openai_api_base']='https://api.lingyiwanwu.com'
|
||||
else:
|
||||
parsed_url = urlparse(credentials['endpoint_url'])
|
||||
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class Vector:
|
|||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ import os
|
|||
import shutil
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union
|
||||
|
||||
import boto3
|
||||
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
|
||||
from botocore.exceptions import ClientError
|
||||
from flask import Flask
|
||||
|
||||
|
|
@ -27,6 +29,18 @@ class Storage:
|
|||
endpoint_url=app.config.get('S3_ENDPOINT'),
|
||||
region_name=app.config.get('S3_REGION')
|
||||
)
|
||||
elif self.storage_type == 'azure-blob':
|
||||
self.bucket_name = app.config.get('AZURE_BLOB_CONTAINER_NAME')
|
||||
sas_token = generate_account_sas(
|
||||
account_name=app.config.get('AZURE_BLOB_ACCOUNT_NAME'),
|
||||
account_key=app.config.get('AZURE_BLOB_ACCOUNT_KEY'),
|
||||
resource_types=ResourceTypes(service=True, container=True, object=True),
|
||||
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
|
||||
expiry=datetime.utcnow() + timedelta(hours=1)
|
||||
)
|
||||
self.client = BlobServiceClient(account_url=app.config.get('AZURE_BLOB_ACCOUNT_URL'),
|
||||
credential=sas_token)
|
||||
|
||||
else:
|
||||
self.folder = app.config.get('STORAGE_LOCAL_PATH')
|
||||
if not os.path.isabs(self.folder):
|
||||
|
|
@ -35,6 +49,9 @@ class Storage:
|
|||
def save(self, filename, data):
|
||||
if self.storage_type == 's3':
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob_container = self.client.get_container_client(container=self.bucket_name)
|
||||
blob_container.upload_blob(filename, data)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
|
|
@ -63,6 +80,10 @@ class Storage:
|
|||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob = self.client.get_container_client(container=self.bucket_name)
|
||||
blob = blob.get_blob_client(blob=filename)
|
||||
data = blob.download_blob().readall()
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
|
|
@ -90,6 +111,11 @@ class Storage:
|
|||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
with closing(blob.download_blob()) as blob_stream:
|
||||
while chunk := blob_stream.readall(4096):
|
||||
yield chunk
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
|
|
@ -109,6 +135,11 @@ class Storage:
|
|||
if self.storage_type == 's3':
|
||||
with closing(self.client) as client:
|
||||
client.download_file(self.bucket_name, filename, target_filepath)
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
with open(target_filepath, "wb") as my_blob:
|
||||
blob_data = blob.download_blob()
|
||||
blob_data.readinto(my_blob)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
|
|
@ -128,6 +159,9 @@ class Storage:
|
|||
return True
|
||||
except:
|
||||
return False
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
return blob.exists()
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class SMTPClient:
|
|||
smtp = smtplib.SMTP(self.server, self.port)
|
||||
if self._use_tls:
|
||||
smtp.starttls()
|
||||
if (self.username):
|
||||
if self.username and self.password:
|
||||
smtp.login(self.username, self.password)
|
||||
msg = MIMEMultipart()
|
||||
msg['Subject'] = mail['subject']
|
||||
|
|
|
|||
|
|
@ -67,8 +67,10 @@ yfinance~=0.2.35
|
|||
pydub~=0.25.1
|
||||
gmpy2~=2.1.5
|
||||
numexpr~=2.9.0
|
||||
duckduckgo-search==4.4.3
|
||||
duckduckgo-search==5.1.0
|
||||
arxiv==2.1.0
|
||||
yarl~=1.9.4
|
||||
twilio==9.0.0
|
||||
qrcode~=7.4.2
|
||||
azure-storage-blob==12.9.0
|
||||
azure-identity==1.15.0
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
version: '3.1'
|
||||
version: '3'
|
||||
services:
|
||||
# The postgres database.
|
||||
db:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
version: '3.1'
|
||||
version: '3'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
|
|
@ -70,7 +70,7 @@ services:
|
|||
# If you want to enable cross-origin support,
|
||||
# you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`.
|
||||
#
|
||||
# The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local`
|
||||
# The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob`, Default: `local`
|
||||
STORAGE_TYPE: local
|
||||
# The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`.
|
||||
# only available when STORAGE_TYPE is `local`.
|
||||
|
|
@ -81,6 +81,11 @@ services:
|
|||
S3_ACCESS_KEY: 'ak-difyai'
|
||||
S3_SECRET_KEY: 'sk-difyai'
|
||||
S3_REGION: 'us-east-1'
|
||||
# The Azure Blob storage configurations, only available when STORAGE_TYPE is `azure-blob`.
|
||||
AZURE_BLOB_ACCOUNT_NAME: 'difyai'
|
||||
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
|
||||
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
|
||||
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
|
||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
|
||||
VECTOR_STORE: weaviate
|
||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||
|
|
@ -167,9 +172,20 @@ services:
|
|||
REDIS_USE_SSL: 'false'
|
||||
# The configurations of celery broker.
|
||||
CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1
|
||||
# The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local`
|
||||
# The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob`, Default: `local`
|
||||
STORAGE_TYPE: local
|
||||
STORAGE_LOCAL_PATH: storage
|
||||
# The S3 storage configurations, only available when STORAGE_TYPE is `s3`.
|
||||
S3_ENDPOINT: 'https://xxx.r2.cloudflarestorage.com'
|
||||
S3_BUCKET_NAME: 'difyai'
|
||||
S3_ACCESS_KEY: 'ak-difyai'
|
||||
S3_SECRET_KEY: 'sk-difyai'
|
||||
S3_REGION: 'us-east-1'
|
||||
# The Azure Blob storage configurations, only available when STORAGE_TYPE is `azure-blob`.
|
||||
AZURE_BLOB_ACCOUNT_NAME: 'difyai'
|
||||
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
|
||||
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
|
||||
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
|
||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
|
||||
VECTOR_STORE: weaviate
|
||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||
|
|
|
|||
|
|
@ -935,7 +935,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
|
||||
### Request Body
|
||||
<Properties>
|
||||
<Property name='segments' type='object list' key='segments'>
|
||||
<Property name='segment' type='object list' key='segment'>
|
||||
- <code>content</code> (text) text content/question content,required
|
||||
- <code>answer</code> (text) Answer content, not required, passed if the Knowledge is in qa mode
|
||||
- <code>keywords</code> (list) keyword, not required
|
||||
|
|
@ -948,13 +948,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}"
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'\\\n--data-raw '{\"segments\": {\"content\": \"1\",\"answer\": \"1\", \"keywords\": [\"a\"], \"enabled\": false}}'`}
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'\\\n--data-raw '{\"segment\": {\"content\": \"1\",\"answer\": \"1\", \"keywords\": [\"a\"], \"enabled\": false}}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"segments": {
|
||||
"segment": {
|
||||
"content": "1",
|
||||
"answer": "1",
|
||||
"keywords": ["a"],
|
||||
|
|
|
|||
|
|
@ -935,7 +935,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
|
||||
### Request Body
|
||||
<Properties>
|
||||
<Property name='segments' type='object list' key='segments'>
|
||||
<Property name='segment' type='object list' key='segment'>
|
||||
- <code>content</code> (text) 文本内容/问题内容,必填
|
||||
- <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为qa模式则传值
|
||||
- <code>keywords</code> (list) 关键字,非必填
|
||||
|
|
@ -948,14 +948,14 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}"
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'\\\n--data-raw '{\"segments\": {\"content\": \"1\",\"answer\": \"1\", \"keywords\": [\"a\"], \"enabled\": false}}'`}
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'\\\n--data-raw '{\"segment\": {\"content\": \"1\",\"answer\": \"1\", \"keywords\": [\"a\"], \"enabled\": false}}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' \
|
||||
--header 'Authorization: Bearer {api_key}' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"segments": {
|
||||
"segment": {
|
||||
"content": "1",
|
||||
"answer": "1",
|
||||
"keywords": ["a"],
|
||||
|
|
|
|||
|
|
@ -444,10 +444,18 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
Message ID
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Query
|
||||
<Properties>
|
||||
<Property name='user' type='string' key='user'>
|
||||
User identifier, used to define the identity of the end-user for retrieval and statistics.
|
||||
Should be uniquely defined by the developer within the application.
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="Request" tag="GET" label="/messages/{message_id}/suggested" targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested \\\n--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \\\n--header 'Content-Type: application/json'`}>
|
||||
<CodeGroup title="Request" tag="GET" label="/messages/{message_id}/suggested" targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested?user=abc-123& \\\n--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \\\n--header 'Content-Type: application/json'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \
|
||||
|
|
|
|||
|
|
@ -459,10 +459,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
Message ID
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Query
|
||||
<Properties>
|
||||
<Property name='user' type='string' key='user'>
|
||||
用户标识,由开发者定义规则,需保证用户标识在应用内唯一。
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="Request" tag="GET" label="/messages/{message_id}/suggested" targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested \\\n--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \\\n--header 'Content-Type: application/json'`}>
|
||||
<CodeGroup title="Request" tag="GET" label="/messages/{message_id}/suggested" targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested?user=abc-123 \\\n--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \\\n--header 'Content-Type: application/json'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \
|
||||
|
|
|
|||
Loading…
Reference in New Issue