From cb79a90031d16b09ef1747e227cb1e174cc8a435 Mon Sep 17 00:00:00 2001 From: Onelevenvy <49232224+Onelevenvy@users.noreply.github.com> Date: Mon, 18 Mar 2024 16:22:48 +0800 Subject: [PATCH 1/4] feat: Add tools for open weather search and image generation using the Spark API. (#2845) --- .../baichuan/text_embedding/text_embedding.py | 2 +- .../builtin/openweather/_assets/icon.svg | 12 ++ .../builtin/openweather/openweather.py | 36 ++++ .../builtin/openweather/openweather.yaml | 29 ++++ .../builtin/openweather/tools/weather.py | 60 +++++++ .../builtin/openweather/tools/weather.yaml | 80 +++++++++ .../tools/provider/builtin/spark/__init__.py | 0 .../provider/builtin/spark/_assets/icon.svg | 5 + .../tools/provider/builtin/spark/spark.py | 40 +++++ .../tools/provider/builtin/spark/spark.yaml | 59 +++++++ .../spark/tools/spark_img_generation.py | 154 ++++++++++++++++++ .../spark/tools/spark_img_generation.yaml | 36 ++++ sdks/python-client/dify_client/__init__.py | 2 +- 13 files changed, 513 insertions(+), 2 deletions(-) create mode 100644 api/core/tools/provider/builtin/openweather/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/openweather/openweather.py create mode 100644 api/core/tools/provider/builtin/openweather/openweather.yaml create mode 100644 api/core/tools/provider/builtin/openweather/tools/weather.py create mode 100644 api/core/tools/provider/builtin/openweather/tools/weather.yaml create mode 100644 api/core/tools/provider/builtin/spark/__init__.py create mode 100644 api/core/tools/provider/builtin/spark/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/spark/spark.py create mode 100644 api/core/tools/provider/builtin/spark/spark.yaml create mode 100644 api/core/tools/provider/builtin/spark/tools/spark_img_generation.py create mode 100644 api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 535714f663..5ae90d54b5 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -124,7 +124,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): elif err == 'insufficient_quota': raise InsufficientAccountBalance(msg) elif err == 'invalid_authentication': - raise InvalidAuthenticationError(msg) + raise InvalidAuthenticationError(msg) elif err and 'rate' in err: raise RateLimitReachedError(msg) elif err and 'internal' in err: diff --git a/api/core/tools/provider/builtin/openweather/_assets/icon.svg b/api/core/tools/provider/builtin/openweather/_assets/icon.svg new file mode 100644 index 0000000000..f06cd87e64 --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/_assets/icon.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py new file mode 100644 index 0000000000..a2827177a3 --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/openweather.py @@ -0,0 +1,36 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None): + + url = "https://api.openweathermap.org/data/2.5/weather" + params = {"q": city, "appid": api_key, "units": units, "lang": language} + + return requests.get(url, params=params) + + +class OpenweatherProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + if "api_key" not in credentials or not credentials.get("api_key"): + raise ToolProviderCredentialValidationError( + "Open weather API key is required." + ) + apikey = credentials.get("api_key") + try: + response = query_weather(api_key=apikey) + if response.status_code == 200: + pass + else: + raise ToolProviderCredentialValidationError( + (response.json()).get("info") + ) + except Exception as e: + raise ToolProviderCredentialValidationError( + "Open weather API Key is invalid. {}".format(e) + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/openweather/openweather.yaml b/api/core/tools/provider/builtin/openweather/openweather.yaml new file mode 100644 index 0000000000..60bb33c36d --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/openweather.yaml @@ -0,0 +1,29 @@ +identity: + author: Onelevenvy + name: openweather + label: + en_US: Open weather query + zh_Hans: Open Weather + pt_BR: Consulta de clima open weather + description: + en_US: Weather query toolkit based on Open Weather + zh_Hans: 基于open weather的天气查询工具包 + pt_BR: Kit de consulta de clima baseado no Open Weather + icon: icon.svg +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API Key + zh_Hans: API Key + pt_BR: Fogo a chave + placeholder: + en_US: Please enter your open weather API Key + zh_Hans: 请输入你的open weather API Key + pt_BR: Insira sua chave de API open weather + help: + en_US: Get your API Key from open weather + zh_Hans: 从open weather获取您的 API Key + pt_BR: Obtenha sua chave de API do open weather + url: https://openweathermap.org diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.py b/api/core/tools/provider/builtin/openweather/tools/weather.py new file mode 100644 index 0000000000..536a3511f4 --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/tools/weather.py @@ -0,0 +1,60 @@ +import json +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class OpenweatherTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + city = tool_parameters.get("city", "") + if not city: + return self.create_text_message("Please tell me your city") + if ( + "api_key" not in self.runtime.credentials + or not self.runtime.credentials.get("api_key") + ): + return self.create_text_message("OpenWeather API key is required.") + + units = tool_parameters.get("units", "metric") + lang = tool_parameters.get("lang", "zh_cn") + try: + # request URL + url = "https://api.openweathermap.org/data/2.5/weather" + + # request parmas + params = { + "q": city, + "appid": self.runtime.credentials.get("api_key"), + "units": units, + "lang": lang, + } + response = requests.get(url, params=params) + + if response.status_code == 200: + + data = response.json() + return self.create_text_message( + self.summary( + user_id=user_id, content=json.dumps(data, ensure_ascii=False) + ) + ) + else: + error_message = { + "error": f"failed:{response.status_code}", + "data": response.text, + } + # return error + return json.dumps(error_message) + + except Exception as e: + return self.create_text_message( + "Openweather API Key is invalid. {}".format(e) + ) diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.yaml b/api/core/tools/provider/builtin/openweather/tools/weather.yaml new file mode 100644 index 0000000000..f2dae5c2df --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/tools/weather.yaml @@ -0,0 +1,80 @@ +identity: + name: weather + author: Onelevenvy + label: + en_US: Open Weather Query + zh_Hans: 天气查询 + pt_BR: Previsão do tempo + icon: icon.svg +description: + human: + en_US: Weather forecast inquiry + zh_Hans: 天气查询 + pt_BR: Inquérito sobre previsão meteorológica + llm: A tool when you want to ask about the weather or weather-related question +parameters: + - name: city + type: string + required: true + label: + en_US: city + zh_Hans: 城市 + pt_BR: cidade + human_description: + en_US: Target city for weather forecast query + zh_Hans: 天气预报查询的目标城市 + pt_BR: Cidade de destino para consulta de previsão do tempo + llm_description: If you don't know you can extract the city name from the + question or you can reply:Please tell me your city. You have to extract + the Chinese city name from the question.If the input region is in Chinese + characters for China, it should be replaced with the corresponding English + name, such as '北京' for correct input is 'Beijing' + form: llm + - name: lang + type: select + required: true + human_description: + en_US: language + zh_Hans: 语言 + pt_BR: language + label: + en_US: language + zh_Hans: 语言 + pt_BR: language + form: form + options: + - value: zh_cn + label: + en_US: cn + zh_Hans: 中国 + pt_BR: cn + - value: en_us + label: + en_US: usa + zh_Hans: 美国 + pt_BR: usa + default: zh_cn + - name: units + type: select + required: true + human_description: + en_US: units for temperature + zh_Hans: 温度单位 + pt_BR: units for temperature + label: + en_US: units + zh_Hans: 单位 + pt_BR: units + form: form + options: + - value: metric + label: + en_US: metric + zh_Hans: ℃ + pt_BR: metric + - value: imperial + label: + en_US: imperial + zh_Hans: ℉ + pt_BR: imperial + default: metric diff --git a/api/core/tools/provider/builtin/spark/__init__.py b/api/core/tools/provider/builtin/spark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/provider/builtin/spark/_assets/icon.svg b/api/core/tools/provider/builtin/spark/_assets/icon.svg new file mode 100644 index 0000000000..ef0a9131a4 --- /dev/null +++ b/api/core/tools/provider/builtin/spark/_assets/icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py new file mode 100644 index 0000000000..cb8e69a59f --- /dev/null +++ b/api/core/tools/provider/builtin/spark/spark.py @@ -0,0 +1,40 @@ +import json + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SparkProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + if "APPID" not in credentials or not credentials.get("APPID"): + raise ToolProviderCredentialValidationError("APPID is required.") + if "APISecret" not in credentials or not credentials.get("APISecret"): + raise ToolProviderCredentialValidationError("APISecret is required.") + if "APIKey" not in credentials or not credentials.get("APIKey"): + raise ToolProviderCredentialValidationError("APIKey is required.") + + appid = credentials.get("APPID") + apisecret = credentials.get("APISecret") + apikey = credentials.get("APIKey") + prompt = "a cute black dog" + + try: + response = spark_response(prompt, appid, apikey, apisecret) + data = json.loads(response) + code = data["header"]["code"] + + if code == 0: + # 0 success, + pass + else: + raise ToolProviderCredentialValidationError( + "image generate error, code:{}".format(code) + ) + except Exception as e: + raise ToolProviderCredentialValidationError( + "APPID APISecret APIKey is invalid. {}".format(e) + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/spark/spark.yaml b/api/core/tools/provider/builtin/spark/spark.yaml new file mode 100644 index 0000000000..f2b9c89e96 --- /dev/null +++ b/api/core/tools/provider/builtin/spark/spark.yaml @@ -0,0 +1,59 @@ +identity: + author: Onelevenvy + name: spark + label: + en_US: Spark + zh_Hans: 讯飞星火 + pt_BR: Spark + description: + en_US: Spark Platform Toolkit + zh_Hans: 讯飞星火平台工具 + pt_BR: Pacote de Ferramentas da Plataforma Spark + icon: icon.svg +credentials_for_provider: + APPID: + type: secret-input + required: true + label: + en_US: Spark APPID + zh_Hans: APPID + pt_BR: Spark APPID + help: + en_US: Please input your APPID + zh_Hans: 请输入你的 APPID + pt_BR: Please input your APPID + placeholder: + en_US: Please input your APPID + zh_Hans: 请输入你的 APPID + pt_BR: Please input your APPID + APISecret: + type: secret-input + required: true + label: + en_US: Spark APISecret + zh_Hans: APISecret + pt_BR: Spark APISecret + help: + en_US: Please input your Spark APISecret + zh_Hans: 请输入你的 APISecret + pt_BR: Please input your Spark APISecret + placeholder: + en_US: Please input your Spark APISecret + zh_Hans: 请输入你的 APISecret + pt_BR: Please input your Spark APISecret + APIKey: + type: secret-input + required: true + label: + en_US: Spark APIKey + zh_Hans: APIKey + pt_BR: Spark APIKey + help: + en_US: Please input your Spark APIKey + zh_Hans: 请输入你的 APIKey + pt_BR: Please input your Spark APIKey + placeholder: + en_US: Please input your Spark APIKey + zh_Hans: 请输入你的 APIKey + pt_BR: Please input Spark APIKey + url: https://console.xfyun.cn/services diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py new file mode 100644 index 0000000000..a977af2b76 --- /dev/null +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -0,0 +1,154 @@ +import base64 +import hashlib +import hmac +import json +from base64 import b64decode +from datetime import datetime +from time import mktime +from typing import Any, Union +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class AssembleHeaderException(Exception): + def __init__(self, msg): + self.message = msg + + +class Url: + def __init__(this, host, path, schema): + this.host = host + this.path = path + this.schema = schema + + +# calculate sha256 and encode to base64 +def sha256base64(data): + sha256 = hashlib.sha256() + sha256.update(data) + digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8") + return digest + + +def parse_url(requset_url): + stidx = requset_url.index("://") + host = requset_url[stidx + 3 :] + schema = requset_url[: stidx + 3] + edidx = host.index("/") + if edidx <= 0: + raise AssembleHeaderException("invalid request url:" + requset_url) + path = host[edidx:] + host = host[:edidx] + u = Url(host, path, schema) + return u + +def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""): + u = parse_url(requset_url) + host = u.host + path = u.path + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format( + host, date, method, path + ) + signature_sha = hmac.new( + api_secret.encode("utf-8"), + signature_origin.encode("utf-8"), + digestmod=hashlib.sha256, + ).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") + authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' + + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( + encoding="utf-8" + ) + values = {"host": host, "date": date, "authorization": authorization} + + return requset_url + "?" + urlencode(values) + + +def get_body(appid, text): + body = { + "header": {"app_id": appid, "uid": "123456789"}, + "parameter": { + "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096} + }, + "payload": {"message": {"text": [{"role": "user", "content": text}]}}, + } + return body + + +def spark_response(text, appid, apikey, apisecret): + host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti" + url = assemble_ws_auth_url( + host, method="POST", api_key=apikey, api_secret=apisecret + ) + content = get_body(appid, text) + response = requests.post( + url, json=content, headers={"content-type": "application/json"} + ).text + return response + + +class SparkImgGeneratorTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get( + "APPID" + ): + return self.create_text_message("APPID is required.") + if ( + "APISecret" not in self.runtime.credentials + or not self.runtime.credentials.get("APISecret") + ): + return self.create_text_message("APISecret is required.") + if ( + "APIKey" not in self.runtime.credentials + or not self.runtime.credentials.get("APIKey") + ): + return self.create_text_message("APIKey is required.") + + prompt = tool_parameters.get("prompt", "") + if not prompt: + return self.create_text_message("Please input prompt") + res = self.img_generation(prompt) + result = [] + for image in res: + result.append( + self.create_blob_message( + blob=b64decode(image["base64_image"]), + meta={"mime_type": "image/png"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) + ) + return result + + def img_generation(self, prompt): + response = spark_response( + text=prompt, + appid=self.runtime.credentials.get("APPID"), + apikey=self.runtime.credentials.get("APIKey"), + apisecret=self.runtime.credentials.get("APISecret"), + ) + data = json.loads(response) + code = data["header"]["code"] + if code != 0: + return self.create_text_message(f"error: {code}, {data}") + else: + text = data["payload"]["choices"]["text"] + image_content = text[0] + image_base = image_content["content"] + json_data = {"base64_image": image_base} + return [json_data] diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml new file mode 100644 index 0000000000..d44bbc9564 --- /dev/null +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml @@ -0,0 +1,36 @@ +identity: + name: spark_img_generation + author: Onelevenvy + label: + en_US: Spark Image Generation + zh_Hans: 图片生成 + pt_BR: Geração de imagens Spark + icon: icon.svg + description: + en_US: Spark Image Generation + zh_Hans: 图片生成 + pt_BR: Geração de imagens Spark +description: + human: + en_US: Generate images based on user input, with image generation API + provided by Spark + zh_Hans: 根据用户的输入生成图片,由讯飞星火提供图片生成api + pt_BR: Gerar imagens com base na entrada do usuário, com API de geração + de imagem fornecida pela Spark + llm: spark_img_generation is a tool used to generate images from text +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: Image prompt + zh_Hans: 图像提示词 + pt_BR: Image prompt + llm_description: Image prompt of spark_img_generation tooll, you should + describe the image you want to generate as a list of words as possible + as detailed + form: llm diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py index 6fa9d190e5..6ef0017fee 100644 --- a/sdks/python-client/dify_client/__init__.py +++ b/sdks/python-client/dify_client/__init__.py @@ -1 +1 @@ -from dify_client.client import ChatClient, CompletionClient, DifyClient +from dify_client.client import ChatClient, CompletionClient, DifyClient \ No newline at end of file From 95b74c211df5e8191924c98f5cc1627a87343d9c Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Mon, 18 Mar 2024 16:55:26 +0800 Subject: [PATCH 2/4] Feat/support tool credentials bool schema (#2875) --- api/core/tools/entities/tool_entities.py | 3 +- api/core/tools/provider/builtin/bing/bing.py | 5 +- .../tools/provider/builtin/bing/bing.yaml | 60 +++++++ .../builtin/bing/tools/bing_web_search.py | 148 +++++++++++++----- .../tools/provider/builtin_tool_provider.py | 23 ++- api/services/tools_manage_service.py | 6 +- .../setting/build-in/config-credentials.tsx | 9 +- 7 files changed, 204 insertions(+), 50 deletions(-) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index f7a61b0b0c..437f871864 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -171,6 +171,7 @@ class ToolProviderCredentials(BaseModel): SECRET_INPUT = "secret-input" TEXT_INPUT = "text-input" SELECT = "select" + BOOLEAN = "boolean" @classmethod def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType": @@ -192,7 +193,7 @@ class ToolProviderCredentials(BaseModel): name: str = Field(..., description="The name of the credentials") type: CredentialsType = Field(..., description="The type of the credentials") required: bool = False - default: Optional[str] = None + default: Optional[Union[int, str]] = None options: Optional[list[ToolCredentialsOption]] = None label: Optional[I18nObject] = None help: Optional[I18nObject] = None diff --git a/api/core/tools/provider/builtin/bing/bing.py b/api/core/tools/provider/builtin/bing/bing.py index ff131b26cd..6e62abfc10 100644 --- a/api/core/tools/provider/builtin/bing/bing.py +++ b/api/core/tools/provider/builtin/bing/bing.py @@ -12,12 +12,11 @@ class BingProvider(BuiltinToolProviderController): meta={ "credentials": credentials, } - ).invoke( - user_id='', + ).validate_credentials( + credentials=credentials, tool_parameters={ "query": "test", "result_type": "link", - "enable_webpages": True, }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/bing/bing.yaml b/api/core/tools/provider/builtin/bing/bing.yaml index 9df836929c..35cd729208 100644 --- a/api/core/tools/provider/builtin/bing/bing.yaml +++ b/api/core/tools/provider/builtin/bing/bing.yaml @@ -43,3 +43,63 @@ credentials_for_provider: zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search" pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search" default: https://api.bing.microsoft.com/v7.0/search + allow_entities: + type: boolean + required: false + label: + en_US: Allow Entities Search + zh_Hans: 支持实体搜索 + pt_BR: Allow Entities Search + help: + en_US: Does your subscription plan allow entity search + zh_Hans: 您的订阅计划是否支持实体搜索 + pt_BR: Does your subscription plan allow entity search + default: true + allow_web_pages: + type: boolean + required: false + label: + en_US: Allow Web Pages Search + zh_Hans: 支持网页搜索 + pt_BR: Allow Web Pages Search + help: + en_US: Does your subscription plan allow web pages search + zh_Hans: 您的订阅计划是否支持网页搜索 + pt_BR: Does your subscription plan allow web pages search + default: true + allow_computation: + type: boolean + required: false + label: + en_US: Allow Computation Search + zh_Hans: 支持计算搜索 + pt_BR: Allow Computation Search + help: + en_US: Does your subscription plan allow computation search + zh_Hans: 您的订阅计划是否支持计算搜索 + pt_BR: Does your subscription plan allow computation search + default: false + allow_news: + type: boolean + required: false + label: + en_US: Allow News Search + zh_Hans: 支持新闻搜索 + pt_BR: Allow News Search + help: + en_US: Does your subscription plan allow news search + zh_Hans: 您的订阅计划是否支持新闻搜索 + pt_BR: Does your subscription plan allow news search + default: false + allow_related_searches: + type: boolean + required: false + label: + en_US: Allow Related Searches + zh_Hans: 支持相关搜索 + pt_BR: Allow Related Searches + help: + en_US: Does your subscription plan allow related searches + zh_Hans: 您的订阅计划是否支持相关搜索 + pt_BR: Does your subscription plan allow related searches + default: false diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index 7b740293dd..8f11d2173c 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -10,53 +10,23 @@ from core.tools.tool.builtin_tool import BuiltinTool class BingSearchTool(BuiltinTool): url = 'https://api.bing.microsoft.com/v7.0/search' - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke_bing(self, + user_id: str, + subscription_key: str, query: str, limit: int, + result_type: str, market: str, lang: str, + filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke bing search """ - - key = self.runtime.credentials.get('subscription_key', None) - if not key: - raise Exception('subscription_key is required') - - server_url = self.runtime.credentials.get('server_url', None) - if not server_url: - server_url = self.url - - query = tool_parameters.get('query', None) - if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' - - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') - filter = [] - - if tool_parameters.get('enable_computation', False): - filter.append('Computation') - if tool_parameters.get('enable_entities', False): - filter.append('Entities') - if tool_parameters.get('enable_news', False): - filter.append('News') - if tool_parameters.get('enable_related_search', False): - filter.append('RelatedSearches') - if tool_parameters.get('enable_webpages', False): - filter.append('WebPages') - market_code = f'{lang}-{market}' accept_language = f'{lang},{market_code};q=0.9' headers = { - 'Ocp-Apim-Subscription-Key': key, + 'Ocp-Apim-Subscription-Key': subscription_key, 'Accept-Language': accept_language } query = quote(query) - server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filter)}' + server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' response = get(server_url, headers=headers) if response.status_code != 200: @@ -124,3 +94,105 @@ class BingSearchTool(BuiltinTool): text += f'{related["displayText"]} - {related["webSearchUrl"]}\n' return self.create_text_message(text=self.summary(user_id=user_id, content=text)) + + + def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: + key = credentials.get('subscription_key', None) + if not key: + raise Exception('subscription_key is required') + + server_url = credentials.get('server_url', None) + if not server_url: + server_url = self.url + + query = tool_parameters.get('query', None) + if not query: + raise Exception('query is required') + + limit = min(tool_parameters.get('limit', 5), 10) + result_type = tool_parameters.get('result_type', 'text') or 'text' + + market = tool_parameters.get('market', 'US') + lang = tool_parameters.get('language', 'en') + filter = [] + + if credentials.get('allow_entities', False): + filter.append('Entities') + + if credentials.get('allow_computation', False): + filter.append('Computation') + + if credentials.get('allow_news', False): + filter.append('News') + + if credentials.get('allow_related_searches', False): + filter.append('RelatedSearches') + + if credentials.get('allow_web_pages', False): + filter.append('WebPages') + + if not filter: + raise Exception('At least one filter is required') + + self._invoke_bing( + user_id='test', + subscription_key=key, + query=query, + limit=limit, + result_type=result_type, + market=market, + lang=lang, + filters=filter + ) + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + key = self.runtime.credentials.get('subscription_key', None) + if not key: + raise Exception('subscription_key is required') + + server_url = self.runtime.credentials.get('server_url', None) + if not server_url: + server_url = self.url + + query = tool_parameters.get('query', None) + if not query: + raise Exception('query is required') + + limit = min(tool_parameters.get('limit', 5), 10) + result_type = tool_parameters.get('result_type', 'text') or 'text' + + market = tool_parameters.get('market', 'US') + lang = tool_parameters.get('language', 'en') + filter = [] + + if tool_parameters.get('enable_computation', False): + filter.append('Computation') + if tool_parameters.get('enable_entities', False): + filter.append('Entities') + if tool_parameters.get('enable_news', False): + filter.append('News') + if tool_parameters.get('enable_related_search', False): + filter.append('RelatedSearches') + if tool_parameters.get('enable_webpages', False): + filter.append('WebPages') + + if not filter: + raise Exception('At least one filter is required') + + return self._invoke_bing( + user_id=user_id, + subscription_key=key, + query=query, + limit=limit, + result_type=result_type, + market=market, + lang=lang, + filters=filter + ) \ No newline at end of file diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 93e7d5a39e..824f91c822 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -246,8 +246,27 @@ class BuiltinToolProviderController(ToolProviderController): if credentials[credential_name] not in [x.value for x in options]: raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}') - - if credentials[credential_name]: + elif credential_schema.type == ToolProviderCredentials.CredentialsType.BOOLEAN: + if isinstance(credentials[credential_name], bool): + pass + elif isinstance(credentials[credential_name], str): + if credentials[credential_name].lower() == 'true': + credentials[credential_name] = True + elif credentials[credential_name].lower() == 'false': + credentials[credential_name] = False + else: + raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean') + elif isinstance(credentials[credential_name], int): + if credentials[credential_name] == 1: + credentials[credential_name] = True + elif credentials[credential_name] == 0: + credentials[credential_name] = False + else: + raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean') + else: + raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean') + + if credentials[credential_name] or credentials[credential_name] == False: credentials_need_to_validate.pop(credential_name) for credential_name in credentials_need_to_validate: diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index ff618e5d2b..70c6a44459 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -138,9 +138,9 @@ class ToolManageService: :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name) - return [ - v.to_dict() for _, v in (provider.credentials_schema or {}).items() - ] + return json.loads(serialize_base_model_array([ + v for _, v in (provider.credentials_schema or {}).items() + ])) @staticmethod def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]: diff --git a/web/app/components/tools/setting/build-in/config-credentials.tsx b/web/app/components/tools/setting/build-in/config-credentials.tsx index d5365001c8..1a3c8f015a 100644 --- a/web/app/components/tools/setting/build-in/config-credentials.tsx +++ b/web/app/components/tools/setting/build-in/config-credentials.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import cn from 'classnames' -import { toolCredentialToFormSchemas } from '../../utils/to-form-schema' +import { addDefaultValue, toolCredentialToFormSchemas } from '../../utils/to-form-schema' import type { Collection } from '../../types' import Drawer from '@/app/components/base/drawer-plus' import Button from '@/app/components/base/button' @@ -28,12 +28,15 @@ const ConfigCredential: FC = ({ const { t } = useTranslation() const [credentialSchema, setCredentialSchema] = useState(null) const { team_credentials: credentialValue, name: collectionName } = collection + const [tempCredential, setTempCredential] = React.useState(credentialValue) useEffect(() => { fetchBuiltInToolCredentialSchema(collectionName).then((res) => { - setCredentialSchema(toolCredentialToFormSchemas(res)) + const toolCredentialSchemas = toolCredentialToFormSchemas(res) + const defaultCredentials = addDefaultValue(credentialValue, toolCredentialSchemas) + setCredentialSchema(toolCredentialSchemas) + setTempCredential(defaultCredentials) }) }, []) - const [tempCredential, setTempCredential] = React.useState(credentialValue) return ( Date: Mon, 18 Mar 2024 17:01:25 +0800 Subject: [PATCH 3/4] fix/Add isModel flag to AgentTools component (#2876) --- .../config/agent/agent-tools/index.tsx | 1 + .../agent/agent-tools/setting-built-in-tool.tsx | 15 ++++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 95858d9540..b92ff94983 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -210,6 +210,7 @@ const AgentTools: FC = () => { setting={currentTool?.tool_parameters as any} collection={currentTool?.collection as Collection} isBuiltIn={currentTool?.collection?.type === CollectionType.builtIn} + isModel={currentTool?.collection?.type === CollectionType.model} onSave={handleToolSettingChange} onHide={() => setIsShowSettingTool(false)} />) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index 378054aae6..9eb2657fcf 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -58,11 +58,16 @@ const SettingBuiltInTool: FC = ({ (async () => { setIsLoading(true) try { - const list = isBuiltIn - ? await fetchBuiltInToolList(collection.name) - : isModel - ? await fetchModelToolList(collection.name) - : await fetchCustomToolList(collection.name) + const list = await new Promise((resolve) => { + (async function () { + if (isModel) + resolve(await fetchModelToolList(collection.name)) + else if (isBuiltIn) + resolve(await fetchBuiltInToolList(collection.name)) + else + resolve(await fetchCustomToolList(collection.name)) + }()) + }) setTools(list) const currTool = list.find(tool => tool.name === toolName) if (currTool) { From 4834eae887931db0c51ca89db1e6e8c84fd2643a Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:18:52 +0800 Subject: [PATCH 4/4] fix enable annotation reply when collection is None (#2877) Co-authored-by: jyong --- api/tasks/annotation/enable_annotation_reply_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index f3260bbb50..666fa8692f 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -89,7 +89,7 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_ logging.info( click.style('Delete annotation index error: {}'.format(str(e)), fg='red')) - vector.add_texts(documents) + vector.create(documents) db.session.commit() redis_client.setex(enable_app_annotation_job_key, 600, 'completed') end_at = time.perf_counter()