From cb79a90031d16b09ef1747e227cb1e174cc8a435 Mon Sep 17 00:00:00 2001
From: Onelevenvy <49232224+Onelevenvy@users.noreply.github.com>
Date: Mon, 18 Mar 2024 16:22:48 +0800
Subject: [PATCH 1/4] feat: Add tools for open weather search and image
generation using the Spark API. (#2845)
---
.../baichuan/text_embedding/text_embedding.py | 2 +-
.../builtin/openweather/_assets/icon.svg | 12 ++
.../builtin/openweather/openweather.py | 36 ++++
.../builtin/openweather/openweather.yaml | 29 ++++
.../builtin/openweather/tools/weather.py | 60 +++++++
.../builtin/openweather/tools/weather.yaml | 80 +++++++++
.../tools/provider/builtin/spark/__init__.py | 0
.../provider/builtin/spark/_assets/icon.svg | 5 +
.../tools/provider/builtin/spark/spark.py | 40 +++++
.../tools/provider/builtin/spark/spark.yaml | 59 +++++++
.../spark/tools/spark_img_generation.py | 154 ++++++++++++++++++
.../spark/tools/spark_img_generation.yaml | 36 ++++
sdks/python-client/dify_client/__init__.py | 2 +-
13 files changed, 513 insertions(+), 2 deletions(-)
create mode 100644 api/core/tools/provider/builtin/openweather/_assets/icon.svg
create mode 100644 api/core/tools/provider/builtin/openweather/openweather.py
create mode 100644 api/core/tools/provider/builtin/openweather/openweather.yaml
create mode 100644 api/core/tools/provider/builtin/openweather/tools/weather.py
create mode 100644 api/core/tools/provider/builtin/openweather/tools/weather.yaml
create mode 100644 api/core/tools/provider/builtin/spark/__init__.py
create mode 100644 api/core/tools/provider/builtin/spark/_assets/icon.svg
create mode 100644 api/core/tools/provider/builtin/spark/spark.py
create mode 100644 api/core/tools/provider/builtin/spark/spark.yaml
create mode 100644 api/core/tools/provider/builtin/spark/tools/spark_img_generation.py
create mode 100644 api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml
diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
index 535714f663..5ae90d54b5 100644
--- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
@@ -124,7 +124,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
elif err == 'insufficient_quota':
raise InsufficientAccountBalance(msg)
elif err == 'invalid_authentication':
- raise InvalidAuthenticationError(msg)
+ raise InvalidAuthenticationError(msg)
elif err and 'rate' in err:
raise RateLimitReachedError(msg)
elif err and 'internal' in err:
diff --git a/api/core/tools/provider/builtin/openweather/_assets/icon.svg b/api/core/tools/provider/builtin/openweather/_assets/icon.svg
new file mode 100644
index 0000000000..f06cd87e64
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/_assets/icon.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py
new file mode 100644
index 0000000000..a2827177a3
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/openweather.py
@@ -0,0 +1,36 @@
+import requests
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None):
+
+ url = "https://api.openweathermap.org/data/2.5/weather"
+ params = {"q": city, "appid": api_key, "units": units, "lang": language}
+
+ return requests.get(url, params=params)
+
+
+class OpenweatherProvider(BuiltinToolProviderController):
+ def _validate_credentials(self, credentials: dict) -> None:
+ try:
+ if "api_key" not in credentials or not credentials.get("api_key"):
+ raise ToolProviderCredentialValidationError(
+ "Open weather API key is required."
+ )
+ apikey = credentials.get("api_key")
+ try:
+ response = query_weather(api_key=apikey)
+ if response.status_code == 200:
+ pass
+ else:
+ raise ToolProviderCredentialValidationError(
+ (response.json()).get("info")
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(
+ "Open weather API Key is invalid. {}".format(e)
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(str(e))
diff --git a/api/core/tools/provider/builtin/openweather/openweather.yaml b/api/core/tools/provider/builtin/openweather/openweather.yaml
new file mode 100644
index 0000000000..60bb33c36d
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/openweather.yaml
@@ -0,0 +1,29 @@
+identity:
+ author: Onelevenvy
+ name: openweather
+ label:
+ en_US: Open weather query
+ zh_Hans: Open Weather
+ pt_BR: Consulta de clima open weather
+ description:
+ en_US: Weather query toolkit based on Open Weather
+ zh_Hans: 基于open weather的天气查询工具包
+ pt_BR: Kit de consulta de clima baseado no Open Weather
+ icon: icon.svg
+credentials_for_provider:
+ api_key:
+ type: secret-input
+ required: true
+ label:
+ en_US: API Key
+ zh_Hans: API Key
+ pt_BR: Fogo a chave
+ placeholder:
+ en_US: Please enter your open weather API Key
+ zh_Hans: 请输入你的open weather API Key
+ pt_BR: Insira sua chave de API open weather
+ help:
+ en_US: Get your API Key from open weather
+ zh_Hans: 从open weather获取您的 API Key
+ pt_BR: Obtenha sua chave de API do open weather
+ url: https://openweathermap.org
diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.py b/api/core/tools/provider/builtin/openweather/tools/weather.py
new file mode 100644
index 0000000000..536a3511f4
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/tools/weather.py
@@ -0,0 +1,60 @@
+import json
+from typing import Any, Union
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class OpenweatherTool(BuiltinTool):
+ def _invoke(
+ self, user_id: str, tool_parameters: dict[str, Any]
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ invoke tools
+ """
+ city = tool_parameters.get("city", "")
+ if not city:
+ return self.create_text_message("Please tell me your city")
+ if (
+ "api_key" not in self.runtime.credentials
+ or not self.runtime.credentials.get("api_key")
+ ):
+ return self.create_text_message("OpenWeather API key is required.")
+
+ units = tool_parameters.get("units", "metric")
+ lang = tool_parameters.get("lang", "zh_cn")
+ try:
+ # request URL
+ url = "https://api.openweathermap.org/data/2.5/weather"
+
+ # request parmas
+ params = {
+ "q": city,
+ "appid": self.runtime.credentials.get("api_key"),
+ "units": units,
+ "lang": lang,
+ }
+ response = requests.get(url, params=params)
+
+ if response.status_code == 200:
+
+ data = response.json()
+ return self.create_text_message(
+ self.summary(
+ user_id=user_id, content=json.dumps(data, ensure_ascii=False)
+ )
+ )
+ else:
+ error_message = {
+ "error": f"failed:{response.status_code}",
+ "data": response.text,
+ }
+ # return error
+ return json.dumps(error_message)
+
+ except Exception as e:
+ return self.create_text_message(
+ "Openweather API Key is invalid. {}".format(e)
+ )
diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.yaml b/api/core/tools/provider/builtin/openweather/tools/weather.yaml
new file mode 100644
index 0000000000..f2dae5c2df
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/tools/weather.yaml
@@ -0,0 +1,80 @@
+identity:
+ name: weather
+ author: Onelevenvy
+ label:
+ en_US: Open Weather Query
+ zh_Hans: 天气查询
+ pt_BR: Previsão do tempo
+ icon: icon.svg
+description:
+ human:
+ en_US: Weather forecast inquiry
+ zh_Hans: 天气查询
+ pt_BR: Inquérito sobre previsão meteorológica
+ llm: A tool when you want to ask about the weather or weather-related question
+parameters:
+ - name: city
+ type: string
+ required: true
+ label:
+ en_US: city
+ zh_Hans: 城市
+ pt_BR: cidade
+ human_description:
+ en_US: Target city for weather forecast query
+ zh_Hans: 天气预报查询的目标城市
+ pt_BR: Cidade de destino para consulta de previsão do tempo
+ llm_description: If you don't know you can extract the city name from the
+ question or you can reply:Please tell me your city. You have to extract
+ the Chinese city name from the question.If the input region is in Chinese
+ characters for China, it should be replaced with the corresponding English
+ name, such as '北京' for correct input is 'Beijing'
+ form: llm
+ - name: lang
+ type: select
+ required: true
+ human_description:
+ en_US: language
+ zh_Hans: 语言
+ pt_BR: language
+ label:
+ en_US: language
+ zh_Hans: 语言
+ pt_BR: language
+ form: form
+ options:
+ - value: zh_cn
+ label:
+ en_US: cn
+ zh_Hans: 中国
+ pt_BR: cn
+ - value: en_us
+ label:
+ en_US: usa
+ zh_Hans: 美国
+ pt_BR: usa
+ default: zh_cn
+ - name: units
+ type: select
+ required: true
+ human_description:
+ en_US: units for temperature
+ zh_Hans: 温度单位
+ pt_BR: units for temperature
+ label:
+ en_US: units
+ zh_Hans: 单位
+ pt_BR: units
+ form: form
+ options:
+ - value: metric
+ label:
+ en_US: metric
+ zh_Hans: ℃
+ pt_BR: metric
+ - value: imperial
+ label:
+ en_US: imperial
+ zh_Hans: ℉
+ pt_BR: imperial
+ default: metric
diff --git a/api/core/tools/provider/builtin/spark/__init__.py b/api/core/tools/provider/builtin/spark/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/tools/provider/builtin/spark/_assets/icon.svg b/api/core/tools/provider/builtin/spark/_assets/icon.svg
new file mode 100644
index 0000000000..ef0a9131a4
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/_assets/icon.svg
@@ -0,0 +1,5 @@
+
diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py
new file mode 100644
index 0000000000..cb8e69a59f
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/spark.py
@@ -0,0 +1,40 @@
+import json
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class SparkProvider(BuiltinToolProviderController):
+ def _validate_credentials(self, credentials: dict) -> None:
+ try:
+ if "APPID" not in credentials or not credentials.get("APPID"):
+ raise ToolProviderCredentialValidationError("APPID is required.")
+ if "APISecret" not in credentials or not credentials.get("APISecret"):
+ raise ToolProviderCredentialValidationError("APISecret is required.")
+ if "APIKey" not in credentials or not credentials.get("APIKey"):
+ raise ToolProviderCredentialValidationError("APIKey is required.")
+
+ appid = credentials.get("APPID")
+ apisecret = credentials.get("APISecret")
+ apikey = credentials.get("APIKey")
+ prompt = "a cute black dog"
+
+ try:
+ response = spark_response(prompt, appid, apikey, apisecret)
+ data = json.loads(response)
+ code = data["header"]["code"]
+
+ if code == 0:
+ # 0 success,
+ pass
+ else:
+ raise ToolProviderCredentialValidationError(
+ "image generate error, code:{}".format(code)
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(
+ "APPID APISecret APIKey is invalid. {}".format(e)
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(str(e))
diff --git a/api/core/tools/provider/builtin/spark/spark.yaml b/api/core/tools/provider/builtin/spark/spark.yaml
new file mode 100644
index 0000000000..f2b9c89e96
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/spark.yaml
@@ -0,0 +1,59 @@
+identity:
+ author: Onelevenvy
+ name: spark
+ label:
+ en_US: Spark
+ zh_Hans: 讯飞星火
+ pt_BR: Spark
+ description:
+ en_US: Spark Platform Toolkit
+ zh_Hans: 讯飞星火平台工具
+ pt_BR: Pacote de Ferramentas da Plataforma Spark
+ icon: icon.svg
+credentials_for_provider:
+ APPID:
+ type: secret-input
+ required: true
+ label:
+ en_US: Spark APPID
+ zh_Hans: APPID
+ pt_BR: Spark APPID
+ help:
+ en_US: Please input your APPID
+ zh_Hans: 请输入你的 APPID
+ pt_BR: Please input your APPID
+ placeholder:
+ en_US: Please input your APPID
+ zh_Hans: 请输入你的 APPID
+ pt_BR: Please input your APPID
+ APISecret:
+ type: secret-input
+ required: true
+ label:
+ en_US: Spark APISecret
+ zh_Hans: APISecret
+ pt_BR: Spark APISecret
+ help:
+ en_US: Please input your Spark APISecret
+ zh_Hans: 请输入你的 APISecret
+ pt_BR: Please input your Spark APISecret
+ placeholder:
+ en_US: Please input your Spark APISecret
+ zh_Hans: 请输入你的 APISecret
+ pt_BR: Please input your Spark APISecret
+ APIKey:
+ type: secret-input
+ required: true
+ label:
+ en_US: Spark APIKey
+ zh_Hans: APIKey
+ pt_BR: Spark APIKey
+ help:
+ en_US: Please input your Spark APIKey
+ zh_Hans: 请输入你的 APIKey
+ pt_BR: Please input your Spark APIKey
+ placeholder:
+ en_US: Please input your Spark APIKey
+ zh_Hans: 请输入你的 APIKey
+ pt_BR: Please input Spark APIKey
+ url: https://console.xfyun.cn/services
diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py
new file mode 100644
index 0000000000..a977af2b76
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py
@@ -0,0 +1,154 @@
+import base64
+import hashlib
+import hmac
+import json
+from base64 import b64decode
+from datetime import datetime
+from time import mktime
+from typing import Any, Union
+from urllib.parse import urlencode
+from wsgiref.handlers import format_date_time
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class AssembleHeaderException(Exception):
+ def __init__(self, msg):
+ self.message = msg
+
+
+class Url:
+ def __init__(this, host, path, schema):
+ this.host = host
+ this.path = path
+ this.schema = schema
+
+
+# calculate sha256 and encode to base64
+def sha256base64(data):
+ sha256 = hashlib.sha256()
+ sha256.update(data)
+ digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8")
+ return digest
+
+
+def parse_url(requset_url):
+ stidx = requset_url.index("://")
+ host = requset_url[stidx + 3 :]
+ schema = requset_url[: stidx + 3]
+ edidx = host.index("/")
+ if edidx <= 0:
+ raise AssembleHeaderException("invalid request url:" + requset_url)
+ path = host[edidx:]
+ host = host[:edidx]
+ u = Url(host, path, schema)
+ return u
+
+def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
+ u = parse_url(requset_url)
+ host = u.host
+ path = u.path
+ now = datetime.now()
+ date = format_date_time(mktime(now.timetuple()))
+ signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
+ host, date, method, path
+ )
+ signature_sha = hmac.new(
+ api_secret.encode("utf-8"),
+ signature_origin.encode("utf-8"),
+ digestmod=hashlib.sha256,
+ ).digest()
+ signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
+ authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"'
+
+ authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
+ encoding="utf-8"
+ )
+ values = {"host": host, "date": date, "authorization": authorization}
+
+ return requset_url + "?" + urlencode(values)
+
+
+def get_body(appid, text):
+ body = {
+ "header": {"app_id": appid, "uid": "123456789"},
+ "parameter": {
+ "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}
+ },
+ "payload": {"message": {"text": [{"role": "user", "content": text}]}},
+ }
+ return body
+
+
+def spark_response(text, appid, apikey, apisecret):
+ host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti"
+ url = assemble_ws_auth_url(
+ host, method="POST", api_key=apikey, api_secret=apisecret
+ )
+ content = get_body(appid, text)
+ response = requests.post(
+ url, json=content, headers={"content-type": "application/json"}
+ ).text
+ return response
+
+
+class SparkImgGeneratorTool(BuiltinTool):
+ def _invoke(
+ self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ invoke tools
+ """
+
+ if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get(
+ "APPID"
+ ):
+ return self.create_text_message("APPID is required.")
+ if (
+ "APISecret" not in self.runtime.credentials
+ or not self.runtime.credentials.get("APISecret")
+ ):
+ return self.create_text_message("APISecret is required.")
+ if (
+ "APIKey" not in self.runtime.credentials
+ or not self.runtime.credentials.get("APIKey")
+ ):
+ return self.create_text_message("APIKey is required.")
+
+ prompt = tool_parameters.get("prompt", "")
+ if not prompt:
+ return self.create_text_message("Please input prompt")
+ res = self.img_generation(prompt)
+ result = []
+ for image in res:
+ result.append(
+ self.create_blob_message(
+ blob=b64decode(image["base64_image"]),
+ meta={"mime_type": "image/png"},
+ save_as=self.VARIABLE_KEY.IMAGE.value,
+ )
+ )
+ return result
+
+ def img_generation(self, prompt):
+ response = spark_response(
+ text=prompt,
+ appid=self.runtime.credentials.get("APPID"),
+ apikey=self.runtime.credentials.get("APIKey"),
+ apisecret=self.runtime.credentials.get("APISecret"),
+ )
+ data = json.loads(response)
+ code = data["header"]["code"]
+ if code != 0:
+ return self.create_text_message(f"error: {code}, {data}")
+ else:
+ text = data["payload"]["choices"]["text"]
+ image_content = text[0]
+ image_base = image_content["content"]
+ json_data = {"base64_image": image_base}
+ return [json_data]
diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml
new file mode 100644
index 0000000000..d44bbc9564
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml
@@ -0,0 +1,36 @@
+identity:
+ name: spark_img_generation
+ author: Onelevenvy
+ label:
+ en_US: Spark Image Generation
+ zh_Hans: 图片生成
+ pt_BR: Geração de imagens Spark
+ icon: icon.svg
+ description:
+ en_US: Spark Image Generation
+ zh_Hans: 图片生成
+ pt_BR: Geração de imagens Spark
+description:
+ human:
+ en_US: Generate images based on user input, with image generation API
+ provided by Spark
+ zh_Hans: 根据用户的输入生成图片,由讯飞星火提供图片生成api
+ pt_BR: Gerar imagens com base na entrada do usuário, com API de geração
+ de imagem fornecida pela Spark
+ llm: spark_img_generation is a tool used to generate images from text
+parameters:
+ - name: prompt
+ type: string
+ required: true
+ label:
+ en_US: Prompt
+ zh_Hans: 提示词
+ pt_BR: Prompt
+ human_description:
+ en_US: Image prompt
+ zh_Hans: 图像提示词
+ pt_BR: Image prompt
+ llm_description: Image prompt of spark_img_generation tooll, you should
+ describe the image you want to generate as a list of words as possible
+ as detailed
+ form: llm
diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py
index 6fa9d190e5..6ef0017fee 100644
--- a/sdks/python-client/dify_client/__init__.py
+++ b/sdks/python-client/dify_client/__init__.py
@@ -1 +1 @@
-from dify_client.client import ChatClient, CompletionClient, DifyClient
+from dify_client.client import ChatClient, CompletionClient, DifyClient
\ No newline at end of file
From 95b74c211df5e8191924c98f5cc1627a87343d9c Mon Sep 17 00:00:00 2001
From: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
Date: Mon, 18 Mar 2024 16:55:26 +0800
Subject: [PATCH 2/4] Feat/support tool credentials bool schema (#2875)
---
api/core/tools/entities/tool_entities.py | 3 +-
api/core/tools/provider/builtin/bing/bing.py | 5 +-
.../tools/provider/builtin/bing/bing.yaml | 60 +++++++
.../builtin/bing/tools/bing_web_search.py | 148 +++++++++++++-----
.../tools/provider/builtin_tool_provider.py | 23 ++-
api/services/tools_manage_service.py | 6 +-
.../setting/build-in/config-credentials.tsx | 9 +-
7 files changed, 204 insertions(+), 50 deletions(-)
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index f7a61b0b0c..437f871864 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -171,6 +171,7 @@ class ToolProviderCredentials(BaseModel):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
+ BOOLEAN = "boolean"
@classmethod
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
@@ -192,7 +193,7 @@ class ToolProviderCredentials(BaseModel):
name: str = Field(..., description="The name of the credentials")
type: CredentialsType = Field(..., description="The type of the credentials")
required: bool = False
- default: Optional[str] = None
+ default: Optional[Union[int, str]] = None
options: Optional[list[ToolCredentialsOption]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
diff --git a/api/core/tools/provider/builtin/bing/bing.py b/api/core/tools/provider/builtin/bing/bing.py
index ff131b26cd..6e62abfc10 100644
--- a/api/core/tools/provider/builtin/bing/bing.py
+++ b/api/core/tools/provider/builtin/bing/bing.py
@@ -12,12 +12,11 @@ class BingProvider(BuiltinToolProviderController):
meta={
"credentials": credentials,
}
- ).invoke(
- user_id='',
+ ).validate_credentials(
+ credentials=credentials,
tool_parameters={
"query": "test",
"result_type": "link",
- "enable_webpages": True,
},
)
except Exception as e:
diff --git a/api/core/tools/provider/builtin/bing/bing.yaml b/api/core/tools/provider/builtin/bing/bing.yaml
index 9df836929c..35cd729208 100644
--- a/api/core/tools/provider/builtin/bing/bing.yaml
+++ b/api/core/tools/provider/builtin/bing/bing.yaml
@@ -43,3 +43,63 @@ credentials_for_provider:
zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search"
pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
default: https://api.bing.microsoft.com/v7.0/search
+ allow_entities:
+ type: boolean
+ required: false
+ label:
+ en_US: Allow Entities Search
+ zh_Hans: 支持实体搜索
+ pt_BR: Allow Entities Search
+ help:
+ en_US: Does your subscription plan allow entity search
+ zh_Hans: 您的订阅计划是否支持实体搜索
+ pt_BR: Does your subscription plan allow entity search
+ default: true
+ allow_web_pages:
+ type: boolean
+ required: false
+ label:
+ en_US: Allow Web Pages Search
+ zh_Hans: 支持网页搜索
+ pt_BR: Allow Web Pages Search
+ help:
+ en_US: Does your subscription plan allow web pages search
+ zh_Hans: 您的订阅计划是否支持网页搜索
+ pt_BR: Does your subscription plan allow web pages search
+ default: true
+ allow_computation:
+ type: boolean
+ required: false
+ label:
+ en_US: Allow Computation Search
+ zh_Hans: 支持计算搜索
+ pt_BR: Allow Computation Search
+ help:
+ en_US: Does your subscription plan allow computation search
+ zh_Hans: 您的订阅计划是否支持计算搜索
+ pt_BR: Does your subscription plan allow computation search
+ default: false
+ allow_news:
+ type: boolean
+ required: false
+ label:
+ en_US: Allow News Search
+ zh_Hans: 支持新闻搜索
+ pt_BR: Allow News Search
+ help:
+ en_US: Does your subscription plan allow news search
+ zh_Hans: 您的订阅计划是否支持新闻搜索
+ pt_BR: Does your subscription plan allow news search
+ default: false
+ allow_related_searches:
+ type: boolean
+ required: false
+ label:
+ en_US: Allow Related Searches
+ zh_Hans: 支持相关搜索
+ pt_BR: Allow Related Searches
+ help:
+ en_US: Does your subscription plan allow related searches
+ zh_Hans: 您的订阅计划是否支持相关搜索
+ pt_BR: Does your subscription plan allow related searches
+ default: false
diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py
index 7b740293dd..8f11d2173c 100644
--- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py
+++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py
@@ -10,53 +10,23 @@ from core.tools.tool.builtin_tool import BuiltinTool
class BingSearchTool(BuiltinTool):
url = 'https://api.bing.microsoft.com/v7.0/search'
- def _invoke(self,
- user_id: str,
- tool_parameters: dict[str, Any],
- ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ def _invoke_bing(self,
+ user_id: str,
+ subscription_key: str, query: str, limit: int,
+ result_type: str, market: str, lang: str,
+ filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
- invoke tools
+ invoke bing search
"""
-
- key = self.runtime.credentials.get('subscription_key', None)
- if not key:
- raise Exception('subscription_key is required')
-
- server_url = self.runtime.credentials.get('server_url', None)
- if not server_url:
- server_url = self.url
-
- query = tool_parameters.get('query', None)
- if not query:
- raise Exception('query is required')
-
- limit = min(tool_parameters.get('limit', 5), 10)
- result_type = tool_parameters.get('result_type', 'text') or 'text'
-
- market = tool_parameters.get('market', 'US')
- lang = tool_parameters.get('language', 'en')
- filter = []
-
- if tool_parameters.get('enable_computation', False):
- filter.append('Computation')
- if tool_parameters.get('enable_entities', False):
- filter.append('Entities')
- if tool_parameters.get('enable_news', False):
- filter.append('News')
- if tool_parameters.get('enable_related_search', False):
- filter.append('RelatedSearches')
- if tool_parameters.get('enable_webpages', False):
- filter.append('WebPages')
-
market_code = f'{lang}-{market}'
accept_language = f'{lang},{market_code};q=0.9'
headers = {
- 'Ocp-Apim-Subscription-Key': key,
+ 'Ocp-Apim-Subscription-Key': subscription_key,
'Accept-Language': accept_language
}
query = quote(query)
- server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filter)}'
+ server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
response = get(server_url, headers=headers)
if response.status_code != 200:
@@ -124,3 +94,105 @@ class BingSearchTool(BuiltinTool):
text += f'{related["displayText"]} - {related["webSearchUrl"]}\n'
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
+
+
+ def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
+ key = credentials.get('subscription_key', None)
+ if not key:
+ raise Exception('subscription_key is required')
+
+ server_url = credentials.get('server_url', None)
+ if not server_url:
+ server_url = self.url
+
+ query = tool_parameters.get('query', None)
+ if not query:
+ raise Exception('query is required')
+
+ limit = min(tool_parameters.get('limit', 5), 10)
+ result_type = tool_parameters.get('result_type', 'text') or 'text'
+
+ market = tool_parameters.get('market', 'US')
+ lang = tool_parameters.get('language', 'en')
+ filter = []
+
+ if credentials.get('allow_entities', False):
+ filter.append('Entities')
+
+ if credentials.get('allow_computation', False):
+ filter.append('Computation')
+
+ if credentials.get('allow_news', False):
+ filter.append('News')
+
+ if credentials.get('allow_related_searches', False):
+ filter.append('RelatedSearches')
+
+ if credentials.get('allow_web_pages', False):
+ filter.append('WebPages')
+
+ if not filter:
+ raise Exception('At least one filter is required')
+
+ self._invoke_bing(
+ user_id='test',
+ subscription_key=key,
+ query=query,
+ limit=limit,
+ result_type=result_type,
+ market=market,
+ lang=lang,
+ filters=filter
+ )
+
+ def _invoke(self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ invoke tools
+ """
+
+ key = self.runtime.credentials.get('subscription_key', None)
+ if not key:
+ raise Exception('subscription_key is required')
+
+ server_url = self.runtime.credentials.get('server_url', None)
+ if not server_url:
+ server_url = self.url
+
+ query = tool_parameters.get('query', None)
+ if not query:
+ raise Exception('query is required')
+
+ limit = min(tool_parameters.get('limit', 5), 10)
+ result_type = tool_parameters.get('result_type', 'text') or 'text'
+
+ market = tool_parameters.get('market', 'US')
+ lang = tool_parameters.get('language', 'en')
+ filter = []
+
+ if tool_parameters.get('enable_computation', False):
+ filter.append('Computation')
+ if tool_parameters.get('enable_entities', False):
+ filter.append('Entities')
+ if tool_parameters.get('enable_news', False):
+ filter.append('News')
+ if tool_parameters.get('enable_related_search', False):
+ filter.append('RelatedSearches')
+ if tool_parameters.get('enable_webpages', False):
+ filter.append('WebPages')
+
+ if not filter:
+ raise Exception('At least one filter is required')
+
+ return self._invoke_bing(
+ user_id=user_id,
+ subscription_key=key,
+ query=query,
+ limit=limit,
+ result_type=result_type,
+ market=market,
+ lang=lang,
+ filters=filter
+ )
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py
index 93e7d5a39e..824f91c822 100644
--- a/api/core/tools/provider/builtin_tool_provider.py
+++ b/api/core/tools/provider/builtin_tool_provider.py
@@ -246,8 +246,27 @@ class BuiltinToolProviderController(ToolProviderController):
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}')
-
- if credentials[credential_name]:
+ elif credential_schema.type == ToolProviderCredentials.CredentialsType.BOOLEAN:
+ if isinstance(credentials[credential_name], bool):
+ pass
+ elif isinstance(credentials[credential_name], str):
+ if credentials[credential_name].lower() == 'true':
+ credentials[credential_name] = True
+ elif credentials[credential_name].lower() == 'false':
+ credentials[credential_name] = False
+ else:
+ raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
+ elif isinstance(credentials[credential_name], int):
+ if credentials[credential_name] == 1:
+ credentials[credential_name] = True
+ elif credentials[credential_name] == 0:
+ credentials[credential_name] = False
+ else:
+ raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
+ else:
+ raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean')
+
+ if credentials[credential_name] or credentials[credential_name] == False:
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py
index ff618e5d2b..70c6a44459 100644
--- a/api/services/tools_manage_service.py
+++ b/api/services/tools_manage_service.py
@@ -138,9 +138,9 @@ class ToolManageService:
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name)
- return [
- v.to_dict() for _, v in (provider.credentials_schema or {}).items()
- ]
+ return json.loads(serialize_base_model_array([
+ v for _, v in (provider.credentials_schema or {}).items()
+ ]))
@staticmethod
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
diff --git a/web/app/components/tools/setting/build-in/config-credentials.tsx b/web/app/components/tools/setting/build-in/config-credentials.tsx
index d5365001c8..1a3c8f015a 100644
--- a/web/app/components/tools/setting/build-in/config-credentials.tsx
+++ b/web/app/components/tools/setting/build-in/config-credentials.tsx
@@ -3,7 +3,7 @@ import type { FC } from 'react'
import React, { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
-import { toolCredentialToFormSchemas } from '../../utils/to-form-schema'
+import { addDefaultValue, toolCredentialToFormSchemas } from '../../utils/to-form-schema'
import type { Collection } from '../../types'
import Drawer from '@/app/components/base/drawer-plus'
import Button from '@/app/components/base/button'
@@ -28,12 +28,15 @@ const ConfigCredential: FC = ({
const { t } = useTranslation()
const [credentialSchema, setCredentialSchema] = useState(null)
const { team_credentials: credentialValue, name: collectionName } = collection
+ const [tempCredential, setTempCredential] = React.useState(credentialValue)
useEffect(() => {
fetchBuiltInToolCredentialSchema(collectionName).then((res) => {
- setCredentialSchema(toolCredentialToFormSchemas(res))
+ const toolCredentialSchemas = toolCredentialToFormSchemas(res)
+ const defaultCredentials = addDefaultValue(credentialValue, toolCredentialSchemas)
+ setCredentialSchema(toolCredentialSchemas)
+ setTempCredential(defaultCredentials)
})
}, [])
- const [tempCredential, setTempCredential] = React.useState(credentialValue)
return (
Date: Mon, 18 Mar 2024 17:01:25 +0800
Subject: [PATCH 3/4] fix/Add isModel flag to AgentTools component (#2876)
---
.../config/agent/agent-tools/index.tsx | 1 +
.../agent/agent-tools/setting-built-in-tool.tsx | 15 ++++++++++-----
2 files changed, 11 insertions(+), 5 deletions(-)
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx
index 95858d9540..b92ff94983 100644
--- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx
+++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx
@@ -210,6 +210,7 @@ const AgentTools: FC = () => {
setting={currentTool?.tool_parameters as any}
collection={currentTool?.collection as Collection}
isBuiltIn={currentTool?.collection?.type === CollectionType.builtIn}
+ isModel={currentTool?.collection?.type === CollectionType.model}
onSave={handleToolSettingChange}
onHide={() => setIsShowSettingTool(false)}
/>)
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx
index 378054aae6..9eb2657fcf 100644
--- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx
+++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx
@@ -58,11 +58,16 @@ const SettingBuiltInTool: FC = ({
(async () => {
setIsLoading(true)
try {
- const list = isBuiltIn
- ? await fetchBuiltInToolList(collection.name)
- : isModel
- ? await fetchModelToolList(collection.name)
- : await fetchCustomToolList(collection.name)
+ const list = await new Promise((resolve) => {
+ (async function () {
+ if (isModel)
+ resolve(await fetchModelToolList(collection.name))
+ else if (isBuiltIn)
+ resolve(await fetchBuiltInToolList(collection.name))
+ else
+ resolve(await fetchCustomToolList(collection.name))
+ }())
+ })
setTools(list)
const currTool = list.find(tool => tool.name === toolName)
if (currTool) {
From 4834eae887931db0c51ca89db1e6e8c84fd2643a Mon Sep 17 00:00:00 2001
From: Jyong <76649700+JohnJyong@users.noreply.github.com>
Date: Mon, 18 Mar 2024 17:18:52 +0800
Subject: [PATCH 4/4] fix enable annotation reply when collection is None
(#2877)
Co-authored-by: jyong
---
api/tasks/annotation/enable_annotation_reply_task.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py
index f3260bbb50..666fa8692f 100644
--- a/api/tasks/annotation/enable_annotation_reply_task.py
+++ b/api/tasks/annotation/enable_annotation_reply_task.py
@@ -89,7 +89,7 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_
logging.info(
click.style('Delete annotation index error: {}'.format(str(e)),
fg='red'))
- vector.add_texts(documents)
+ vector.create(documents)
db.session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, 'completed')
end_at = time.perf_counter()