From 396a3e0456d2a313b70dc9e7557b5bb3a23bd087 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 2 Apr 2024 11:58:50 +0800 Subject: [PATCH] feat: add tool parameter type converter --- api/core/tools/tool/tool.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 045802dd63..bb34d90a50 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -179,6 +179,9 @@ class Tool(BaseModel, ABC): if self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) + # try parse tool parameters into the correct type + tool_parameters = self._transform_tool_parameters_type(tool_parameters) + result = self._invoke( user_id=user_id, tool_parameters=tool_parameters, @@ -211,6 +214,31 @@ class Tool(BaseModel, ABC): result += f"tool response: {response.message}." return result + + def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: + """ + Transform tool parameters type + """ + for parameter in self.parameters: + if parameter.name in tool_parameters: + if parameter.type in [ + ToolParameter.ToolParameterType.SECRET_INPUT, + ToolParameter.ToolParameterType.STRING, + ToolParameter.ToolParameterType.SELECT, + ] and not isinstance(tool_parameters[parameter.name], str): + tool_parameters[parameter.name] = str(tool_parameters[parameter.name]) + elif parameter.type == ToolParameter.ToolParameterType.NUMBER \ + and not isinstance(tool_parameters[parameter.name], int | float): + if isinstance(tool_parameters[parameter.name], str): + try: + tool_parameters[parameter.name] = int(tool_parameters[parameter.name]) + except ValueError: + tool_parameters[parameter.name] = float(tool_parameters[parameter.name]) + elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: + if not isinstance(tool_parameters[parameter.name], bool): + tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) + + return tool_parameters @abstractmethod def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: