feat: add tool benchmark

This commit is contained in:
Yeuoly 2024-04-02 15:23:54 +08:00
parent 426abe2134
commit 1af2d06d29
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
2 changed files with 42 additions and 23 deletions

View File

@ -1,6 +1,7 @@
import json
import logging
import mimetypes
from collections.abc import Generator
from os import listdir, path
from typing import Any, Union
@ -350,14 +351,14 @@ class ToolManager:
return absolute_path, mime_type
@staticmethod
def list_builtin_providers() -> list[BuiltinToolProviderController]:
def list_builtin_providers() -> Generator[BuiltinToolProviderController, None, None]:
global _builtin_providers
# use cache first
if len(_builtin_providers) > 0:
return list(_builtin_providers.values())
yield from list(_builtin_providers.values())
return
builtin_providers: list[BuiltinToolProviderController] = []
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'):
continue
@ -373,18 +374,25 @@ class ToolManager:
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
builtin_providers.append(provider_class())
provider: BuiltinToolProviderController = provider_class()
_builtin_providers[provider.identity.name] = provider
for tool in provider.get_tools():
_builtin_tools_labels[tool.identity.name] = tool.identity.label
yield provider
except Exception as e:
logger.error(f'load builtin provider {provider} error: {e}')
continue
# cache the builtin providers
for provider in builtin_providers:
_builtin_providers[provider.identity.name] = provider
for tool in provider.get_tools():
_builtin_tools_labels[tool.identity.name] = tool.identity.label
@staticmethod
def load_builtin_providers_cache():
for _ in ToolManager.list_builtin_providers():
pass
return builtin_providers
@staticmethod
def clear_builtin_providers_cache():
global _builtin_providers
_builtin_providers = {}
@staticmethod
def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
@ -441,7 +449,7 @@ class ToolManager:
global _builtin_tools_labels
if len(_builtin_tools_labels) == 0:
# init the builtin providers
ToolManager.list_builtin_providers()
ToolManager.load_builtin_providers_cache()
if tool_name not in _builtin_tools_labels:
return None
@ -474,14 +482,14 @@ class ToolManager:
result_providers[provider.identity.name] = user_provider
# get model tool providers
model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
# append model providers
for provider in model_providers:
user_provider = ToolTransformService.model_provider_to_user_provider(
db_provider=provider,
)
result_providers[f'model_provider.{provider.identity.name}'] = user_provider
# # get model tool providers
# model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
# # append model providers
# for provider in model_providers:
# user_provider = ToolTransformService.model_provider_to_user_provider(
# db_provider=provider,
# )
# result_providers[f'model_provider.{provider.identity.name}'] = user_provider
# get db api providers
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \

View File

@ -1,10 +1,21 @@
import pytest
from core.tools.tool_manager import ToolManager
provider_generator = ToolManager.list_builtin_providers()
provider_names = [provider.identity.name for provider in provider_generator]
ToolManager.clear_builtin_providers_cache()
provider_generator = ToolManager.list_builtin_providers()
def test_tool_providers():
@pytest.mark.parametrize('name', provider_names)
def test_tool_providers(benchmark, name):
"""
Test that all tool providers can be loaded
"""
providers = ToolManager.list_builtin_providers()
for provider in providers:
provider.get_tools()
def test(generator):
try:
return next(generator)
except StopIteration:
return None
benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)