diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 1586d39b41..c3911f31ae 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -10,6 +10,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.entities.request import ( RequestInvokeApp, + RequestInvokeEncrypt, RequestInvokeLLM, RequestInvokeModeration, RequestInvokeNode, @@ -132,6 +133,14 @@ class PluginInvokeAppApi(Resource): PluginAppBackwardsInvocation.convert_to_event_stream(response) ) +class PluginInvokeEncryptApi(Resource): + @setup_required + @plugin_inner_api_only + @get_tenant + @plugin_data(payload_type=RequestInvokeEncrypt) + def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeEncrypt): + """""" + api.add_resource(PluginInvokeLLMApi, '/invoke/llm') api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding') api.add_resource(PluginInvokeRerankApi, '/invoke/rerank') diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 754b792d61..235359a9bb 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -46,6 +46,8 @@ def enterprise_inner_api_user_auth(view): user_id = user_id.split(" ")[1] inner_api_key = request.headers.get("X-Inner-Api-Key") + if not inner_api_key: + raise ValueError("inner api key not found") data_to_sign = f"DIFY {user_id}" diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 15348251f2..73476bef58 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -60,7 +60,7 @@ class QueueIterationStartEvent(AppQueueEvent): node_data: BaseNodeData node_run_index: int - inputs: dict = None + inputs: Optional[dict] = None predecessor_node_id: Optional[str] = None metadata: Optional[dict] = None diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py new file mode 100644 index 0000000000..cc402de773 --- /dev/null +++ b/api/core/entities/parameter_entities.py @@ -0,0 +1,30 @@ +from enum import Enum + + +class CommonParameterType(Enum): + SECRET_INPUT = "secret-input" + TEXT_INPUT = "text-input" + SELECT = "select" + STRING = "string" + NUMBER = "number" + FILE = "file" + BOOLEAN = "boolean" + APP_SELECTOR = "app-selector" + MODEL_CONFIG = "model-config" + + +class AppSelectorScope(Enum): + ALL = "all" + CHAT = "chat" + WORKFLOW = "workflow" + COMPLETION = "completion" + + +class ModelConfigScope(Enum): + LLM = "llm" + TEXT_EMBEDDING = "text-embedding" + RERANK = "rerank" + TTS = "tts" + SPEECH2TEXT = "speech2text" + MODERATION = "moderation" + VISION = "vision" \ No newline at end of file diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 0d5b0a1b2c..ae78d9ecf9 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,8 +1,10 @@ from enum import Enum -from typing import Optional +from typing import Optional, Union -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field +from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope +from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType from models.provider import ProviderQuotaType @@ -100,3 +102,52 @@ class ModelSettings(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + +class BasicProviderConfig(BaseModel): + """ + Base model class for common provider settings like credentials + """ + class Type(Enum): + SECRET_INPUT = CommonParameterType.SECRET_INPUT.value + TEXT_INPUT = CommonParameterType.TEXT_INPUT.value + SELECT = CommonParameterType.SELECT.value + BOOLEAN = CommonParameterType.BOOLEAN.value + APP_SELECTOR = CommonParameterType.APP_SELECTOR.value + MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value + + @classmethod + def value_of(cls, value: str) -> "ProviderConfig.Type": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + @staticmethod + def default(value: str) -> str: + return "" + + type: Type = Field(..., description="The type of the credentials") + name: str = Field(..., description="The name of the credentials") + +class ProviderConfig(BasicProviderConfig): + """ + Model class for common provider settings like credentials + """ + class Option(BaseModel): + value: str = Field(..., description="The value of the option") + label: I18nObject = Field(..., description="The label of the option") + + scope: AppSelectorScope | ModelConfigScope | None + required: bool = False + default: Optional[Union[int, str]] = None + options: Optional[list[Option]] = None + label: Optional[I18nObject] = None + help: Optional[I18nObject] = None + url: Optional[str] = None + placeholder: Optional[I18nObject] = None diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index ea8605ac57..98226e89c0 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,4 +1,9 @@ -tool_file_manager = { +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.tools.tool_file_manager import ToolFileManager + +tool_file_manager: dict[str, Any] = { 'manager': None } diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index fd22d1f057..0533746815 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -1,7 +1,9 @@ +from collections.abc import Mapping from typing import Any, Literal, Optional from pydantic import BaseModel, Field, field_validator +from core.entities.provider_entities import BasicProviderConfig from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -30,11 +32,10 @@ class RequestInvokeLLM(BaseRequestInvokeModel): """ Request to invoke LLM """ - model_type: ModelType = ModelType.LLM mode: str model_parameters: dict[str, Any] = Field(default_factory=dict) - prompt_messages: list[PromptMessage] + prompt_messages: list[PromptMessage] = Field(default_factory=list) tools: Optional[list[PromptMessageTool]] = Field(default_factory=list) stop: Optional[list[str]] = Field(default_factory=list) stream: Optional[bool] = False @@ -105,4 +106,11 @@ class RequestInvokeApp(BaseModel): conversation_id: Optional[str] = None user: Optional[str] = None files: list[dict] = Field(default_factory=list) - \ No newline at end of file + +class RequestInvokeEncrypt(BaseModel): + """ + Request to encryption + """ + opt: Literal["encrypt", "decrypt"] + data: dict = Field(default_factory=dict) + config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2b01b8fd8e..71db8d8b2d 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType +from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType from core.tools.tool.tool import ToolParameter @@ -62,4 +62,4 @@ class UserToolProvider(BaseModel): } class UserToolProviderCredentials(BaseModel): - credentials: dict[str, ToolProviderCredentials] \ No newline at end of file + credentials: dict[str, ProviderConfig] \ No newline at end of file diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index c266db1cdb..98efb92a0d 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -3,6 +3,7 @@ from typing import Any, Optional, Union, cast from pydantic import BaseModel, Field, field_validator +from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope from core.tools.entities.common_entities import I18nObject @@ -137,12 +138,12 @@ class ToolParameterOption(BaseModel): class ToolParameter(BaseModel): class ToolParameterType(str, Enum): - STRING = "string" - NUMBER = "number" - BOOLEAN = "boolean" - SELECT = "select" - SECRET_INPUT = "secret-input" - FILE = "file" + STRING = CommonParameterType.STRING.value + NUMBER = CommonParameterType.NUMBER.value + BOOLEAN = CommonParameterType.BOOLEAN.value + SELECT = CommonParameterType.SELECT.value + SECRET_INPUT = CommonParameterType.SECRET_INPUT.value + FILE = CommonParameterType.FILE.value class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool @@ -151,16 +152,17 @@ class ToolParameter(BaseModel): name: str = Field(..., description="The name of the parameter") label: I18nObject = Field(..., description="The label presented to the user") - human_description: Optional[I18nObject] = Field(None, description="The description presented to the user") - placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user") + human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") + placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") type: ToolParameterType = Field(..., description="The type of the parameter") + scope: AppSelectorScope | ModelConfigScope | None = None form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None required: Optional[bool] = False default: Optional[Union[float, int, str]] = None min: Optional[Union[float, int]] = None max: Optional[Union[float, int]] = None - options: Optional[list[ToolParameterOption]] = None + options: list[ToolParameterOption] = Field(default_factory=list) @classmethod def get_simple_instance(cls, @@ -211,57 +213,6 @@ class ToolIdentity(BaseModel): provider: str = Field(..., description="The provider of the tool") icon: Optional[str] = None -class ToolCredentialsOption(BaseModel): - value: str = Field(..., description="The value of the option") - label: I18nObject = Field(..., description="The label of the option") - -class ToolProviderCredentials(BaseModel): - class CredentialsType(Enum): - SECRET_INPUT = "secret-input" - TEXT_INPUT = "text-input" - SELECT = "select" - BOOLEAN = "boolean" - - @classmethod - def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') - - @staticmethod - def default(value: str) -> str: - return "" - - name: str = Field(..., description="The name of the credentials") - type: CredentialsType = Field(..., description="The type of the credentials") - required: bool = False - default: Optional[Union[int, str]] = None - options: Optional[list[ToolCredentialsOption]] = None - label: Optional[I18nObject] = None - help: Optional[I18nObject] = None - url: Optional[str] = None - placeholder: Optional[I18nObject] = None - - def to_dict(self) -> dict: - return { - 'name': self.name, - 'type': self.type.value, - 'required': self.required, - 'default': self.default, - 'options': self.options, - 'help': self.help.to_dict() if self.help else None, - 'label': self.label.to_dict() if self.label else None, - 'url': self.url, - 'placeholder': self.placeholder.to_dict() if self.placeholder else None, - } - class ToolRuntimeVariableType(Enum): TEXT = "text" IMAGE = "image" diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index ae80ad2114..fc7fcb675a 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -3,8 +3,8 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ProviderConfig, ToolCredentialsOption, - ToolProviderCredentials, ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController @@ -20,10 +20,10 @@ class ApiToolProviderController(ToolProviderController): @staticmethod def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': credentials_schema = { - 'auth_type': ToolProviderCredentials( + 'auth_type': ProviderConfig( name='auth_type', required=True, - type=ToolProviderCredentials.CredentialsType.SELECT, + type=ProviderConfig.Type.SELECT, options=[ ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')), ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) @@ -38,30 +38,30 @@ class ApiToolProviderController(ToolProviderController): if auth_type == ApiProviderAuthType.API_KEY: credentials_schema = { **credentials_schema, - 'api_key_header': ToolProviderCredentials( + 'api_key_header': ProviderConfig( name='api_key_header', required=False, default='api_key', - type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, + type=ProviderConfig.Type.TEXT_INPUT, help=I18nObject( en_US='The header name of the api key', zh_Hans='携带 api key 的 header 名称' ) ), - 'api_key_value': ToolProviderCredentials( + 'api_key_value': ProviderConfig( name='api_key_value', required=True, - type=ToolProviderCredentials.CredentialsType.SECRET_INPUT, + type=ProviderConfig.Type.SECRET_INPUT, help=I18nObject( en_US='The api key', zh_Hans='api key的值' ) ), - 'api_key_header_prefix': ToolProviderCredentials( + 'api_key_header_prefix': ProviderConfig( name='api_key_header_prefix', required=False, default='basic', - type=ToolProviderCredentials.CredentialsType.SELECT, + type=ProviderConfig.Type.SELECT, help=I18nObject( en_US='The prefix of the api key header', zh_Hans='api key header 的前缀' diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py deleted file mode 100644 index 2d472e0a93..0000000000 --- a/api/core/tools/provider/app_tool_provider.py +++ /dev/null @@ -1,115 +0,0 @@ -import logging -from typing import Any - -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType -from core.tools.provider.tool_provider import ToolProviderController -from core.tools.tool.tool import Tool -from extensions.ext_database import db -from models.model import App, AppModelConfig -from models.tools import PublishedAppTool - -logger = logging.getLogger(__name__) - -class AppToolProviderEntity(ToolProviderController): - @property - def provider_type(self) -> ToolProviderType: - return ToolProviderType.APP - - def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: - pass - - def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None: - pass - - def get_tools(self, user_id: str) -> list[Tool]: - db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter( - PublishedAppTool.user_id == user_id, - ).all() - - if not db_tools or len(db_tools) == 0: - return [] - - tools: list[Tool] = [] - - for db_tool in db_tools: - tool = { - 'identity': { - 'author': db_tool.author, - 'name': db_tool.tool_name, - 'label': { - 'en_US': db_tool.tool_name, - 'zh_Hans': db_tool.tool_name - }, - 'icon': '' - }, - 'description': { - 'human': { - 'en_US': db_tool.description_i18n.en_US, - 'zh_Hans': db_tool.description_i18n.zh_Hans - }, - 'llm': db_tool.llm_description - }, - 'parameters': [] - } - # get app from db - app: App = db_tool.app - - if not app: - logger.error(f"app {db_tool.app_id} not found") - continue - - app_model_config: AppModelConfig = app.app_model_config - user_input_form_list = app_model_config.user_input_form_list - for input_form in user_input_form_list: - # get type - form_type = input_form.keys()[0] - default = input_form[form_type]['default'] - required = input_form[form_type]['required'] - label = input_form[form_type]['label'] - variable_name = input_form[form_type]['variable_name'] - options = input_form[form_type].get('options', []) - if form_type == 'paragraph' or form_type == 'text-input': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.STRING, - required=required, - default=default - )) - elif form_type == 'select': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.SELECT, - required=required, - default=default, - options=[ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in options] - )) - - tools.append(Tool(**tool)) - return tools \ No newline at end of file diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index bcf41c90ed..7ad8a5468b 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -2,22 +2,23 @@ from abc import abstractmethod from os import listdir, path from typing import Any +from pydantic import Field + +from core.entities.provider_entities import ProviderConfig from core.helper.module_import_helper import load_single_subclass_from_source -from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType +from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( - ToolNotFoundError, - ToolParameterValidationError, ToolProviderNotFoundError, ) from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.tool.tool import Tool -from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.yaml_utils import load_yaml_file class BuiltinToolProviderController(ToolProviderController): + tools: list[BuiltinTool] = Field(default_factory=list) + def __init__(self, **data: Any) -> None: if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: super().__init__(**data) @@ -41,7 +42,7 @@ class BuiltinToolProviderController(ToolProviderController): 'credentials_schema': provider_yaml.get('credentials_for_provider', None), }) - def _get_builtin_tools(self) -> list[Tool]: + def _get_builtin_tools(self) -> list[BuiltinTool]: """ returns a list of tools that the provider can provide @@ -72,7 +73,7 @@ class BuiltinToolProviderController(ToolProviderController): self.tools = tools return tools - def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: + def get_credentials_schema(self) -> dict[str, ProviderConfig]: """ returns the credentials schema of the provider @@ -83,7 +84,7 @@ class BuiltinToolProviderController(ToolProviderController): return self.credentials_schema.copy() - def get_tools(self) -> list[Tool]: + def get_tools(self) -> list[BuiltinTool]: """ returns a list of tools that the provider can provide @@ -91,24 +92,12 @@ class BuiltinToolProviderController(ToolProviderController): """ return self._get_builtin_tools() - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> BuiltinTool | None: """ returns the tool that the provider can provide """ return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) - def get_parameters(self, tool_name: str) -> list[ToolParameter]: - """ - returns the parameters of the tool - - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters - """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) - if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') - return tool.parameters - @property def need_credentials(self) -> bool: """ @@ -143,67 +132,6 @@ class BuiltinToolProviderController(ToolProviderController): returns the labels of the provider """ return self.identity.tags or [] - - def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: - """ - validate the parameters of the tool and set the default value if needed - - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool - """ - tool_parameters_schema = self.get_parameters(tool_name) - - tool_parameters_need_to_validate: dict[str, ToolParameter] = {} - for parameter in tool_parameters_schema: - tool_parameters_need_to_validate[parameter.name] = parameter - - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - - # check type - parameter_schema = tool_parameters_need_to_validate[parameter] - if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - - elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - - elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - - elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - - options = parameter_schema.options - if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - - tool_parameters_need_to_validate.pop(parameter) - - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] - if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - - # the parameter is not set currently, set the default value if needed - if parameter_schema.default is not None: - default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) - tool_parameters[parameter] = default_value def validate_credentials(self, credentials: dict[str, Any]) -> None: """ diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index ef1ace9c7c..ac770a2a60 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -1,25 +1,23 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field +from core.entities.provider_entities import ProviderConfig from core.tools.entities.tool_entities import ( - ToolParameter, - ToolProviderCredentials, ToolProviderIdentity, ToolProviderType, ) -from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError +from core.tools.errors import ToolProviderCredentialValidationError from core.tools.tool.tool import Tool -from core.tools.utils.tool_parameter_converter import ToolParameterConverter class ToolProviderController(BaseModel, ABC): - identity: Optional[ToolProviderIdentity] = None - tools: Optional[list[Tool]] = None - credentials_schema: Optional[dict[str, ToolProviderCredentials]] = None + identity: ToolProviderIdentity + tools: list[Tool] = Field(default_factory=list) + credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) - def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: + def get_credentials_schema(self) -> dict[str, ProviderConfig]: """ returns the credentials schema of the provider @@ -27,15 +25,6 @@ class ToolProviderController(BaseModel, ABC): """ return self.credentials_schema.copy() - @abstractmethod - def get_tools(self) -> list[Tool]: - """ - returns a list of tools that the provider can provide - - :return: list of tools - """ - pass - @abstractmethod def get_tool(self, tool_name: str) -> Tool: """ @@ -45,18 +34,6 @@ class ToolProviderController(BaseModel, ABC): """ pass - def get_parameters(self, tool_name: str) -> list[ToolParameter]: - """ - returns the parameters of the tool - - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters - """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) - if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') - return tool.parameters - @property def provider_type(self) -> ToolProviderType: """ @@ -66,66 +43,6 @@ class ToolProviderController(BaseModel, ABC): """ return ToolProviderType.BUILT_IN - def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: - """ - validate the parameters of the tool and set the default value if needed - - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool - """ - tool_parameters_schema = self.get_parameters(tool_name) - - tool_parameters_need_to_validate: dict[str, ToolParameter] = {} - for parameter in tool_parameters_schema: - tool_parameters_need_to_validate[parameter.name] = parameter - - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - - # check type - parameter_schema = tool_parameters_need_to_validate[parameter] - if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - - elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - - elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - - elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - - options = parameter_schema.options - if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - - tool_parameters_need_to_validate.pop(parameter) - - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] - if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - - # the parameter is not set currently, set the default value if needed - if parameter_schema.default is not None: - tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) - def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ validate the format of the credentials of the provider and set the default value if needed @@ -136,7 +53,7 @@ class ToolProviderController(BaseModel, ABC): if credentials_schema is None: return - credentials_need_to_validate: dict[str, ToolProviderCredentials] = {} + credentials_need_to_validate: dict[str, ProviderConfig] = {} for credential_name in credentials_schema: credentials_need_to_validate[credential_name] = credentials_schema[credential_name] @@ -146,12 +63,12 @@ class ToolProviderController(BaseModel, ABC): # check type credential_schema = credentials_need_to_validate[credential_name] - if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: + if credential_schema == ProviderConfig.Type.SECRET_INPUT or \ + credential_schema == ProviderConfig.Type.TEXT_INPUT: if not isinstance(credentials[credential_name], str): raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + elif credential_schema.type == ProviderConfig.Type.SELECT: if not isinstance(credentials[credential_name], str): raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') @@ -173,9 +90,9 @@ class ToolProviderController(BaseModel, ABC): if credential_schema.default is not None: default_value = credential_schema.default # parse default value into the correct type - if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + if credential_schema.type == ProviderConfig.Type.SECRET_INPUT or \ + credential_schema.type == ProviderConfig.Type.TEXT_INPUT or \ + credential_schema.type == ProviderConfig.Type.SELECT: default_value = str(default_value) credentials[credential_name] = default_value diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index f14abac767..a84b7a36ed 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -1,5 +1,8 @@ +from collections.abc import Mapping from typing import Optional +from pydantic import Field + from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.tools.entities.common_entities import I18nObject @@ -28,6 +31,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = { class WorkflowToolProviderController(ToolProviderController): provider_id: str + tools: list[WorkflowTool] = Field(default_factory=list) @classmethod def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController': @@ -71,16 +75,17 @@ class WorkflowToolProviderController(ToolProviderController): :param app: the app :return: the tool """ - workflow: Workflow = db.session.query(Workflow).filter( + workflow: Workflow | None = db.session.query(Workflow).filter( Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version ).first() + if not workflow: raise ValueError('workflow not found') # fetch start node - graph: dict = workflow.graph_dict - features_dict: dict = workflow.features_dict + graph: Mapping = workflow.graph_dict + features_dict: Mapping = workflow.features_dict features = WorkflowAppConfigManager.convert_features( config_dict=features_dict, app_mode=AppMode.WORKFLOW @@ -89,7 +94,7 @@ class WorkflowToolProviderController(ToolProviderController): parameters = db_provider.parameter_configurations variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) - def fetch_workflow_variable(variable_name: str) -> VariableEntity: + def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: return next(filter(lambda x: x.variable == variable_name, variables), None) user = db_provider.user @@ -99,7 +104,7 @@ class WorkflowToolProviderController(ToolProviderController): variable = fetch_workflow_variable(parameter.name) if variable: parameter_type = None - options = None + options = [] if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: raise ValueError(f'unsupported variable type {variable.type}') parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] @@ -185,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController): label=db_provider.label ) - def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: + def get_tools(self, tenant_id: str) -> list[WorkflowTool]: """ fetch tools from database @@ -196,7 +201,7 @@ class WorkflowToolProviderController(ToolProviderController): if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == self.provider_id, ).first() diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 68db0d5b2e..6005297118 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -55,7 +55,7 @@ class Tool(BaseModel, ABC): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: Optional[dict[str, Any]] = None - runtime_parameters: Optional[dict[str, Any]] = None + runtime_parameters: dict[str, Any] = Field(default_factory=dict) runtime: Optional[Runtime] = None variables: Optional[ToolRuntimeVariablePool] = None diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 06b6bb9f52..efc2802016 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,7 +4,7 @@ import mimetypes from collections.abc import Generator from os import listdir, path from threading import Lock -from typing import Any, Union +from typing import Any, Union, cast from configs import dify_config from core.agent.entities import AgentToolEntity @@ -22,6 +22,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool +from core.tools.tool.workflow_tool import WorkflowTool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter @@ -57,7 +58,7 @@ class ToolManager: return cls._builtin_providers[provider] @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: + def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None: """ get the builtin tool @@ -78,7 +79,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + -> Union[BuiltinTool, ApiTool, WorkflowTool]: """ get the tool runtime @@ -90,19 +91,21 @@ class ToolManager: """ if provider_type == ToolProviderType.BUILT_IN: builtin_tool = cls.get_builtin_tool(provider_id, tool_name) + if not builtin_tool: + raise ValueError(f"tool {tool_name} not found") # check if the builtin tool need credentials provider_controller = cls.get_builtin_provider(provider_id) if not provider_controller.need_credentials: - return builtin_tool.fork_tool_runtime(runtime={ + return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': {}, 'invoke_from': invoke_from, 'tool_invoke_from': tool_invoke_from, - }) + })) # get credentials - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + builtin_provider: BuiltinToolProvider | None = db.session.query(BuiltinToolProvider).filter( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_id, ).first() @@ -117,13 +120,13 @@ class ToolManager: decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return builtin_tool.fork_tool_runtime(runtime={ + return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': decrypted_credentials, 'runtime_parameters': {}, 'invoke_from': invoke_from, 'tool_invoke_from': tool_invoke_from, - }) + })) elif provider_type == ToolProviderType.API: if tenant_id is None: @@ -135,12 +138,12 @@ class ToolManager: tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ + return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': decrypted_credentials, 'invoke_from': invoke_from, 'tool_invoke_from': tool_invoke_from, - }) + })) elif provider_type == ToolProviderType.WORKFLOW: workflow_provider = db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, @@ -154,12 +157,12 @@ class ToolManager: db_provider=workflow_provider ) - return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ + return cast(WorkflowTool, controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': {}, 'invoke_from': invoke_from, 'tool_invoke_from': tool_invoke_from, - }) + })) elif provider_type == ToolProviderType.APP: raise NotImplementedError('app provider not implemented') else: @@ -220,7 +223,10 @@ class ToolManager: identity_id=f'AGENT.{app_id}' ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) - + + if not tool_entity.runtime: + raise Exception("tool missing runtime") + tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -258,6 +264,9 @@ class ToolManager: if runtime_parameters: runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if not tool_entity.runtime: + raise Exception("tool missing runtime") + tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -304,20 +313,20 @@ class ToolManager: """ list all the builtin providers """ - for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): - if provider.startswith('__'): + for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): + if provider_path.startswith('__'): continue - if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)): - if provider.startswith('__'): + if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider_path)): + if provider_path.startswith('__'): continue # init provider try: provider_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.{provider}', + module_name=f'core.tools.provider.builtin.{provider_path}.{provider_path}', script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), + 'provider', 'builtin', provider_path, f'{provider_path}.py'), parent_type=BuiltinToolProviderController) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider @@ -387,8 +396,8 @@ class ToolManager: for provider in builtin_providers: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore data=provider, name_func=lambda x: x.identity.name ): @@ -461,7 +470,7 @@ class ToolManager: :return: the provider controller, the credentials """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter( ApiToolProvider.id == provider_id, ApiToolProvider.tenant_id == tenant_id, ).first() @@ -486,22 +495,22 @@ class ToolManager: """ get tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + provider_obj: ApiToolProvider| None = db.session.query(ApiToolProvider).filter( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider, ).first() - if provider is None: + if provider_obj is None: raise ValueError(f'you have not added provider {provider}') try: - credentials = json.loads(provider.credentials_str) or {} + credentials = json.loads(provider_obj.credentials_str) or {} except: credentials = {} # package tool provider controller controller = ApiToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) # init tool configuration tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) @@ -510,7 +519,7 @@ class ToolManager: masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) try: - icon = json.loads(provider.icon) + icon = json.loads(provider_obj.icon) except: icon = { "background": "#252525", @@ -521,14 +530,14 @@ class ToolManager: labels = ToolLabelManager.get_tool_labels(controller) return jsonable_encoder({ - 'schema_type': provider.schema_type, - 'schema': provider.schema, - 'tools': provider.tools, + 'schema_type': provider_obj.schema_type, + 'schema': provider_obj.schema, + 'tools': provider_obj.tools, 'icon': icon, - 'description': provider.description, + 'description': provider_obj.description, 'credentials': masked_credentials, - 'privacy_policy': provider.privacy_policy, - 'custom_disclaimer': provider.custom_disclaimer, + 'privacy_policy': provider_obj.privacy_policy, + 'custom_disclaimer': provider_obj.custom_disclaimer, 'labels': labels, }) @@ -551,25 +560,29 @@ class ToolManager: + "/icon") elif provider_type == ToolProviderType.API: try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + api_provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id ).first() - return json.loads(provider.icon) + if not api_provider: + raise ValueError("api tool not found") + + return json.loads(api_provider.icon) except: return { "background": "#252525", "content": "\ud83d\ude01" } elif provider_type == ToolProviderType.WORKFLOW: - provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + workflow_provider: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id ).first() - if provider is None: + + if workflow_provider is None: raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') - return json.loads(provider.icon) + return json.loads(workflow_provider.icon) else: raise ValueError(f"provider type {provider_type} not found") diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 68b0cea24f..2fc0ba3bcd 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -7,8 +7,8 @@ from core.helper import encrypter from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.entities.tool_entities import ( + ProviderConfig, ToolParameter, - ToolProviderCredentials, ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController @@ -36,7 +36,7 @@ class ToolConfigurationManager(BaseModel): # get fields need to be decrypted fields = self.provider_controller.get_credentials_schema() for field_name, field in fields.items(): - if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: + if field.type == ProviderConfig.Type.SECRET_INPUT: if field_name in credentials: encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name]) credentials[field_name] = encrypted @@ -54,7 +54,7 @@ class ToolConfigurationManager(BaseModel): # get fields need to be decrypted fields = self.provider_controller.get_credentials_schema() for field_name, field in fields.items(): - if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: + if field.type == ProviderConfig.Type.SECRET_INPUT: if field_name in credentials: if len(credentials[field_name]) > 6: credentials[field_name] = \ @@ -84,7 +84,7 @@ class ToolConfigurationManager(BaseModel): # get fields need to be decrypted fields = self.provider_controller.get_credentials_schema() for field_name, field in fields.items(): - if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: + if field.type == ProviderConfig.Type.SECRET_INPUT: if field_name in credentials: try: credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name]) diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index ff5505bbbf..b8237fd043 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + from core.app.app_config.entities import VariableEntity from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration @@ -13,7 +15,7 @@ class WorkflowToolConfigurationUtils: raise ValueError('invalid parameter configuration') @classmethod - def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: + def get_workflow_graph_variables(cls, graph: Mapping) -> list[VariableEntity]: """ get workflow graph variables """ @@ -44,5 +46,3 @@ class WorkflowToolConfigurationUtils: for parameter in tool_configurations: if parameter.name not in variable_names: raise ValueError('parameter configuration mismatch, please republish the tool to update') - - return True \ No newline at end of file diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 3ded9c0989..d7538bd812 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -10,8 +10,8 @@ from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, + ProviderConfig, ToolCredentialsOption, - ToolProviderCredentials, ) from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.tool_label_manager import ToolLabelManager @@ -39,9 +39,9 @@ class ApiToolManageService: raise ValueError(f"invalid schema: {str(e)}") credentials_schema = [ - ToolProviderCredentials( + ProviderConfig( name="auth_type", - type=ToolProviderCredentials.CredentialsType.SELECT, + type=ProviderConfig.Type.SELECT, required=True, default="none", options=[ @@ -50,17 +50,17 @@ class ApiToolManageService: ], placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), ), - ToolProviderCredentials( + ProviderConfig( name="api_key_header", - type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, + type=ProviderConfig.Type.TEXT_INPUT, required=False, placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), default="api_key", help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), ), - ToolProviderCredentials( + ProviderConfig( name="api_key_value", - type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, + type=ProviderConfig.Type.TEXT_INPUT, required=False, placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), default="", diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index d072203fba..1848fb2a13 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -8,8 +8,8 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ProviderConfig, ToolParameter, - ToolProviderCredentials, ToolProviderType, ) from core.tools.provider.api_tool_provider import ApiToolProviderController @@ -92,7 +92,7 @@ class ToolTransformService: # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) + result.masked_credentials[name] = ProviderConfig.Type.default(value.type) # check if the provider need credentials if not provider_controller.need_credentials: