diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 12d8e282b2..1cd1745f92 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -6,16 +6,13 @@ from graphon.file import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: for parameter_name, parameter in parameters.items(): - if isinstance(parameter, File): - parameters[parameter_name] = parameter.to_plugin_parameter() - elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter): - parameters[parameter_name] = [] - for p in parameter: - parameters[parameter_name].append(p.to_plugin_parameter()) - elif isinstance(parameter, ToolSelector): - parameters[parameter_name] = parameter.to_plugin_parameter() - elif isinstance(parameter, list) and all(isinstance(p, ToolSelector) for p in parameter): - parameters[parameter_name] = [] - for p in parameter: - parameters[parameter_name].append(p.to_plugin_parameter()) + match parameter: + case File(): + parameters[parameter_name] = parameter.to_plugin_parameter() + case [*items] if all(isinstance(p, File) for p in items): + parameters[parameter_name] = [p.to_plugin_parameter() for p in items] + case ToolSelector(): + parameters[parameter_name] = parameter.to_plugin_parameter() + case [*items] if all(isinstance(p, ToolSelector) for p in items): + parameters[parameter_name] = [p.to_plugin_parameter() for p in items] return parameters diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 4d784b5f23..b219ba4957 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -66,20 +66,21 @@ class Tool(ABC): message_id=message_id, ) - if isinstance(result, ToolInvokeMessage): + match result: + case ToolInvokeMessage(): - def single_generator() -> Generator[ToolInvokeMessage, None, None]: - yield result + def single_generator() -> Generator[ToolInvokeMessage, None, None]: + yield result - return single_generator() - elif isinstance(result, list): + return single_generator() + case list(): - def generator() -> Generator[ToolInvokeMessage, None, None]: - yield from result + def generator() -> Generator[ToolInvokeMessage, None, None]: + yield from result - return generator() - else: - return result + return generator() + case _: + return result def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: """