Merge branch 'main' into feat/workflow
|
|
@ -12,6 +12,8 @@ Please delete options that are not relevant.
|
|||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [ ] Improvement,including but not limited to code refactoring, performance optimization, and UI/UX improvement
|
||||
- [ ] Dependency upgrade
|
||||
|
||||
# How Has This Been Tested?
|
||||
|
||||
|
|
|
|||
|
|
@ -342,12 +342,20 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
Convert prompt messages to dict list and system
|
||||
"""
|
||||
system = ""
|
||||
prompt_message_dicts = []
|
||||
|
||||
first_loop = True
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
system += message.content + ("\n" if not system else "")
|
||||
else:
|
||||
message.content=message.content.strip()
|
||||
if first_loop:
|
||||
system=message.content
|
||||
first_loop=False
|
||||
else:
|
||||
system+="\n"
|
||||
system+=message.content
|
||||
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if not isinstance(message, SystemPromptMessage):
|
||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
||||
|
||||
return system, prompt_message_dicts
|
||||
|
|
|
|||
|
|
@ -123,6 +123,65 @@ LLM_BASE_MODELS = [
|
|||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-35-turbo-0125',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 16385,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.0005,
|
||||
output=0.0015,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4',
|
||||
entity=AIModelEntity(
|
||||
|
|
@ -273,6 +332,81 @@ LLM_BASE_MODELS = [
|
|||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-0125-preview',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.01,
|
||||
output=0.03,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-1106-preview',
|
||||
entity=AIModelEntity(
|
||||
|
|
|
|||
|
|
@ -75,6 +75,12 @@ model_credential_schema:
|
|||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-0125
|
||||
value: gpt-35-turbo-0125
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-16k
|
||||
value: gpt-35-turbo-16k
|
||||
|
|
@ -93,6 +99,12 @@ model_credential_schema:
|
|||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-0125-preview
|
||||
value: gpt-4-0125-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-1106-preview
|
||||
value: gpt-4-1106-preview
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||
elif err == 'insufficient_quota':
|
||||
raise InsufficientAccountBalance(msg)
|
||||
elif err == 'invalid_authentication':
|
||||
raise InvalidAuthenticationError(msg)
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err and 'rate' in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif err and 'internal' in err:
|
||||
|
|
|
|||
|
|
@ -48,23 +48,23 @@ provider_credential_schema:
|
|||
- value: us-east-1
|
||||
label:
|
||||
en_US: US East (N. Virginia)
|
||||
zh_Hans: US East (N. Virginia)
|
||||
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||
- value: us-west-2
|
||||
label:
|
||||
en_US: US West (Oregon)
|
||||
zh_Hans: US West (Oregon)
|
||||
zh_Hans: 美国西部 (俄勒冈州)
|
||||
- value: ap-southeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Singapore)
|
||||
zh_Hans: Asia Pacific (Singapore)
|
||||
zh_Hans: 亚太地区 (新加坡)
|
||||
- value: ap-northeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Tokyo)
|
||||
zh_Hans: Asia Pacific (Tokyo)
|
||||
zh_Hans: 亚太地区 (东京)
|
||||
- value: eu-central-1
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
- anthropic.claude-v1
|
||||
- anthropic.claude-v2
|
||||
- anthropic.claude-v2:1
|
||||
- anthropic.claude-3-sonnet-v1:0
|
||||
- anthropic.claude-3-haiku-v1:0
|
||||
- cohere.command-light-text-v14
|
||||
- cohere.command-text-v14
|
||||
- meta.llama2-13b-chat-v1
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
model: anthropic.claude-3-haiku-20240307-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Haiku
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Sonnet
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00125'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
|
@ -1,9 +1,22 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from anthropic import AnthropicBedrock, Stream
|
||||
from anthropic.types import (
|
||||
ContentBlockDeltaEvent,
|
||||
Message,
|
||||
MessageDeltaEvent,
|
||||
MessageStartEvent,
|
||||
MessageStopEvent,
|
||||
MessageStreamEvent,
|
||||
)
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
|
|
@ -13,14 +26,18 @@ from botocore.exceptions import (
|
|||
UnknownServiceError,
|
||||
)
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
|
|
@ -54,9 +71,293 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# invoke claude 3 models via anthropic official SDK
|
||||
if "anthropic.claude-3" in model:
|
||||
return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def _invoke_claude3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke Claude3 large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# use Anthropic official SDK references
|
||||
# - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
client = AnthropicBedrock(
|
||||
aws_access_key=credentials["aws_access_key_id"],
|
||||
aws_secret_key=credentials["aws_secret_access_key"],
|
||||
aws_region=credentials["aws_region"],
|
||||
)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
|
||||
# Notice: If you request the current version of the SDK to the bedrock server,
|
||||
# you will get the following error message and you need to wait for the service or SDK to be updated.
|
||||
# Response: Error code: 400
|
||||
# {'message': 'Malformed input request: #: subject must not be valid against schema
|
||||
# {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'}
|
||||
# TODO: Open in the future when the interface is properly supported
|
||||
# if user:
|
||||
# ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465
|
||||
# extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user)
|
||||
|
||||
system, prompt_message_dicts = self._convert_claude3_prompt_messages(prompt_messages)
|
||||
|
||||
if system:
|
||||
extra_model_kwargs['system'] = system
|
||||
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_claude3_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_claude3_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_claude3_response(self, model: str, credentials: dict, response: Message,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: full response chunk generator result
|
||||
"""
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.content[0].text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_claude3_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
|
||||
prompt_messages: list[PromptMessage], ) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
try:
|
||||
full_assistant_content = ''
|
||||
return_model = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
finish_reason = None
|
||||
index = 0
|
||||
|
||||
for chunk in response:
|
||||
if isinstance(chunk, MessageStartEvent):
|
||||
return_model = chunk.message.model
|
||||
input_tokens = chunk.message.usage.input_tokens
|
||||
elif isinstance(chunk, MessageDeltaEvent):
|
||||
output_tokens = chunk.usage.output_tokens
|
||||
finish_reason = chunk.delta.stop_reason
|
||||
elif isinstance(chunk, MessageStopEvent):
|
||||
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
||||
yield LLMResultChunk(
|
||||
model=return_model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index + 1,
|
||||
message=AssistantPromptMessage(
|
||||
content=''
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||
chunk_text = chunk.delta.text if chunk.delta.text else ''
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else '',
|
||||
)
|
||||
index = chunk.index
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
raise InvokeError(str(ex))
|
||||
|
||||
def _calc_claude3_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_tokens: prompt tokens
|
||||
:param completion_tokens: completion tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get prompt price info
|
||||
prompt_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
# get completion price info
|
||||
completion_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.OUTPUT,
|
||||
tokens=completion_tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=prompt_price_info.unit_price,
|
||||
prompt_price_unit=prompt_price_info.unit,
|
||||
prompt_price=prompt_price_info.total_amount,
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=completion_price_info.unit_price,
|
||||
completion_price_unit=completion_price_info.unit,
|
||||
completion_price=completion_price_info.total_amount,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
||||
currency=prompt_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def _convert_claude3_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
||||
"""
|
||||
Convert prompt messages to dict list and system
|
||||
"""
|
||||
|
||||
system = ""
|
||||
first_loop = True
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content=message.content.strip()
|
||||
if first_loop:
|
||||
system=message.content
|
||||
first_loop=False
|
||||
else:
|
||||
system+="\n"
|
||||
system+=message.content
|
||||
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if not isinstance(message, SystemPromptMessage):
|
||||
prompt_message_dicts.append(self._convert_claude3_prompt_message_to_dict(message))
|
||||
|
||||
return system, prompt_message_dicts
|
||||
|
||||
def _convert_claude3_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
raise ValueError(f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
||||
|
||||
sub_message_dict = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
|
|
@ -101,7 +402,19 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
if "anthropic.claude-3" in model:
|
||||
try:
|
||||
self._invoke_claude3(model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[{"role": "user", "content": "ping"}],
|
||||
model_parameters={},
|
||||
stop=None,
|
||||
stream=False)
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
try:
|
||||
ping_message = UserPromptMessage(content="ping")
|
||||
self._generate(model=model,
|
||||
|
|
|
|||
|
|
@ -449,7 +449,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
help=I18nObject(en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"),
|
||||
default=0.8,
|
||||
default=0.1,
|
||||
min=0,
|
||||
max=2
|
||||
),
|
||||
|
|
@ -472,7 +472,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
help=I18nObject(en_US="Reduces the probability of generating nonsense. "
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
|
||||
default=40,
|
||||
min=1,
|
||||
max=100
|
||||
),
|
||||
|
|
@ -483,7 +482,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
|
||||
default=1.1,
|
||||
min=-2,
|
||||
max=2
|
||||
),
|
||||
|
|
@ -494,7 +492,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"),
|
||||
default=128,
|
||||
default=512 if int(credentials.get('max_tokens', 4096)) >= 768 else 128,
|
||||
min=-2,
|
||||
max=int(credentials.get('max_tokens', 4096)),
|
||||
),
|
||||
|
|
@ -504,7 +502,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
|
||||
default=0,
|
||||
min=0,
|
||||
max=2
|
||||
),
|
||||
|
|
@ -516,7 +513,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"),
|
||||
default=0.1,
|
||||
precision=1
|
||||
),
|
||||
ParameterRule(
|
||||
|
|
@ -525,7 +521,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"),
|
||||
default=5.0,
|
||||
precision=1
|
||||
),
|
||||
ParameterRule(
|
||||
|
|
@ -543,7 +538,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="The number of layers to send to the GPU(s). "
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."),
|
||||
default=1,
|
||||
min=0,
|
||||
max=1
|
||||
),
|
||||
|
|
@ -563,7 +557,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"),
|
||||
default=64,
|
||||
min=-1
|
||||
),
|
||||
ParameterRule(
|
||||
|
|
@ -573,7 +566,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"),
|
||||
default=1,
|
||||
precision=1
|
||||
),
|
||||
ParameterRule(
|
||||
|
|
@ -583,7 +575,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"),
|
||||
default=0
|
||||
),
|
||||
ParameterRule(
|
||||
name='format',
|
||||
|
|
|
|||
|
|
@ -656,6 +656,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
if assistant_message_function_call:
|
||||
# start of stream function call
|
||||
delta_assistant_message_function_call_storage = assistant_message_function_call
|
||||
if delta_assistant_message_function_call_storage.arguments is None:
|
||||
delta_assistant_message_function_call_storage.arguments = ''
|
||||
if not has_finish_reason:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -8,54 +8,70 @@ model_properties:
|
|||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1.0
|
||||
type: float
|
||||
default: 0.85
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1500
|
||||
min: 1
|
||||
max: 6000
|
||||
help:
|
||||
zh_Hans: 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量。
|
||||
en_US: It is used to limit the number of tokens generated by the model. max_tokens sets the upper limit of generation, which does not mean that so many tokens will be generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。默认不传递该参数,取值为None或当top_k大于100时,表示不启用top_k策略,此时,仅有top_p策略生效。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. This parameter is not passed by default. The value is None or when top_k is greater than 100, it means that the top_k policy is not enabled. At this time, only the top_p policy takes effect.
|
||||
required: false
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,随机数的种子,用于控制模型生成的随机性。如果使用相同的种子,每次运行生成的结果都将相同;当需要复现模型的生成结果时,可以使用相同的种子。seed参数支持无符号64位整数类型。
|
||||
en_US: When generating, the random number seed is used to control the randomness of model generation. If you use the same seed, the results generated by each run will be the same; when you need to reproduce the results of the model, you can use the same seed. The seed parameter supports unsigned 64-bit integer types.
|
||||
required: false
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: enable_search
|
||||
type: boolean
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.12'
|
||||
output: '0.12'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
|
|||
|
|
@ -4,58 +4,74 @@ label:
|
|||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 30000
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1.0
|
||||
type: float
|
||||
default: 0.85
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 28000
|
||||
help:
|
||||
zh_Hans: 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量。
|
||||
en_US: It is used to limit the number of tokens generated by the model. max_tokens sets the upper limit of generation, which does not mean that so many tokens will be generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。默认不传递该参数,取值为None或当top_k大于100时,表示不启用top_k策略,此时,仅有top_p策略生效。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. This parameter is not passed by default. The value is None or when top_k is greater than 100, it means that the top_k policy is not enabled. At this time, only the top_p policy takes effect.
|
||||
required: false
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,随机数的种子,用于控制模型生成的随机性。如果使用相同的种子,每次运行生成的结果都将相同;当需要复现模型的生成结果时,可以使用相同的种子。seed参数支持无符号64位整数类型。
|
||||
en_US: When generating, the random number seed is used to control the randomness of model generation. If you use the same seed, the results generated by each run will be the same; when you need to reproduce the results of the model, you can use the same seed. The seed parameter supports unsigned 64-bit integer types.
|
||||
required: false
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: enable_search
|
||||
type: boolean
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.12'
|
||||
output: '0.12'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
|
|||
|
|
@ -8,54 +8,70 @@ model_properties:
|
|||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1.0
|
||||
type: float
|
||||
default: 0.85
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1500
|
||||
min: 1
|
||||
max: 6000
|
||||
help:
|
||||
zh_Hans: 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量。
|
||||
en_US: It is used to limit the number of tokens generated by the model. max_tokens sets the upper limit of generation, which does not mean that so many tokens will be generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。默认不传递该参数,取值为None或当top_k大于100时,表示不启用top_k策略,此时,仅有top_p策略生效。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. This parameter is not passed by default. The value is None or when top_k is greater than 100, it means that the top_k policy is not enabled. At this time, only the top_p policy takes effect.
|
||||
required: false
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,随机数的种子,用于控制模型生成的随机性。如果使用相同的种子,每次运行生成的结果都将相同;当需要复现模型的生成结果时,可以使用相同的种子。seed参数支持无符号64位整数类型。
|
||||
en_US: When generating, the random number seed is used to control the randomness of model generation. If you use the same seed, the results generated by each run will be the same; when you need to reproduce the results of the model, you can use the same seed. The seed parameter supports unsigned 64-bit integer types.
|
||||
required: false
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: enable_search
|
||||
type: boolean
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.12'
|
||||
output: '0.12'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
|
|||
|
|
@ -4,58 +4,70 @@ label:
|
|||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1.0
|
||||
type: float
|
||||
default: 0.85
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 1500
|
||||
min: 1
|
||||
max: 1500
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 30000
|
||||
help:
|
||||
zh_Hans: 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量。
|
||||
en_US: It is used to limit the number of tokens generated by the model. max_tokens sets the upper limit of generation, which does not mean that so many tokens will be generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。默认不传递该参数,取值为None或当top_k大于100时,表示不启用top_k策略,此时,仅有top_p策略生效。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. This parameter is not passed by default. The value is None or when top_k is greater than 100, it means that the top_k policy is not enabled. At this time, only the top_p policy takes effect.
|
||||
required: false
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,随机数的种子,用于控制模型生成的随机性。如果使用相同的种子,每次运行生成的结果都将相同;当需要复现模型的生成结果时,可以使用相同的种子。seed参数支持无符号64位整数类型。
|
||||
en_US: When generating, the random number seed is used to control the randomness of model generation. If you use the same seed, the results generated by each run will be the same; when you need to reproduce the results of the model, you can use the same seed. The seed parameter supports unsigned 64-bit integer types.
|
||||
required: false
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: enable_search
|
||||
type: boolean
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
|
|
|
|||
|
|
@ -8,55 +8,66 @@ model_properties:
|
|||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1.0
|
||||
type: float
|
||||
default: 0.85
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 1500
|
||||
min: 1
|
||||
max: 1500
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1500
|
||||
min: 1
|
||||
max: 6000
|
||||
help:
|
||||
zh_Hans: 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量。
|
||||
en_US: It is used to limit the number of tokens generated by the model. max_tokens sets the upper limit of generation, which does not mean that so many tokens will be generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。默认不传递该参数,取值为None或当top_k大于100时,表示不启用top_k策略,此时,仅有top_p策略生效。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. This parameter is not passed by default. The value is None or when top_k is greater than 100, it means that the top_k policy is not enabled. At this time, only the top_p policy takes effect.
|
||||
required: false
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 生成时,随机数的种子,用于控制模型生成的随机性。如果使用相同的种子,每次运行生成的结果都将相同;当需要复现模型的生成结果时,可以使用相同的种子。seed参数支持无符号64位整数类型。
|
||||
en_US: When generating, the random number seed is used to control the randomness of model generation. If you use the same seed, the results generated by each run will be the same; when you need to reproduce the results of the model, you can use the same seed. The seed parameter supports unsigned 64-bit integer types.
|
||||
required: false
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: enable_search
|
||||
type: boolean
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
model: text-embedding-v1
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 2048
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
model: text-embedding-v2
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 2048
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
EmbeddingUsage,
|
||||
TextEmbeddingResult,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||
TextEmbeddingModel,
|
||||
)
|
||||
from core.model_runtime.model_providers.tongyi._common import _CommonTongyi
|
||||
|
||||
|
||||
class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Tongyi text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
dashscope.api_key = credentials_kwargs["dashscope_api_key"]
|
||||
embeddings, embedding_used_tokens = self.embed_documents(model, texts)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens),
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
dashscope.api_key = credentials_kwargs["dashscope_api_key"]
|
||||
# call embedding model
|
||||
self.embed_documents(model=model, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def embed_documents(model: str, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""Call out to Tongyi's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text, and tokens usage.
|
||||
"""
|
||||
embeddings = []
|
||||
embedding_used_tokens = 0
|
||||
for text in texts:
|
||||
response = dashscope.TextEmbedding.call(model=model, input=text, text_type="document")
|
||||
data = response.output["embeddings"][0]
|
||||
embeddings.append(data["embedding"])
|
||||
embedding_used_tokens += response.usage["total_tokens"]
|
||||
|
||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
|
||||
|
||||
def _calc_response_usage(
|
||||
self, model: str, credentials: dict, tokens: int
|
||||
) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -17,15 +17,16 @@ help:
|
|||
supported_model_types:
|
||||
- llm
|
||||
- tts
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: dashscope_api_key
|
||||
label:
|
||||
en_US: APIKey
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 APIKey
|
||||
en_US: Enter your APIKey
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
|
|
|
|||
|
|
@ -1,20 +1,12 @@
|
|||
<svg width="80" height="22" viewBox="0 0 450 120" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2;">
|
||||
<g transform="matrix(0.172742,0,0,0.172742,9.60932,8.17741)">
|
||||
<circle cx="300" cy="300" r="300" style="fill:rgb(0,52,37);"/>
|
||||
</g>
|
||||
<g transform="matrix(0.172742,0,0,0.172742,9.60932,8.17741)">
|
||||
<path d="M452.119,361.224C452.119,349.527 442.623,340.031 430.926,340.031C419.229,340.031 409.733,349.527 409.733,361.224L409.733,470.486C409.733,482.183 419.229,491.679 430.926,491.679C442.623,491.679 452.119,482.183 452.119,470.486L452.119,361.224Z" style="fill:white;"/>
|
||||
</g>
|
||||
<g transform="matrix(0.172742,0,0,0.172742,9.60932,8.17741)">
|
||||
<path d="M422.005,133.354C413.089,125.771 399.714,126.851 392.131,135.767L273.699,275.021C270.643,278.614 268.994,282.932 268.698,287.302C268.532,288.371 268.446,289.466 268.446,290.581L268.446,468.603C268.446,480.308 277.934,489.796 289.639,489.796C301.344,489.796 310.832,480.308 310.832,468.603L310.832,296.784L424.419,163.228C432.002,154.312 430.921,140.937 422.005,133.354Z" style="fill:white;"/>
|
||||
</g>
|
||||
<g transform="matrix(0.13359,-0.109514,0.109514,0.13359,-0.630793,25.9151)">
|
||||
<path d="M156.358,155.443C156.358,143.746 146.862,134.25 135.165,134.25C123.468,134.25 113.972,143.746 113.972,155.443L113.972,287.802C113.972,299.499 123.468,308.995 135.165,308.995C146.862,308.995 156.358,299.499 156.358,287.802L156.358,155.443Z" style="fill:white;"/>
|
||||
</g>
|
||||
<g transform="matrix(0.172742,0,0,0.172742,9.60932,8.17741)">
|
||||
<circle cx="460.126" cy="279.278" r="25.903" style="fill:rgb(0,255,37);"/>
|
||||
</g>
|
||||
<g transform="matrix(1,0,0,1,-77.4848,13.0849)">
|
||||
<text x="210.275px" y="74.595px" style="font-family:'AlibabaPuHuiTi_3_55_Regular', 'Alibaba PuHuiTi 3.0', serif;font-size:80px;">01<tspan x="294.355px " y="74.595px ">.</tspan>AI</text>
|
||||
</g>
|
||||
</svg>
|
||||
<svg width="64" height="24" viewBox="0 0 64 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M1.28808 1.39558C1.76461 1.00315 2.46905 1.07132 2.86149 1.54785L7.7517 7.48596C8.14414 7.96249 8.07597 8.66693 7.59944 9.05937C7.1229 9.45181 6.41847 9.38363 6.02603 8.9071L1.13582 2.96899C0.743382 2.49246 0.811553 1.78802 1.28808 1.39558Z" fill="#133426"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M10.1689 22.3553C9.55157 22.3553 9.05112 21.8549 9.05109 21.2375L9.05075 10.7193C9.05074 10.4478 9.14951 10.1856 9.32863 9.98168L16.1801 2.17956C16.5875 1.7157 17.2937 1.66989 17.7576 2.07723C18.2214 2.48457 18.2673 3.19081 17.8599 3.65467L11.2863 11.1403L11.2866 21.2375C11.2866 21.8548 10.7862 22.3552 10.1689 22.3553Z" fill="#133426"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M18.2138 13.7077C18.8311 13.7077 19.3315 14.2081 19.3315 14.8255V21.0896C19.3315 21.7069 18.8311 22.2073 18.2138 22.2073C17.5965 22.2073 17.096 21.7069 17.096 21.0896V14.8255C17.096 14.2081 17.5965 13.7077 18.2138 13.7077Z" fill="#133426"/>
|
||||
<circle cx="19.7936" cy="10.3307" r="1.73695" fill="#00FF00"/>
|
||||
<path d="M61.6555 10.3637V22H60.593V10.3637H61.6555Z" fill="black"/>
|
||||
<path d="M50.1101 22H48.9964L53.2294 10.3637H54.3658L58.5987 22H57.4851L53.8374 11.7444H53.7578L50.1101 22ZM50.9112 17.5398H56.6839V18.4944H50.9112V17.5398Z" fill="black"/>
|
||||
<path d="M46.3928 22.0853C46.1693 22.0853 45.9761 22.0057 45.8132 21.8466C45.6541 21.6838 45.5746 21.4906 45.5746 21.2671C45.5746 21.0398 45.6541 20.8466 45.8132 20.6875C45.9761 20.5285 46.1693 20.4489 46.3928 20.4489C46.62 20.4489 46.8132 20.5285 46.9723 20.6875C47.1314 20.8466 47.2109 21.0398 47.2109 21.2671C47.2109 21.4148 47.1731 21.5512 47.0973 21.6762C47.0253 21.8012 46.9268 21.9016 46.8018 21.9773C46.6806 22.0493 46.5443 22.0853 46.3928 22.0853Z" fill="black"/>
|
||||
<path d="M42.6996 10.3637V22H41.6371V11.4773H41.5689L38.8416 13.2898V12.1875L41.5916 10.3637H42.6996Z" fill="black"/>
|
||||
<path d="M32.9098 22.1591C32.0916 22.1591 31.3928 21.9243 30.8132 21.4546C30.2375 20.9811 29.7943 20.2974 29.4837 19.4035C29.1768 18.5095 29.0234 17.4357 29.0234 16.1819C29.0234 14.9319 29.1768 13.8618 29.4837 12.9716C29.7943 12.0777 30.2393 11.394 30.8189 10.9205C31.4022 10.4432 32.0992 10.2046 32.9098 10.2046C33.7204 10.2046 34.4155 10.4432 34.995 10.9205C35.5784 11.394 36.0234 12.0777 36.3303 12.9716C36.6409 13.8618 36.7962 14.9319 36.7962 16.1819C36.7962 17.4357 36.6409 18.5095 36.3303 19.4035C36.0234 20.2974 35.5803 20.9811 35.0007 21.4546C34.425 21.9243 33.728 22.1591 32.9098 22.1591ZM32.9098 21.2046C33.8075 21.2046 34.5083 20.7671 35.0121 19.8921C35.5159 19.0133 35.7678 17.7766 35.7678 16.1819C35.7678 15.1213 35.6522 14.216 35.4212 13.466C35.1939 12.7122 34.8662 12.1364 34.4382 11.7387C34.014 11.341 33.5045 11.1421 32.9098 11.1421C32.0196 11.1421 31.3208 11.5853 30.8132 12.4716C30.3056 13.3542 30.0518 14.591 30.0518 16.1819C30.0518 17.2425 30.1655 18.1478 30.3928 18.8978C30.6238 19.6478 30.9515 20.2197 31.3757 20.6137C31.8037 21.0076 32.3151 21.2046 32.9098 21.2046Z" fill="black"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 2.0 KiB After Width: | Height: | Size: 3.1 KiB |
|
|
@ -1,20 +0,0 @@
|
|||
<svg width="80" height="22" viewBox="0 0 450 120" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2;">
|
||||
<g transform="matrix(0.172742,0,0,0.172742,9.60932,8.17741)">
|
||||
<circle cx="300" cy="300" r="300" style="fill:rgb(0,52,37);"/>
|
||||
</g>
|
||||
<g transform="matrix(0.172742,0,0,0.172742,9.60932,8.17741)">
|
||||
<path d="M452.119,361.224C452.119,349.527 442.623,340.031 430.926,340.031C419.229,340.031 409.733,349.527 409.733,361.224L409.733,470.486C409.733,482.183 419.229,491.679 430.926,491.679C442.623,491.679 452.119,482.183 452.119,470.486L452.119,361.224Z" style="fill:white;"/>
|
||||
</g>
|
||||
<g transform="matrix(0.172742,0,0,0.172742,9.60932,8.17741)">
|
||||
<path d="M422.005,133.354C413.089,125.771 399.714,126.851 392.131,135.767L273.699,275.021C270.643,278.614 268.994,282.932 268.698,287.302C268.532,288.371 268.446,289.466 268.446,290.581L268.446,468.603C268.446,480.308 277.934,489.796 289.639,489.796C301.344,489.796 310.832,480.308 310.832,468.603L310.832,296.784L424.419,163.228C432.002,154.312 430.921,140.937 422.005,133.354Z" style="fill:white;"/>
|
||||
</g>
|
||||
<g transform="matrix(0.13359,-0.109514,0.109514,0.13359,-0.630793,25.9151)">
|
||||
<path d="M156.358,155.443C156.358,143.746 146.862,134.25 135.165,134.25C123.468,134.25 113.972,143.746 113.972,155.443L113.972,287.802C113.972,299.499 123.468,308.995 135.165,308.995C146.862,308.995 156.358,299.499 156.358,287.802L156.358,155.443Z" style="fill:white;"/>
|
||||
</g>
|
||||
<g transform="matrix(0.172742,0,0,0.172742,9.60932,8.17741)">
|
||||
<circle cx="460.126" cy="279.278" r="25.903" style="fill:rgb(0,255,37);"/>
|
||||
</g>
|
||||
<g transform="matrix(1,0,0,1,-77.4848,13.0849)">
|
||||
<text x="210.275px" y="74.595px" style="font-family:'AlibabaPuHuiTi_3_55_Regular', 'Alibaba PuHuiTi 3.0', serif;font-size:80px;">零一万物</text>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.0 KiB |
|
|
@ -1,7 +1,8 @@
|
|||
<svg width="24" height="24" viewBox="0 0 600 600" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="300" cy="300" r="300" fill="#003425"/>
|
||||
<rect x="409.733" y="340.031" width="42.3862" height="151.648" rx="21.1931" fill="white"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M422.005 133.354C413.089 125.771 399.714 126.851 392.131 135.767L273.699 275.021C270.643 278.614 268.994 282.932 268.698 287.302C268.532 288.371 268.446 289.466 268.446 290.581V468.603C268.446 480.308 277.934 489.796 289.639 489.796C301.344 489.796 310.832 480.308 310.832 468.603V296.784L424.419 163.228C432.002 154.312 430.921 140.937 422.005 133.354Z" fill="white"/>
|
||||
<rect x="113.972" y="134.25" width="42.3862" height="174.745" rx="21.1931" transform="rotate(-39.3441 113.972 134.25)" fill="white"/>
|
||||
<circle cx="460.126" cy="279.278" r="25.9027" fill="#00FF25"/>
|
||||
</svg>
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="1" y="1" width="22" height="22" rx="5" fill="#133426"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M4.52004 4.43887C4.87945 4.1429 5.41077 4.19431 5.70676 4.55371L9.39515 9.03221C9.69114 9.39161 9.63972 9.92289 9.2803 10.2189C8.92089 10.5148 8.38957 10.4634 8.09358 10.104L4.40519 5.62553C4.1092 5.26613 4.16062 4.73485 4.52004 4.43887Z" fill="white"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.2183 20.2466C10.7527 20.2466 10.3752 19.8692 10.3752 19.4036L10.3749 11.4708C10.3749 11.266 10.4494 11.0683 10.5845 10.9145L15.7522 5.03014C16.0594 4.6803 16.5921 4.64575 16.942 4.95297C17.2918 5.26018 17.3264 5.79283 17.0192 6.14266L12.0611 11.7883L12.0613 19.4035C12.0613 19.8691 11.6839 20.2466 11.2183 20.2466Z" fill="white"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M17.2861 13.7246C17.7517 13.7246 18.1291 14.102 18.1291 14.5676V19.292C18.1291 19.7576 17.7517 20.135 17.2861 20.135C16.8205 20.135 16.443 19.7576 16.443 19.292V14.5676C16.443 14.102 16.8205 13.7246 17.2861 13.7246Z" fill="white"/>
|
||||
<ellipse cx="18.4761" cy="11.1782" rx="1.31008" ry="1.31" fill="#00FF00"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 882 B After Width: | Height: | Size: 1.2 KiB |
|
|
@ -9,7 +9,7 @@ icon_small:
|
|||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#EFFDFD"
|
||||
background: "#E9F1EC"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from 01.ai
|
||||
|
|
|
|||
|
|
@ -32,3 +32,8 @@ parameter_rules:
|
|||
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
|
||||
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
|
|
|
|||
|
|
@ -30,3 +30,8 @@ parameter_rules:
|
|||
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
|
||||
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
|
|
|
|||
|
|
@ -171,6 +171,7 @@ class ToolProviderCredentials(BaseModel):
|
|||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
BOOLEAN = "boolean"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
|
||||
|
|
@ -192,7 +193,7 @@ class ToolProviderCredentials(BaseModel):
|
|||
name: str = Field(..., description="The name of the credentials")
|
||||
type: CredentialsType = Field(..., description="The type of the credentials")
|
||||
required: bool = False
|
||||
default: Optional[str] = None
|
||||
default: Optional[Union[int, str]] = None
|
||||
options: Optional[list[ToolCredentialsOption]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
|
|
|
|||
|
|
@ -12,12 +12,11 @@ class BingProvider(BuiltinToolProviderController):
|
|||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
).validate_credentials(
|
||||
credentials=credentials,
|
||||
tool_parameters={
|
||||
"query": "test",
|
||||
"result_type": "link",
|
||||
"enable_webpages": True,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -43,3 +43,63 @@ credentials_for_provider:
|
|||
zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search"
|
||||
pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
|
||||
default: https://api.bing.microsoft.com/v7.0/search
|
||||
allow_entities:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Entities Search
|
||||
zh_Hans: 支持实体搜索
|
||||
pt_BR: Allow Entities Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow entity search
|
||||
zh_Hans: 您的订阅计划是否支持实体搜索
|
||||
pt_BR: Does your subscription plan allow entity search
|
||||
default: true
|
||||
allow_web_pages:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Web Pages Search
|
||||
zh_Hans: 支持网页搜索
|
||||
pt_BR: Allow Web Pages Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow web pages search
|
||||
zh_Hans: 您的订阅计划是否支持网页搜索
|
||||
pt_BR: Does your subscription plan allow web pages search
|
||||
default: true
|
||||
allow_computation:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Computation Search
|
||||
zh_Hans: 支持计算搜索
|
||||
pt_BR: Allow Computation Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow computation search
|
||||
zh_Hans: 您的订阅计划是否支持计算搜索
|
||||
pt_BR: Does your subscription plan allow computation search
|
||||
default: false
|
||||
allow_news:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow News Search
|
||||
zh_Hans: 支持新闻搜索
|
||||
pt_BR: Allow News Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow news search
|
||||
zh_Hans: 您的订阅计划是否支持新闻搜索
|
||||
pt_BR: Does your subscription plan allow news search
|
||||
default: false
|
||||
allow_related_searches:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Related Searches
|
||||
zh_Hans: 支持相关搜索
|
||||
pt_BR: Allow Related Searches
|
||||
help:
|
||||
en_US: Does your subscription plan allow related searches
|
||||
zh_Hans: 您的订阅计划是否支持相关搜索
|
||||
pt_BR: Does your subscription plan allow related searches
|
||||
default: false
|
||||
|
|
|
|||
|
|
@ -10,53 +10,23 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
|||
class BingSearchTool(BuiltinTool):
|
||||
url = 'https://api.bing.microsoft.com/v7.0/search'
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke_bing(self,
|
||||
user_id: str,
|
||||
subscription_key: str, query: str, limit: int,
|
||||
result_type: str, market: str, lang: str,
|
||||
filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke bing search
|
||||
"""
|
||||
|
||||
key = self.runtime.credentials.get('subscription_key', None)
|
||||
if not key:
|
||||
raise Exception('subscription_key is required')
|
||||
|
||||
server_url = self.runtime.credentials.get('server_url', None)
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query', None)
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
limit = min(tool_parameters.get('limit', 5), 10)
|
||||
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||
|
||||
market = tool_parameters.get('market', 'US')
|
||||
lang = tool_parameters.get('language', 'en')
|
||||
filter = []
|
||||
|
||||
if tool_parameters.get('enable_computation', False):
|
||||
filter.append('Computation')
|
||||
if tool_parameters.get('enable_entities', False):
|
||||
filter.append('Entities')
|
||||
if tool_parameters.get('enable_news', False):
|
||||
filter.append('News')
|
||||
if tool_parameters.get('enable_related_search', False):
|
||||
filter.append('RelatedSearches')
|
||||
if tool_parameters.get('enable_webpages', False):
|
||||
filter.append('WebPages')
|
||||
|
||||
market_code = f'{lang}-{market}'
|
||||
accept_language = f'{lang},{market_code};q=0.9'
|
||||
headers = {
|
||||
'Ocp-Apim-Subscription-Key': key,
|
||||
'Ocp-Apim-Subscription-Key': subscription_key,
|
||||
'Accept-Language': accept_language
|
||||
}
|
||||
|
||||
query = quote(query)
|
||||
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filter)}'
|
||||
server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||
response = get(server_url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
|
@ -124,3 +94,105 @@ class BingSearchTool(BuiltinTool):
|
|||
text += f'{related["displayText"]} - {related["webSearchUrl"]}\n'
|
||||
|
||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
|
||||
key = credentials.get('subscription_key', None)
|
||||
if not key:
|
||||
raise Exception('subscription_key is required')
|
||||
|
||||
server_url = credentials.get('server_url', None)
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query', None)
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
limit = min(tool_parameters.get('limit', 5), 10)
|
||||
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||
|
||||
market = tool_parameters.get('market', 'US')
|
||||
lang = tool_parameters.get('language', 'en')
|
||||
filter = []
|
||||
|
||||
if credentials.get('allow_entities', False):
|
||||
filter.append('Entities')
|
||||
|
||||
if credentials.get('allow_computation', False):
|
||||
filter.append('Computation')
|
||||
|
||||
if credentials.get('allow_news', False):
|
||||
filter.append('News')
|
||||
|
||||
if credentials.get('allow_related_searches', False):
|
||||
filter.append('RelatedSearches')
|
||||
|
||||
if credentials.get('allow_web_pages', False):
|
||||
filter.append('WebPages')
|
||||
|
||||
if not filter:
|
||||
raise Exception('At least one filter is required')
|
||||
|
||||
self._invoke_bing(
|
||||
user_id='test',
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter
|
||||
)
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
|
||||
key = self.runtime.credentials.get('subscription_key', None)
|
||||
if not key:
|
||||
raise Exception('subscription_key is required')
|
||||
|
||||
server_url = self.runtime.credentials.get('server_url', None)
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query', None)
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
limit = min(tool_parameters.get('limit', 5), 10)
|
||||
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||
|
||||
market = tool_parameters.get('market', 'US')
|
||||
lang = tool_parameters.get('language', 'en')
|
||||
filter = []
|
||||
|
||||
if tool_parameters.get('enable_computation', False):
|
||||
filter.append('Computation')
|
||||
if tool_parameters.get('enable_entities', False):
|
||||
filter.append('Entities')
|
||||
if tool_parameters.get('enable_news', False):
|
||||
filter.append('News')
|
||||
if tool_parameters.get('enable_related_search', False):
|
||||
filter.append('RelatedSearches')
|
||||
if tool_parameters.get('enable_webpages', False):
|
||||
filter.append('WebPages')
|
||||
|
||||
if not filter:
|
||||
raise Exception('At least one filter is required')
|
||||
|
||||
return self._invoke_bing(
|
||||
user_id=user_id,
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter
|
||||
)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<g clip-path="url(#clip0_16624_62807)">
|
||||
<path d="M7.11111 0.888889C7.11111 0.888889 7.11111 0 8 0C8.88889 0 8.88889 0.888889 8.88889 0.888889V1.77778C8.88889 1.77778 8.88889 2.66667 8 2.66667C7.11111 2.66667 7.11111 1.77778 7.11111 1.77778V0.888889ZM15.1111 7.11111C15.1111 7.11111 16 7.11111 16 8C16 8.88889 15.1111 8.88889 15.1111 8.88889H14.2222C14.2222 8.88889 13.3333 8.88889 13.3333 8C13.3333 7.11111 14.2222 7.11111 14.2222 7.11111H15.1111ZM1.77778 7.11111C1.77778 7.11111 2.66667 7.11111 2.66667 8C2.66667 8.88889 1.77778 8.88889 1.77778 8.88889H0.888889C0.888889 8.88889 0 8.88889 0 8C0 7.11111 0.888889 7.11111 0.888889 7.11111H1.77778ZM4.05378 3.24133C4.05378 3.24133 4.68222 3.86978 4.05378 4.49822C3.42533 5.12667 2.79689 4.49822 2.79689 4.49822L2.168 3.87022C2.168 3.87022 1.53956 3.24178 2.168 2.61289C2.79689 1.98444 3.42533 2.61289 3.42533 2.61289L4.05378 3.24133ZM13.2036 4.49822C13.2036 4.49822 12.5751 5.12667 11.9467 4.49822C11.3182 3.86978 11.9467 3.24133 11.9467 3.24133L12.5751 2.61289C12.5751 2.61289 13.2036 1.98444 13.832 2.61289C14.4604 3.24133 13.832 3.86978 13.832 3.86978L13.2036 4.49822ZM3.87022 13.8316C3.87022 13.8316 3.24178 14.46 2.61333 13.8316C1.98489 13.2031 2.61333 12.5747 2.61333 12.5747L3.24178 11.9462C3.24178 11.9462 3.87022 11.3178 4.49867 11.9462C5.12711 12.5747 4.49867 13.2031 4.49867 13.2031L3.87022 13.8316Z" fill="#FFCF27"/>
|
||||
<path d="M8.00011 12.4446C10.4547 12.4446 12.4446 10.4547 12.4446 8.00011C12.4446 5.54551 10.4547 3.55566 8.00011 3.55566C5.54551 3.55566 3.55566 5.54551 3.55566 8.00011C3.55566 10.4547 5.54551 12.4446 8.00011 12.4446Z" fill="#FFCB13"/>
|
||||
<path d="M13.2343 10.3111C12.949 10.3111 12.6743 10.3556 12.4152 10.4378C12.1094 9.53647 11.2774 8.88892 10.2966 8.88892C9.24411 8.88892 8.36322 9.63469 8.11922 10.6387C7.85878 10.436 7.53744 10.3116 7.18544 10.3116C6.32633 10.3116 5.62989 11.0276 5.62989 11.9116C5.62989 12.1262 5.67255 12.3298 5.74722 12.5174C5.59878 12.4742 5.44544 12.4445 5.28411 12.4445C4.32944 12.4445 3.55566 13.2405 3.55566 14.2222C3.55566 15.204 4.32944 16 5.28411 16H13.2348C14.7619 16 16.0001 14.7271 16.0001 13.1556C16.0001 11.5845 14.7619 10.3111 13.2343 10.3111Z" fill="#E9F6FF"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_16624_62807">
|
||||
<rect width="16" height="16" fill="white"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.4 KiB |
|
|
@ -0,0 +1,36 @@
|
|||
import requests
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None):
|
||||
|
||||
url = "https://api.openweathermap.org/data/2.5/weather"
|
||||
params = {"q": city, "appid": api_key, "units": units, "lang": language}
|
||||
|
||||
return requests.get(url, params=params)
|
||||
|
||||
|
||||
class OpenweatherProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
if "api_key" not in credentials or not credentials.get("api_key"):
|
||||
raise ToolProviderCredentialValidationError(
|
||||
"Open weather API key is required."
|
||||
)
|
||||
apikey = credentials.get("api_key")
|
||||
try:
|
||||
response = query_weather(api_key=apikey)
|
||||
if response.status_code == 200:
|
||||
pass
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
(response.json()).get("info")
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
"Open weather API Key is invalid. {}".format(e)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
identity:
|
||||
author: Onelevenvy
|
||||
name: openweather
|
||||
label:
|
||||
en_US: Open weather query
|
||||
zh_Hans: Open Weather
|
||||
pt_BR: Consulta de clima open weather
|
||||
description:
|
||||
en_US: Weather query toolkit based on Open Weather
|
||||
zh_Hans: 基于open weather的天气查询工具包
|
||||
pt_BR: Kit de consulta de clima baseado no Open Weather
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API Key
|
||||
pt_BR: Fogo a chave
|
||||
placeholder:
|
||||
en_US: Please enter your open weather API Key
|
||||
zh_Hans: 请输入你的open weather API Key
|
||||
pt_BR: Insira sua chave de API open weather
|
||||
help:
|
||||
en_US: Get your API Key from open weather
|
||||
zh_Hans: 从open weather获取您的 API Key
|
||||
pt_BR: Obtenha sua chave de API do open weather
|
||||
url: https://openweathermap.org
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class OpenweatherTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
city = tool_parameters.get("city", "")
|
||||
if not city:
|
||||
return self.create_text_message("Please tell me your city")
|
||||
if (
|
||||
"api_key" not in self.runtime.credentials
|
||||
or not self.runtime.credentials.get("api_key")
|
||||
):
|
||||
return self.create_text_message("OpenWeather API key is required.")
|
||||
|
||||
units = tool_parameters.get("units", "metric")
|
||||
lang = tool_parameters.get("lang", "zh_cn")
|
||||
try:
|
||||
# request URL
|
||||
url = "https://api.openweathermap.org/data/2.5/weather"
|
||||
|
||||
# request parmas
|
||||
params = {
|
||||
"q": city,
|
||||
"appid": self.runtime.credentials.get("api_key"),
|
||||
"units": units,
|
||||
"lang": lang,
|
||||
}
|
||||
response = requests.get(url, params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
||||
data = response.json()
|
||||
return self.create_text_message(
|
||||
self.summary(
|
||||
user_id=user_id, content=json.dumps(data, ensure_ascii=False)
|
||||
)
|
||||
)
|
||||
else:
|
||||
error_message = {
|
||||
"error": f"failed:{response.status_code}",
|
||||
"data": response.text,
|
||||
}
|
||||
# return error
|
||||
return json.dumps(error_message)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(
|
||||
"Openweather API Key is invalid. {}".format(e)
|
||||
)
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
identity:
|
||||
name: weather
|
||||
author: Onelevenvy
|
||||
label:
|
||||
en_US: Open Weather Query
|
||||
zh_Hans: 天气查询
|
||||
pt_BR: Previsão do tempo
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Weather forecast inquiry
|
||||
zh_Hans: 天气查询
|
||||
pt_BR: Inquérito sobre previsão meteorológica
|
||||
llm: A tool when you want to ask about the weather or weather-related question
|
||||
parameters:
|
||||
- name: city
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: city
|
||||
zh_Hans: 城市
|
||||
pt_BR: cidade
|
||||
human_description:
|
||||
en_US: Target city for weather forecast query
|
||||
zh_Hans: 天气预报查询的目标城市
|
||||
pt_BR: Cidade de destino para consulta de previsão do tempo
|
||||
llm_description: If you don't know you can extract the city name from the
|
||||
question or you can reply:Please tell me your city. You have to extract
|
||||
the Chinese city name from the question.If the input region is in Chinese
|
||||
characters for China, it should be replaced with the corresponding English
|
||||
name, such as '北京' for correct input is 'Beijing'
|
||||
form: llm
|
||||
- name: lang
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: language
|
||||
zh_Hans: 语言
|
||||
pt_BR: language
|
||||
label:
|
||||
en_US: language
|
||||
zh_Hans: 语言
|
||||
pt_BR: language
|
||||
form: form
|
||||
options:
|
||||
- value: zh_cn
|
||||
label:
|
||||
en_US: cn
|
||||
zh_Hans: 中国
|
||||
pt_BR: cn
|
||||
- value: en_us
|
||||
label:
|
||||
en_US: usa
|
||||
zh_Hans: 美国
|
||||
pt_BR: usa
|
||||
default: zh_cn
|
||||
- name: units
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: units for temperature
|
||||
zh_Hans: 温度单位
|
||||
pt_BR: units for temperature
|
||||
label:
|
||||
en_US: units
|
||||
zh_Hans: 单位
|
||||
pt_BR: units
|
||||
form: form
|
||||
options:
|
||||
- value: metric
|
||||
label:
|
||||
en_US: metric
|
||||
zh_Hans: ℃
|
||||
pt_BR: metric
|
||||
- value: imperial
|
||||
label:
|
||||
en_US: imperial
|
||||
zh_Hans: ℉
|
||||
pt_BR: imperial
|
||||
default: metric
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M21.6547 16.7993C21.3111 18.0034 20.7384 19.0938 20.0054 20.048C18.9058 21.4111 15.1261 21.4111 12.8583 20.8204C10.4072 20.1616 8.6433 18.6395 8.50586 18.5259C9.46797 19.2756 10.6821 19.7072 12.0107 19.7072C15.1948 19.7072 17.7605 17.1174 17.7605 13.9368C17.7605 12.9826 17.5314 12.0966 17.119 11.3015C17.0961 11.2561 17.1419 11.2106 17.1649 11.2333C18.9745 11.5287 22.571 13.2098 21.6547 16.7993Z" fill="#2751D0"/>
|
||||
<path d="M21.9994 12.7773C21.9994 12.8454 21.9306 12.8682 21.8848 12.8C21.0372 11.0053 19.5483 10.46 17.7615 10.0511C16.4099 9.75577 15.5166 9.3014 15.1271 9.09694C15.0355 9.0515 14.9668 8.98335 14.8751 8.93791C12.0575 7.23404 12.0117 4.30339 12.0117 4.30339V0.0550813C12.0117 0.00964486 12.0804 -0.0130733 12.1034 0.0096449L18.7694 6.50706L19.2734 6.98414C20.7394 8.52898 21.7474 10.5509 21.9994 12.7773Z" fill="#D82F20"/>
|
||||
<path d="M20.0052 20.0462C18.1726 22.4316 15.2863 23.9992 12.0334 23.9992C6.48985 23.9992 2 19.501 2 13.9577C2 11.2543 3.05374 8.8234 4.7947 7.00594L5.29866 6.50614L9.65107 2.25783C9.69688 2.2124 9.7656 2.25783 9.7427 2.30327C9.67397 2.59861 9.55944 3.28015 9.62816 4.18888C9.71979 5.25664 10.0634 6.68789 11.0713 8.27817C11.6898 9.27777 12.5832 10.3228 13.8202 11.4133C13.9577 11.5496 14.118 11.6632 14.2784 11.7995C14.8281 12.3674 15.1488 13.1171 15.1488 13.9577C15.1488 15.6616 13.7515 17.0474 12.0563 17.0474C11.3233 17.0474 10.659 16.7975 10.1321 16.3659C10.0863 16.3204 10.1321 16.2523 10.1779 16.275C10.2925 16.2977 10.407 16.3204 10.5215 16.3204C11.1171 16.3204 11.6211 15.8433 11.6211 15.2299C11.6211 14.8665 11.4378 14.5257 11.163 14.3439C10.4299 13.7533 9.81142 13.1853 9.28455 12.6173C8.55151 11.8222 8.00174 11.0498 7.61231 10.3001C6.81055 11.2997 6.30659 12.5492 6.30659 13.935C6.30659 15.7979 7.17707 17.4563 8.55152 18.5014C8.68896 18.615 10.4528 20.1371 12.9039 20.7959C15.1259 21.432 18.9057 21.4093 20.0052 20.0462Z" fill="#69C5F4"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.0 KiB |
|
|
@ -0,0 +1,40 @@
|
|||
import json
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class SparkProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
if "APPID" not in credentials or not credentials.get("APPID"):
|
||||
raise ToolProviderCredentialValidationError("APPID is required.")
|
||||
if "APISecret" not in credentials or not credentials.get("APISecret"):
|
||||
raise ToolProviderCredentialValidationError("APISecret is required.")
|
||||
if "APIKey" not in credentials or not credentials.get("APIKey"):
|
||||
raise ToolProviderCredentialValidationError("APIKey is required.")
|
||||
|
||||
appid = credentials.get("APPID")
|
||||
apisecret = credentials.get("APISecret")
|
||||
apikey = credentials.get("APIKey")
|
||||
prompt = "a cute black dog"
|
||||
|
||||
try:
|
||||
response = spark_response(prompt, appid, apikey, apisecret)
|
||||
data = json.loads(response)
|
||||
code = data["header"]["code"]
|
||||
|
||||
if code == 0:
|
||||
# 0 success,
|
||||
pass
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
"image generate error, code:{}".format(code)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
"APPID APISecret APIKey is invalid. {}".format(e)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
identity:
|
||||
author: Onelevenvy
|
||||
name: spark
|
||||
label:
|
||||
en_US: Spark
|
||||
zh_Hans: 讯飞星火
|
||||
pt_BR: Spark
|
||||
description:
|
||||
en_US: Spark Platform Toolkit
|
||||
zh_Hans: 讯飞星火平台工具
|
||||
pt_BR: Pacote de Ferramentas da Plataforma Spark
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
APPID:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Spark APPID
|
||||
zh_Hans: APPID
|
||||
pt_BR: Spark APPID
|
||||
help:
|
||||
en_US: Please input your APPID
|
||||
zh_Hans: 请输入你的 APPID
|
||||
pt_BR: Please input your APPID
|
||||
placeholder:
|
||||
en_US: Please input your APPID
|
||||
zh_Hans: 请输入你的 APPID
|
||||
pt_BR: Please input your APPID
|
||||
APISecret:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Spark APISecret
|
||||
zh_Hans: APISecret
|
||||
pt_BR: Spark APISecret
|
||||
help:
|
||||
en_US: Please input your Spark APISecret
|
||||
zh_Hans: 请输入你的 APISecret
|
||||
pt_BR: Please input your Spark APISecret
|
||||
placeholder:
|
||||
en_US: Please input your Spark APISecret
|
||||
zh_Hans: 请输入你的 APISecret
|
||||
pt_BR: Please input your Spark APISecret
|
||||
APIKey:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Spark APIKey
|
||||
zh_Hans: APIKey
|
||||
pt_BR: Spark APIKey
|
||||
help:
|
||||
en_US: Please input your Spark APIKey
|
||||
zh_Hans: 请输入你的 APIKey
|
||||
pt_BR: Please input your Spark APIKey
|
||||
placeholder:
|
||||
en_US: Please input your Spark APIKey
|
||||
zh_Hans: 请输入你的 APIKey
|
||||
pt_BR: Please input Spark APIKey
|
||||
url: https://console.xfyun.cn/services
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from base64 import b64decode
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AssembleHeaderException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class Url:
|
||||
def __init__(this, host, path, schema):
|
||||
this.host = host
|
||||
this.path = path
|
||||
this.schema = schema
|
||||
|
||||
|
||||
# calculate sha256 and encode to base64
|
||||
def sha256base64(data):
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(data)
|
||||
digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8")
|
||||
return digest
|
||||
|
||||
|
||||
def parse_url(requset_url):
|
||||
stidx = requset_url.index("://")
|
||||
host = requset_url[stidx + 3 :]
|
||||
schema = requset_url[: stidx + 3]
|
||||
edidx = host.index("/")
|
||||
if edidx <= 0:
|
||||
raise AssembleHeaderException("invalid request url:" + requset_url)
|
||||
path = host[edidx:]
|
||||
host = host[:edidx]
|
||||
u = Url(host, path, schema)
|
||||
return u
|
||||
|
||||
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
|
||||
u = parse_url(requset_url)
|
||||
host = u.host
|
||||
path = u.path
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
|
||||
host, date, method, path
|
||||
)
|
||||
signature_sha = hmac.new(
|
||||
api_secret.encode("utf-8"),
|
||||
signature_origin.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
|
||||
encoding="utf-8"
|
||||
)
|
||||
values = {"host": host, "date": date, "authorization": authorization}
|
||||
|
||||
return requset_url + "?" + urlencode(values)
|
||||
|
||||
|
||||
def get_body(appid, text):
|
||||
body = {
|
||||
"header": {"app_id": appid, "uid": "123456789"},
|
||||
"parameter": {
|
||||
"chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}
|
||||
},
|
||||
"payload": {"message": {"text": [{"role": "user", "content": text}]}},
|
||||
}
|
||||
return body
|
||||
|
||||
|
||||
def spark_response(text, appid, apikey, apisecret):
|
||||
host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti"
|
||||
url = assemble_ws_auth_url(
|
||||
host, method="POST", api_key=apikey, api_secret=apisecret
|
||||
)
|
||||
content = get_body(appid, text)
|
||||
response = requests.post(
|
||||
url, json=content, headers={"content-type": "application/json"}
|
||||
).text
|
||||
return response
|
||||
|
||||
|
||||
class SparkImgGeneratorTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
|
||||
if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get(
|
||||
"APPID"
|
||||
):
|
||||
return self.create_text_message("APPID is required.")
|
||||
if (
|
||||
"APISecret" not in self.runtime.credentials
|
||||
or not self.runtime.credentials.get("APISecret")
|
||||
):
|
||||
return self.create_text_message("APISecret is required.")
|
||||
if (
|
||||
"APIKey" not in self.runtime.credentials
|
||||
or not self.runtime.credentials.get("APIKey")
|
||||
):
|
||||
return self.create_text_message("APIKey is required.")
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
res = self.img_generation(prompt)
|
||||
result = []
|
||||
for image in res:
|
||||
result.append(
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image["base64_image"]),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def img_generation(self, prompt):
|
||||
response = spark_response(
|
||||
text=prompt,
|
||||
appid=self.runtime.credentials.get("APPID"),
|
||||
apikey=self.runtime.credentials.get("APIKey"),
|
||||
apisecret=self.runtime.credentials.get("APISecret"),
|
||||
)
|
||||
data = json.loads(response)
|
||||
code = data["header"]["code"]
|
||||
if code != 0:
|
||||
return self.create_text_message(f"error: {code}, {data}")
|
||||
else:
|
||||
text = data["payload"]["choices"]["text"]
|
||||
image_content = text[0]
|
||||
image_base = image_content["content"]
|
||||
json_data = {"base64_image": image_base}
|
||||
return [json_data]
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
identity:
|
||||
name: spark_img_generation
|
||||
author: Onelevenvy
|
||||
label:
|
||||
en_US: Spark Image Generation
|
||||
zh_Hans: 图片生成
|
||||
pt_BR: Geração de imagens Spark
|
||||
icon: icon.svg
|
||||
description:
|
||||
en_US: Spark Image Generation
|
||||
zh_Hans: 图片生成
|
||||
pt_BR: Geração de imagens Spark
|
||||
description:
|
||||
human:
|
||||
en_US: Generate images based on user input, with image generation API
|
||||
provided by Spark
|
||||
zh_Hans: 根据用户的输入生成图片,由讯飞星火提供图片生成api
|
||||
pt_BR: Gerar imagens com base na entrada do usuário, com API de geração
|
||||
de imagem fornecida pela Spark
|
||||
llm: spark_img_generation is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt
|
||||
zh_Hans: 图像提示词
|
||||
pt_BR: Image prompt
|
||||
llm_description: Image prompt of spark_img_generation tooll, you should
|
||||
describe the image you want to generate as a list of words as possible
|
||||
as detailed
|
||||
form: llm
|
||||
|
|
@ -246,8 +246,27 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}')
|
||||
|
||||
if credentials[credential_name]:
|
||||
elif credential_schema.type == ToolProviderCredentials.CredentialsType.BOOLEAN:
|
||||
if isinstance(credentials[credential_name], bool):
|
||||
pass
|
||||
elif isinstance(credentials[credential_name], str):
|
||||
if credentials[credential_name].lower() == 'true':
|
||||
credentials[credential_name] = True
|
||||
elif credentials[credential_name].lower() == 'false':
|
||||
credentials[credential_name] = False
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
|
||||
elif isinstance(credentials[credential_name], int):
|
||||
if credentials[credential_name] == 1:
|
||||
credentials[credential_name] = True
|
||||
elif credentials[credential_name] == 0:
|
||||
credentials[credential_name] = False
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
|
||||
|
||||
if credentials[credential_name] or credentials[credential_name] == False:
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import requests
|
|||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
API_TOOL_DEFAULT_TIMEOUT = (10, 60)
|
||||
|
|
@ -81,7 +81,7 @@ class ApiTool(Tool):
|
|||
needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
|
||||
for parameter in needed_parameters:
|
||||
if parameter.required and parameter.name not in parameters:
|
||||
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter.name}")
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
|
||||
|
||||
if parameter.default is not None and parameter.name not in parameters:
|
||||
parameters[parameter.name] = parameter.default
|
||||
|
|
@ -94,7 +94,7 @@ class ApiTool(Tool):
|
|||
"""
|
||||
if isinstance(response, httpx.Response):
|
||||
if response.status_code >= 400:
|
||||
raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}")
|
||||
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
|
||||
if not response.content:
|
||||
return 'Empty response from the tool, please check your parameters and try again.'
|
||||
try:
|
||||
|
|
@ -107,7 +107,7 @@ class ApiTool(Tool):
|
|||
return response.text
|
||||
elif isinstance(response, requests.Response):
|
||||
if not response.ok:
|
||||
raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}")
|
||||
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
|
||||
if not response.content:
|
||||
return 'Empty response from the tool, please check your parameters and try again.'
|
||||
try:
|
||||
|
|
@ -139,7 +139,7 @@ class ApiTool(Tool):
|
|||
if parameter['name'] in parameters:
|
||||
value = parameters[parameter['name']]
|
||||
elif parameter['required']:
|
||||
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
|
||||
else:
|
||||
value = (parameter.get('schema', {}) or {}).get('default', '')
|
||||
path_params[parameter['name']] = value
|
||||
|
|
@ -149,7 +149,7 @@ class ApiTool(Tool):
|
|||
if parameter['name'] in parameters:
|
||||
value = parameters[parameter['name']]
|
||||
elif parameter['required']:
|
||||
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
|
||||
else:
|
||||
value = (parameter.get('schema', {}) or {}).get('default', '')
|
||||
params[parameter['name']] = value
|
||||
|
|
@ -159,7 +159,7 @@ class ApiTool(Tool):
|
|||
if parameter['name'] in parameters:
|
||||
value = parameters[parameter['name']]
|
||||
elif parameter['required']:
|
||||
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
|
||||
else:
|
||||
value = (parameter.get('schema', {}) or {}).get('default', '')
|
||||
cookies[parameter['name']] = value
|
||||
|
|
@ -169,7 +169,7 @@ class ApiTool(Tool):
|
|||
if parameter['name'] in parameters:
|
||||
value = parameters[parameter['name']]
|
||||
elif parameter['required']:
|
||||
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
|
||||
else:
|
||||
value = (parameter.get('schema', {}) or {}).get('default', '')
|
||||
headers[parameter['name']] = value
|
||||
|
|
@ -188,7 +188,7 @@ class ApiTool(Tool):
|
|||
# convert type
|
||||
body[name] = self._convert_body_property_type(property, parameters[name])
|
||||
elif name in required:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
raise ToolParameterValidationError(
|
||||
f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
|
||||
)
|
||||
elif 'default' in property:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
"""add_tenant_id_db_index
|
||||
|
||||
Revision ID: a8f9b3c45e4a
|
||||
Revises: 16830a790f0f
|
||||
Create Date: 2024-03-18 05:07:35.588473
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a8f9b3c45e4a'
|
||||
down_revision = '16830a790f0f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('document_segments', schema=None) as batch_op:
|
||||
batch_op.create_index('document_segment_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.create_index('document_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_tenant_idx')
|
||||
|
||||
with op.batch_alter_table('document_segments', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_segment_tenant_idx')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -176,6 +176,7 @@ class Document(db.Model):
|
|||
db.PrimaryKeyConstraint('id', name='document_pkey'),
|
||||
db.Index('document_dataset_id_idx', 'dataset_id'),
|
||||
db.Index('document_is_paused_idx', 'is_paused'),
|
||||
db.Index('document_tenant_idx', 'tenant_id'),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
|
|
@ -334,6 +335,7 @@ class DocumentSegment(db.Model):
|
|||
db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'),
|
||||
db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'),
|
||||
db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'),
|
||||
db.Index('document_segment_tenant_idx', 'tenant_id'),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ gunicorn~=21.2.0
|
|||
gevent~=23.9.1
|
||||
langchain==0.0.250
|
||||
openai~=1.13.3
|
||||
tiktoken~=0.5.2
|
||||
tiktoken~=0.6.0
|
||||
psycopg2-binary~=2.9.6
|
||||
pycryptodome==3.19.1
|
||||
python-dotenv==1.0.0
|
||||
|
|
@ -36,7 +36,7 @@ python-docx~=1.1.0
|
|||
pypdfium2==4.16.0
|
||||
resend~=0.7.0
|
||||
pyjwt~=2.8.0
|
||||
anthropic~=0.17.0
|
||||
anthropic~=0.20.0
|
||||
newspaper3k==0.2.8
|
||||
google-api-python-client==2.90.0
|
||||
wikipedia==1.4.0
|
||||
|
|
|
|||
|
|
@ -138,9 +138,9 @@ class ToolManageService:
|
|||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return [
|
||||
v.to_dict() for _, v in (provider.credentials_schema or {}).items()
|
||||
]
|
||||
return json.loads(serialize_base_model_array([
|
||||
v for _, v in (provider.credentials_schema or {}).items()
|
||||
]))
|
||||
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_
|
|||
logging.info(
|
||||
click.style('Delete annotation index error: {}'.format(str(e)),
|
||||
fg='red'))
|
||||
vector.add_texts(documents)
|
||||
vector.create(documents)
|
||||
db.session.commit()
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, 'completed')
|
||||
end_at = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from dify_client.client import ChatClient, CompletionClient, DifyClient
|
||||
from dify_client.client import ChatClient, CompletionClient, DifyClient
|
||||
|
|
@ -210,6 +210,7 @@ const AgentTools: FC = () => {
|
|||
setting={currentTool?.tool_parameters as any}
|
||||
collection={currentTool?.collection as Collection}
|
||||
isBuiltIn={currentTool?.collection?.type === CollectionType.builtIn}
|
||||
isModel={currentTool?.collection?.type === CollectionType.model}
|
||||
onSave={handleToolSettingChange}
|
||||
onHide={() => setIsShowSettingTool(false)}
|
||||
/>)
|
||||
|
|
|
|||
|
|
@ -58,11 +58,16 @@ const SettingBuiltInTool: FC<Props> = ({
|
|||
(async () => {
|
||||
setIsLoading(true)
|
||||
try {
|
||||
const list = isBuiltIn
|
||||
? await fetchBuiltInToolList(collection.name)
|
||||
: isModel
|
||||
? await fetchModelToolList(collection.name)
|
||||
: await fetchCustomToolList(collection.name)
|
||||
const list = await new Promise<Tool[]>((resolve) => {
|
||||
(async function () {
|
||||
if (isModel)
|
||||
resolve(await fetchModelToolList(collection.name))
|
||||
else if (isBuiltIn)
|
||||
resolve(await fetchBuiltInToolList(collection.name))
|
||||
else
|
||||
resolve(await fetchCustomToolList(collection.name))
|
||||
}())
|
||||
})
|
||||
setTools(list)
|
||||
const currTool = list.find(tool => tool.name === toolName)
|
||||
if (currTool) {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import type { FC } from 'react'
|
|||
import React, { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import cn from 'classnames'
|
||||
import { toolCredentialToFormSchemas } from '../../utils/to-form-schema'
|
||||
import { addDefaultValue, toolCredentialToFormSchemas } from '../../utils/to-form-schema'
|
||||
import type { Collection } from '../../types'
|
||||
import Drawer from '@/app/components/base/drawer-plus'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
|
@ -30,12 +30,15 @@ const ConfigCredential: FC<Props> = ({
|
|||
const { t } = useTranslation()
|
||||
const [credentialSchema, setCredentialSchema] = useState<any>(null)
|
||||
const { team_credentials: credentialValue, name: collectionName } = collection
|
||||
const [tempCredential, setTempCredential] = React.useState<any>(credentialValue)
|
||||
useEffect(() => {
|
||||
fetchBuiltInToolCredentialSchema(collectionName).then((res) => {
|
||||
setCredentialSchema(toolCredentialToFormSchemas(res))
|
||||
const toolCredentialSchemas = toolCredentialToFormSchemas(res)
|
||||
const defaultCredentials = addDefaultValue(credentialValue, toolCredentialSchemas)
|
||||
setCredentialSchema(toolCredentialSchemas)
|
||||
setTempCredential(defaultCredentials)
|
||||
})
|
||||
}, [])
|
||||
const [tempCredential, setTempCredential] = React.useState<any>(credentialValue)
|
||||
|
||||
return (
|
||||
<Drawer
|
||||
|
|
|
|||