diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d2cbbdec0c..c5831876c8 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1,6 +1,7 @@ import json import logging import mimetypes +from collections.abc import Generator from os import listdir, path from typing import Any, Union @@ -350,14 +351,14 @@ class ToolManager: return absolute_path, mime_type @staticmethod - def list_builtin_providers() -> list[BuiltinToolProviderController]: + def list_builtin_providers() -> Generator[BuiltinToolProviderController, None, None]: global _builtin_providers # use cache first if len(_builtin_providers) > 0: - return list(_builtin_providers.values()) + yield from list(_builtin_providers.values()) + return - builtin_providers: list[BuiltinToolProviderController] = [] for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): if provider.startswith('__'): continue @@ -373,18 +374,25 @@ class ToolManager: script_path=path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py'), parent_type=BuiltinToolProviderController) - builtin_providers.append(provider_class()) + provider: BuiltinToolProviderController = provider_class() + _builtin_providers[provider.identity.name] = provider + for tool in provider.get_tools(): + _builtin_tools_labels[tool.identity.name] = tool.identity.label + yield provider + except Exception as e: logger.error(f'load builtin provider {provider} error: {e}') continue - # cache the builtin providers - for provider in builtin_providers: - _builtin_providers[provider.identity.name] = provider - for tool in provider.get_tools(): - _builtin_tools_labels[tool.identity.name] = tool.identity.label + @staticmethod + def load_builtin_providers_cache(): + for _ in ToolManager.list_builtin_providers(): + pass - return builtin_providers + @staticmethod + def clear_builtin_providers_cache(): + global _builtin_providers + _builtin_providers = {} @staticmethod def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]: @@ -441,7 +449,7 @@ class ToolManager: global _builtin_tools_labels if len(_builtin_tools_labels) == 0: # init the builtin providers - ToolManager.list_builtin_providers() + ToolManager.load_builtin_providers_cache() if tool_name not in _builtin_tools_labels: return None @@ -474,14 +482,14 @@ class ToolManager: result_providers[provider.identity.name] = user_provider - # get model tool providers - model_providers = ToolManager.list_model_providers(tenant_id=tenant_id) - # append model providers - for provider in model_providers: - user_provider = ToolTransformService.model_provider_to_user_provider( - db_provider=provider, - ) - result_providers[f'model_provider.{provider.identity.name}'] = user_provider + # # get model tool providers + # model_providers = ToolManager.list_model_providers(tenant_id=tenant_id) + # # append model providers + # for provider in model_providers: + # user_provider = ToolTransformService.model_provider_to_user_provider( + # db_provider=provider, + # ) + # result_providers[f'model_provider.{provider.identity.name}'] = user_provider # get db api providers db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ diff --git a/api/tests/integration_tests/tools/test_all_provider.py b/api/tests/integration_tests/tools/test_all_provider.py index c846ddecfb..65645cb6c5 100644 --- a/api/tests/integration_tests/tools/test_all_provider.py +++ b/api/tests/integration_tests/tools/test_all_provider.py @@ -1,10 +1,21 @@ +import pytest from core.tools.tool_manager import ToolManager +provider_generator = ToolManager.list_builtin_providers() +provider_names = [provider.identity.name for provider in provider_generator] +ToolManager.clear_builtin_providers_cache() +provider_generator = ToolManager.list_builtin_providers() -def test_tool_providers(): +@pytest.mark.parametrize('name', provider_names) +def test_tool_providers(benchmark, name): """ Test that all tool providers can be loaded """ - providers = ToolManager.list_builtin_providers() - for provider in providers: - provider.get_tools() + + def test(generator): + try: + return next(generator) + except StopIteration: + return None + + benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) \ No newline at end of file