diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 53d52b5866..27ac0e3455 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -159,3 +159,6 @@ class ProviderConfig(BasicProviderConfig): help: Optional[I18nObject] = None url: Optional[str] = None placeholder: Optional[I18nObject] = None + + def to_basic_provider_config(self) -> BasicProviderConfig: + return BasicProviderConfig(type=self.type, name=self.name) diff --git a/api/core/plugin/encrypt/__init__.py b/api/core/plugin/encrypt/__init__.py index 6303e2ade1..577e1bbace 100644 --- a/api/core/plugin/encrypt/__init__.py +++ b/api/core/plugin/encrypt/__init__.py @@ -1,6 +1,3 @@ -from collections.abc import Mapping -from typing import Any - from core.plugin.entities.request import RequestInvokeEncrypt from core.tools.utils.configuration import ProviderConfigEncrypter from models.account import Tenant @@ -11,7 +8,7 @@ class PluginEncrypter: def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: encrypter = ProviderConfigEncrypter( tenant_id=tenant.id, - config=payload.data, + config=payload.config, provider_type=payload.namespace, provider_identity=payload.identity, ) diff --git a/api/core/plugin/entities/endpoint.py b/api/core/plugin/entities/endpoint.py index db7819f354..b1a203b39c 100644 --- a/api/core/plugin/entities/endpoint.py +++ b/api/core/plugin/entities/endpoint.py @@ -1,4 +1,3 @@ -from collections.abc import Mapping from datetime import datetime from pydantic import BaseModel, Field @@ -12,7 +11,7 @@ class EndpointDeclaration(BaseModel): declaration of an endpoint """ - settings: Mapping[str, ProviderConfig] = Field(default_factory=Mapping) + settings: list[ProviderConfig] = Field(default_factory=list) class EndpointEntity(BasePluginEntity): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index af40ebc5ca..19bf329674 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -1,4 +1,3 @@ -from collections.abc import Mapping from typing import Any, Literal, Optional from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -181,4 +180,4 @@ class RequestInvokeEncrypt(BaseModel): namespace: Literal["endpoint"] identity: str data: dict = Field(default_factory=dict) - config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping) + config: list[BasicProviderConfig] = Field(default_factory=list) diff --git a/api/core/tools/__base/tool_provider.py b/api/core/tools/__base/tool_provider.py index 492f1c08ae..d096fc7df7 100644 --- a/api/core/tools/__base/tool_provider.py +++ b/api/core/tools/__base/tool_provider.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from copy import deepcopy from typing import Any from core.entities.provider_entities import ProviderConfig @@ -16,13 +17,13 @@ class ToolProviderController(ABC): def __init__(self, entity: ToolProviderEntity) -> None: self.entity = entity - def get_credentials_schema(self) -> dict[str, ProviderConfig]: + def get_credentials_schema(self) -> list[ProviderConfig]: """ returns the credentials schema of the provider :return: the credentials schema """ - return self.entity.credentials_schema.copy() + return deepcopy(self.entity.credentials_schema) @abstractmethod def get_tool(self, tool_name: str) -> Tool: @@ -48,10 +49,13 @@ class ToolProviderController(ABC): :param credentials: the credentials of the tool """ - credentials_schema = self.entity.credentials_schema + credentials_schema = dict[str, ProviderConfig]() if credentials_schema is None: return + for credential in self.entity.credentials_schema: + credentials_schema[credential.name] = credential + credentials_need_to_validate: dict[str, ProviderConfig] = {} for credential_name in credentials_schema: credentials_need_to_validate[credential_name] = credentials_schema[credential_name] diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index e7e374f2e6..18ad385e49 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -34,10 +34,14 @@ class BuiltinToolProviderController(ToolProviderController): for credential_name in provider_yaml["credentials_for_provider"]: provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name + credentials_schema = [] + for credential in provider_yaml.get("credentials_for_provider", {}): + credentials_schema.append(credential) + super().__init__( entity=ToolProviderEntity( identity=provider_yaml["identity"], - credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {}, + credentials_schema=credentials_schema, ), ) @@ -84,14 +88,14 @@ class BuiltinToolProviderController(ToolProviderController): self.tools = tools return tools - def get_credentials_schema(self) -> dict[str, ProviderConfig]: + def get_credentials_schema(self) -> list[ProviderConfig]: """ returns the credentials schema of the provider :return: the credentials schema """ if not self.entity.credentials_schema: - return {} + return [] return self.entity.credentials_schema.copy() diff --git a/api/core/tools/builtin_tool/providers/code/code.yaml b/api/core/tools/builtin_tool/providers/code/code.yaml index 2640a7087e..81ab6d4b45 100644 --- a/api/core/tools/builtin_tool/providers/code/code.yaml +++ b/api/core/tools/builtin_tool/providers/code/code.yaml @@ -12,4 +12,3 @@ identity: icon: icon.svg tags: - productivity -credentials_for_provider: diff --git a/api/core/tools/builtin_tool/providers/time/time.yaml b/api/core/tools/builtin_tool/providers/time/time.yaml index 1278939df5..77bdc0f87a 100644 --- a/api/core/tools/builtin_tool/providers/time/time.yaml +++ b/api/core/tools/builtin_tool/providers/time/time.yaml @@ -12,4 +12,3 @@ identity: icon: icon.svg tags: - utilities -credentials_for_provider: diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index c5e3e8488e..574cba05e6 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -28,8 +28,8 @@ class ApiToolProviderController(ToolProviderController): @classmethod def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType): - credentials_schema = { - "auth_type": ProviderConfig( + credentials_schema = [ + ProviderConfig( name="auth_type", required=True, type=ProviderConfig.Type.SELECT, @@ -40,24 +40,24 @@ class ApiToolProviderController(ToolProviderController): default="none", help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"), ) - } + ] if auth_type == ApiProviderAuthType.API_KEY: - credentials_schema = { - **credentials_schema, - "api_key_header": ProviderConfig( + credentials_schema = [ + *credentials_schema, + ProviderConfig( name="api_key_header", required=False, default="api_key", type=ProviderConfig.Type.TEXT_INPUT, help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"), ), - "api_key_value": ProviderConfig( + ProviderConfig( name="api_key_value", required=True, type=ProviderConfig.Type.SECRET_INPUT, help=I18nObject(en_US="The api key", zh_Hans="api key的值"), ), - "api_key_header_prefix": ProviderConfig( + ProviderConfig( name="api_key_header_prefix", required=False, default="basic", @@ -69,7 +69,7 @@ class ApiToolProviderController(ToolProviderController): ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")), ], ), - } + ] elif auth_type == ApiProviderAuthType.NONE: pass diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index eff808a181..3758820694 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,7 +2,6 @@ from typing import Literal, Optional from pydantic import BaseModel, Field -from core.entities.provider_entities import ProviderConfig from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject @@ -62,7 +61,3 @@ class ToolProviderApiEntity(BaseModel): "tools": tools, "labels": self.labels, } - - -class ToolProviderCredentialsApiEntity(BaseModel): - credentials: dict[str, ProviderConfig] diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index c96498f80b..b62707b8f7 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -312,7 +312,7 @@ class ToolEntity(BaseModel): class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity - credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) + credentials_schema: list[ProviderConfig] = Field(default_factory=list) class ToolProviderEntityWithPlugin(ToolProviderEntity): diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 8e5e3bfd2f..fadc649e1f 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -160,7 +160,7 @@ class ToolManager: credentials = builtin_provider.credentials tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, - config=provider_controller.get_credentials_schema(), + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) @@ -186,7 +186,7 @@ class ToolManager: # decrypt the credentials tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, - config=api_provider.get_credentials_schema(), + config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], provider_type=api_provider.provider_type.value, provider_identity=api_provider.entity.identity.name, ) @@ -643,7 +643,7 @@ class ToolManager: # init tool configuration tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, - config=controller.get_credentials_schema(), + config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], provider_type=controller.provider_type.value, provider_identity=controller.entity.identity.name, ) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 9f685a89b6..9aaac0be21 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,4 +1,3 @@ -from collections.abc import Mapping from copy import deepcopy from typing import Any @@ -17,7 +16,7 @@ from core.tools.entities.tool_entities import ( class ProviderConfigEncrypter(BaseModel): tenant_id: str - config: Mapping[str, BasicProviderConfig] + config: list[BasicProviderConfig] provider_type: str provider_identity: str @@ -36,7 +35,10 @@ class ProviderConfigEncrypter(BaseModel): data = self._deep_copy(data) # get fields need to be decrypted - fields = self.config + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + for field_name, field in fields.items(): if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in data: @@ -54,7 +56,10 @@ class ProviderConfigEncrypter(BaseModel): data = self._deep_copy(data) # get fields need to be decrypted - fields = self.config + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + for field_name, field in fields.items(): if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in data: @@ -83,7 +88,10 @@ class ProviderConfigEncrypter(BaseModel): return cached_credentials data = self._deep_copy(data) # get fields need to be decrypted - fields = self.config + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + for field_name, field in fields.items(): if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in data: diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index c3d778558b..542add9336 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -35,7 +35,7 @@ class BuiltinToolManageService: tool_provider_configurations = ProviderConfigEncrypter( tenant_id=tenant_id, - config=provider_controller.get_credentials_schema(), + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) @@ -78,7 +78,7 @@ class BuiltinToolManageService: :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return jsonable_encoder([v for _, v in (provider.get_credentials_schema() or {}).items()]) + return jsonable_encoder(provider.get_credentials_schema()) @staticmethod def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): @@ -102,7 +102,7 @@ class BuiltinToolManageService: raise ValueError(f"provider {provider_name} does not need credentials") tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, - config=provider_controller.get_credentials_schema(), + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) @@ -164,7 +164,7 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, - config=provider_controller.get_credentials_schema(), + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) @@ -196,7 +196,7 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, - config=provider_controller.get_credentials_schema(), + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index d68818bbb2..3d1f361088 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -85,7 +85,8 @@ class ToolTransformService: ) # get credentials schema - schema = provider_controller.get_credentials_schema() + schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} + for name, value in schema.items(): if result.masked_credentials: result.masked_credentials[name] = "" @@ -103,7 +104,7 @@ class ToolTransformService: # init tool configuration tool_configuration = ProviderConfigEncrypter( tenant_id=db_provider.tenant_id, - config=provider_controller.get_credentials_schema(), + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) @@ -208,7 +209,7 @@ class ToolTransformService: # init tool configuration tool_configuration = ProviderConfigEncrypter( tenant_id=db_provider.tenant_id, - config=provider_controller.get_credentials_schema(), + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, )