mirror of https://github.com/langgenius/dify.git
feat: add tool benchmark
This commit is contained in:
parent
426abe2134
commit
1af2d06d29
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from typing import Any, Union
|
||||
|
||||
|
|
@ -350,14 +351,14 @@ class ToolManager:
|
|||
return absolute_path, mime_type
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_providers() -> list[BuiltinToolProviderController]:
|
||||
def list_builtin_providers() -> Generator[BuiltinToolProviderController, None, None]:
|
||||
global _builtin_providers
|
||||
|
||||
# use cache first
|
||||
if len(_builtin_providers) > 0:
|
||||
return list(_builtin_providers.values())
|
||||
yield from list(_builtin_providers.values())
|
||||
return
|
||||
|
||||
builtin_providers: list[BuiltinToolProviderController] = []
|
||||
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||
if provider.startswith('__'):
|
||||
continue
|
||||
|
|
@ -373,18 +374,25 @@ class ToolManager:
|
|||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
builtin_providers.append(provider_class())
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
_builtin_providers[provider.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
_builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
yield provider
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'load builtin provider {provider} error: {e}')
|
||||
continue
|
||||
|
||||
# cache the builtin providers
|
||||
for provider in builtin_providers:
|
||||
_builtin_providers[provider.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
_builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
@staticmethod
|
||||
def load_builtin_providers_cache():
|
||||
for _ in ToolManager.list_builtin_providers():
|
||||
pass
|
||||
|
||||
return builtin_providers
|
||||
@staticmethod
|
||||
def clear_builtin_providers_cache():
|
||||
global _builtin_providers
|
||||
_builtin_providers = {}
|
||||
|
||||
@staticmethod
|
||||
def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
|
||||
|
|
@ -441,7 +449,7 @@ class ToolManager:
|
|||
global _builtin_tools_labels
|
||||
if len(_builtin_tools_labels) == 0:
|
||||
# init the builtin providers
|
||||
ToolManager.list_builtin_providers()
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
|
||||
if tool_name not in _builtin_tools_labels:
|
||||
return None
|
||||
|
|
@ -474,14 +482,14 @@ class ToolManager:
|
|||
|
||||
result_providers[provider.identity.name] = user_provider
|
||||
|
||||
# get model tool providers
|
||||
model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
|
||||
# append model providers
|
||||
for provider in model_providers:
|
||||
user_provider = ToolTransformService.model_provider_to_user_provider(
|
||||
db_provider=provider,
|
||||
)
|
||||
result_providers[f'model_provider.{provider.identity.name}'] = user_provider
|
||||
# # get model tool providers
|
||||
# model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
|
||||
# # append model providers
|
||||
# for provider in model_providers:
|
||||
# user_provider = ToolTransformService.model_provider_to_user_provider(
|
||||
# db_provider=provider,
|
||||
# )
|
||||
# result_providers[f'model_provider.{provider.identity.name}'] = user_provider
|
||||
|
||||
# get db api providers
|
||||
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
|
||||
|
|
|
|||
|
|
@ -1,10 +1,21 @@
|
|||
import pytest
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
provider_generator = ToolManager.list_builtin_providers()
|
||||
provider_names = [provider.identity.name for provider in provider_generator]
|
||||
ToolManager.clear_builtin_providers_cache()
|
||||
provider_generator = ToolManager.list_builtin_providers()
|
||||
|
||||
def test_tool_providers():
|
||||
@pytest.mark.parametrize('name', provider_names)
|
||||
def test_tool_providers(benchmark, name):
|
||||
"""
|
||||
Test that all tool providers can be loaded
|
||||
"""
|
||||
providers = ToolManager.list_builtin_providers()
|
||||
for provider in providers:
|
||||
provider.get_tools()
|
||||
|
||||
def test(generator):
|
||||
try:
|
||||
return next(generator)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)
|
||||
Loading…
Reference in New Issue