From 87a47762720c3b9d718469c2142f0b6be0583094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=80=E5=9D=A6=E5=85=8B=E7=9A=84=E8=B4=9D=E5=A1=94?= Date: Wed, 19 Jun 2024 17:34:11 +0800 Subject: [PATCH] feat(model/tools): filter unregistered tools and models --- .../model_providers/model_provider_factory.py | 7 ++++++- api/core/tools/provider/builtin/_positions.py | 5 ++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index a4dbaabfc9..bb147f330f 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -201,10 +201,15 @@ class ModelProviderFactory: model_providers_path = os.path.dirname(current_path) # get all folders path under model_providers_path that do not start with __ + whitelist = [ + "baichuan", "chatglm", "deepseek", "hunyuan", "minimax", "moonshot", + "siliconflow", "tongyi", "volcengine_maas", + "wenxin", "xinference", "yi", "zhipuai" + ] model_provider_dir_paths = [ os.path.join(model_providers_path, model_provider_dir) for model_provider_dir in os.listdir(model_providers_path) - if not model_provider_dir.startswith('__') + if model_provider_dir in whitelist and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) ] diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index ae806eaff4..ecb135bc7e 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -17,4 +17,7 @@ class BuiltinToolProviderSort: sorted_providers = sort_by_position_map(cls._position, providers, name_func) - return sorted_providers \ No newline at end of file + blacklist = ['duckduckgo', 'brave', 'dalle', 'github', 'google', 'jina', 'slack', 'stablediffusion', 'youtube'] + filtered_providers = [provider for provider in sorted_providers if provider.name not in blacklist] + return filtered_providers +