diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index d2a17b133b..68ad383a74 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -118,7 +118,9 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): - icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider) + tenant_id = current_user.current_tenant_id + + icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider, tenant_id) icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) @@ -290,7 +292,8 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @login_required @account_initialization_required def get(self, provider): - return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) + tenant_id = current_user.current_tenant_id + return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) class ToolApiProviderSchemaApi(Resource): diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 64075ed231..aea68050bd 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -166,7 +166,7 @@ class BaseAgentRunner(AppRunner): }, ) - parameters = tool_entity.get_all_runtime_parameters() + parameters = tool_entity.get_merged_runtime_parameters() for parameter in parameters: if parameter.form != ToolParameter.ToolParameterForm.LLM: continue diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 0c710c8716..51cc36d7df 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -3,6 +3,8 @@ from typing import Generic, Optional, TypeVar from pydantic import BaseModel +from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin + T = TypeVar("T", bound=(BaseModel | dict | list | bool)) @@ -26,4 +28,11 @@ class InstallPluginMessage(BaseModel): Error = "error" event: Event - data: str \ No newline at end of file + data: str + + +class PluginToolProviderEntity(BaseModel): + provider: str + plugin_unique_identifier: str + plugin_id: str + declaration: ToolProviderEntityWithPlugin \ No newline at end of file diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/manager/base.py index 3f6a87dca8..8b80d923f6 100644 --- a/api/core/plugin/manager/base.py +++ b/api/core/plugin/manager/base.py @@ -93,7 +93,14 @@ class BasePluginManager: Make a request to the plugin daemon inner API and return the response as a model. """ response = self._request(method, path, headers, data, params) - rep = PluginDaemonBasicResponse[type](**response.json()) + json_response = response.json() + for provider in json_response.get("data", []): + declaration = provider.get("declaration", {}) or {} + provider_name = declaration.get("identity", {}).get("name") + for tool in declaration.get("tools", []): + tool["identity"]["provider"] = provider_name + + rep = PluginDaemonBasicResponse[type](**json_response) if rep.code != 0: raise ValueError(f"got error from plugin daemon: {rep.message}, code: {rep.code}") if rep.data is None: diff --git a/api/core/plugin/manager/tool.py b/api/core/plugin/manager/tool.py index f617400355..fe5f7bb757 100644 --- a/api/core/plugin/manager/tool.py +++ b/api/core/plugin/manager/tool.py @@ -1,13 +1,65 @@ +from collections.abc import Generator +from typing import Any + +from core.plugin.entities.plugin_daemon import PluginToolProviderEntity from core.plugin.manager.base import BasePluginManager -from core.tools.entities.tool_entities import ToolProviderEntity +from core.tools.entities.tool_entities import ToolInvokeMessage class PluginToolManager(BasePluginManager): - def fetch_tool_providers(self, tenant_id: str) -> list[ToolProviderEntity]: + def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]: """ Fetch tool providers for the given asset. """ response = self._request_with_plugin_daemon_response( - "GET", f"plugin/{tenant_id}/tools", list[ToolProviderEntity], params={"page": 1, "page_size": 256} + "GET", f"plugin/{tenant_id}/tools", list[PluginToolProviderEntity], params={"page": 1, "page_size": 256} + ) + return response + + def invoke( + self, + tenant_id: str, + user_id: str, + plugin_unique_identifier: str, + tool_provider: str, + tool_name: str, + credentials: dict[str, Any], + tool_parameters: dict[str, Any], + ) -> Generator[ToolInvokeMessage, None, None]: + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/tool/invoke", + ToolInvokeMessage, + data={ + "plugin_unique_identifier": plugin_unique_identifier, + "user_id": user_id, + "data": { + "provider": tool_provider, + "tool": tool_name, + "credentials": credentials, + "tool_parameters": tool_parameters, + }, + }, + ) + return response + + def validate_provider_credentials( + self, tenant_id: str, user_id: str, plugin_unique_identifier: str, provider: str, credentials: dict[str, Any] + ) -> bool: + """ + validate the credentials of the provider + """ + response = self._request_with_plugin_daemon_response( + "POST", + f"plugin/{tenant_id}/tool/validate_credentials", + bool, + data={ + "plugin_unique_identifier": plugin_unique_identifier, + "user_id": user_id, + "data": { + "provider": provider, + "credentials": credentials, + }, + }, ) return response diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 548db51a2a..d0bf2f0c31 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -105,11 +105,11 @@ class Tool(ABC): """ return self.entity.parameters - def get_all_runtime_parameters(self) -> list[ToolParameter]: + def get_merged_runtime_parameters(self) -> list[ToolParameter]: """ - get all runtime parameters + get merged runtime parameters - :return: all runtime parameters + :return: merged runtime parameters """ parameters = self.entity.parameters parameters = parameters.copy() diff --git a/api/core/tools/__base/tool_provider.py b/api/core/tools/__base/tool_provider.py index c71885e48d..795812a109 100644 --- a/api/core/tools/__base/tool_provider.py +++ b/api/core/tools/__base/tool_provider.py @@ -12,11 +12,9 @@ from core.tools.errors import ToolProviderCredentialValidationError class ToolProviderController(ABC): entity: ToolProviderEntity - tools: list[Tool] def __init__(self, entity: ToolProviderEntity) -> None: self.entity = entity - self.tools = [] def get_credentials_schema(self) -> dict[str, ProviderConfig]: """ diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index d4b2ef6104..c9e157cb77 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -16,7 +16,7 @@ class ToolRuntime(BaseModel): tool_id: Optional[str] = None invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None - credentials: Optional[dict[str, Any]] = None + credentials: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 4ebd82f8e7..e7e374f2e6 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -19,9 +19,7 @@ class BuiltinToolProviderController(ToolProviderController): tools: list[BuiltinTool] def __init__(self, **data: Any) -> None: - if self.provider_type == ToolProviderType.API: - super().__init__(**data) - return + self.tools = [] # load provider yaml provider = self.__class__.__module__.split(".")[-1] @@ -76,9 +74,12 @@ class BuiltinToolProviderController(ToolProviderController): parent_type=BuiltinTool, ) tool["identity"]["provider"] = provider - tools.append(assistant_tool_class( - entity=ToolEntity(**tool), runtime=ToolRuntime(tenant_id=""), - )) + tools.append( + assistant_tool_class( + entity=ToolEntity(**tool), + runtime=ToolRuntime(tenant_id=""), + ) + ) self.tools = tools return tools @@ -142,7 +143,7 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.entity.identity.tags or [] - def validate_credentials(self, credentials: dict[str, Any]) -> None: + def validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider @@ -153,10 +154,10 @@ class BuiltinToolProviderController(ToolProviderController): self.validate_credentials_format(credentials) # validate credentials - self._validate_credentials(credentials) + self._validate_credentials(user_id, credentials) @abstractmethod - def _validate_credentials(self, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider diff --git a/api/core/tools/builtin_tool/providers/_positions.py b/api/core/tools/builtin_tool/providers/_positions.py index 5c10f72fda..224b695ff9 100644 --- a/api/core/tools/builtin_tool/providers/_positions.py +++ b/api/core/tools/builtin_tool/providers/_positions.py @@ -1,18 +1,18 @@ import os.path from core.helper.position_helper import get_tool_position_map, sort_by_position_map -from core.tools.entities.api_entities import UserToolProvider +from core.tools.entities.api_entities import ToolProviderApiEntity class BuiltinToolProviderSort: _position = {} @classmethod - def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: + def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]: if not cls._position: cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), "..")) - def name_func(provider: UserToolProvider) -> str: + def name_func(provider: ToolProviderApiEntity) -> str: return provider.name sorted_providers = sort_by_position_map(cls._position, providers, name_func) diff --git a/api/core/tools/builtin_tool/providers/code/code.py b/api/core/tools/builtin_tool/providers/code/code.py index 53210e9c43..18b7cd4c90 100644 --- a/api/core/tools/builtin_tool/providers/code/code.py +++ b/api/core/tools/builtin_tool/providers/code/code.py @@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController class CodeToolProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: pass diff --git a/api/core/tools/builtin_tool/providers/qrcode/qrcode.py b/api/core/tools/builtin_tool/providers/qrcode/qrcode.py index e792382ee3..3999f3b3ef 100644 --- a/api/core/tools/builtin_tool/providers/qrcode/qrcode.py +++ b/api/core/tools/builtin_tool/providers/qrcode/qrcode.py @@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController class QRCodeProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: pass diff --git a/api/core/tools/builtin_tool/providers/time/time.py b/api/core/tools/builtin_tool/providers/time/time.py index d70fc22dfc..323a7c41b8 100644 --- a/api/core/tools/builtin_tool/providers/time/time.py +++ b/api/core/tools/builtin_tool/providers/time/time.py @@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController class WikiPediaProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: pass diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 32eda1d9bc..c5e3e8488e 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -24,9 +24,10 @@ class ApiToolProviderController(ToolProviderController): super().__init__(entity) self.provider_id = provider_id self.tenant_id = tenant_id + self.tools = [] - @staticmethod - def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": + @classmethod + def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType): credentials_schema = { "auth_type": ProviderConfig( name="auth_type", diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index b3a98b4a6d..18db659bb9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -9,7 +9,7 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -class UserTool(BaseModel): +class ToolApiEntity(BaseModel): author: str name: str # identifier label: I18nObject # label @@ -18,10 +18,10 @@ class UserTool(BaseModel): labels: list[str] = Field(default_factory=list) -UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] +ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] -class UserToolProvider(BaseModel): +class ToolProviderApiEntity(BaseModel): id: str author: str name: str # identifier @@ -33,7 +33,7 @@ class UserToolProvider(BaseModel): original_credentials: Optional[dict] = None is_team_authorization: bool = False allow_delete: bool = True - tools: list[UserTool] = Field(default_factory=list) + tools: list[ToolApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) def to_dict(self) -> dict: @@ -63,5 +63,5 @@ class UserToolProvider(BaseModel): } -class UserToolProviderCredentials(BaseModel): +class ToolProviderCredentialsApiEntity(BaseModel): credentials: dict[str, ProviderConfig] diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 2a85c0f882..07ea2d2b11 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -224,6 +224,13 @@ class ToolParameter(BaseModel): max: Optional[Union[float, int]] = None options: list[ToolParameterOption] = Field(default_factory=list) + @field_validator("options", mode="before") + @classmethod + def transform_options(cls, v): + if not isinstance(v, list): + return [] + return v + @classmethod def get_simple_instance( cls, @@ -304,6 +311,9 @@ class ToolEntity(BaseModel): class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) + + +class ToolProviderEntityWithPlugin(ToolProviderEntity): tools: list[ToolEntity] = Field(default_factory=list) diff --git a/api/core/tools/plugin_tool/plugin_tool_provider.py b/api/core/tools/plugin_tool/plugin_tool_provider.py deleted file mode 100644 index 47a78ee318..0000000000 --- a/api/core/tools/plugin_tool/plugin_tool_provider.py +++ /dev/null @@ -1,30 +0,0 @@ - - -from core.entities.provider_entities import ProviderConfig -from core.tools.__base.tool import Tool -from core.tools.__base.tool_provider import ToolProviderController -from core.tools.entities.tool_entities import ToolProviderType - - -class PluginToolProvider(ToolProviderController): - @property - def provider_type(self) -> ToolProviderType: - """ - returns the type of the provider - - :return: type of the provider - """ - return ToolProviderType.PLUGIN - - def get_tool(self, tool_name: str) -> Tool: - """ - return tool with given name - """ - return super().get_tool(tool_name) - - def get_credentials_schema(self) -> dict[str, ProviderConfig]: - """ - get credentials schema - """ - return super().get_credentials_schema() - \ No newline at end of file diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py new file mode 100644 index 0000000000..a52e7c967f --- /dev/null +++ b/api/core/tools/plugin_tool/provider.py @@ -0,0 +1,72 @@ +from typing import Any + +from core.plugin.manager.tool import PluginToolManager +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.plugin_tool.tool import PluginTool + + +class PluginToolProviderController(BuiltinToolProviderController): + entity: ToolProviderEntityWithPlugin + tenant_id: str + plugin_unique_identifier: str + + def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str, plugin_unique_identifier: str) -> None: + self.entity = entity + self.tenant_id = tenant_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.PLUGIN + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + """ + manager = PluginToolManager() + if not manager.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id=user_id, + plugin_unique_identifier=self.plugin_unique_identifier, + provider=self.entity.identity.name, + credentials=credentials, + ): + raise ToolProviderCredentialValidationError("Invalid credentials") + + def get_tool(self, tool_name: str) -> PluginTool: + """ + return tool with given name + """ + tool_entity = next(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name) + + if not tool_entity: + raise ValueError(f"Tool with name {tool_name} not found") + + return PluginTool( + entity=tool_entity, + runtime=ToolRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + plugin_unique_identifier=self.plugin_unique_identifier, + ) + + def get_tools(self) -> list[PluginTool]: + """ + get all tools + """ + return [ + PluginTool( + entity=tool_entity, + runtime=ToolRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + plugin_unique_identifier=self.plugin_unique_identifier, + ) + for tool_entity in self.entity.tools + ] diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py new file mode 100644 index 0000000000..898e8a552d --- /dev/null +++ b/api/core/tools/plugin_tool/tool.py @@ -0,0 +1,41 @@ +from collections.abc import Generator +from typing import Any + +from core.plugin.manager.tool import PluginToolManager +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType + + +class PluginTool(Tool): + tenant_id: str + plugin_unique_identifier: str + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, plugin_unique_identifier: str) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.PLUGIN + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: + manager = PluginToolManager() + return manager.invoke( + tenant_id=self.tenant_id, + user_id=user_id, + plugin_unique_identifier=self.plugin_unique_identifier, + tool_provider=self.entity.identity.provider, + tool_name=self.entity.identity.name, + credentials=self.runtime.credentials, + tool_parameters=tool_parameters, + ) + + def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool": + return PluginTool( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + plugin_unique_identifier=self.plugin_unique_identifier, + ) \ No newline at end of file diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index c37ee730c8..76c7232e01 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -6,7 +6,10 @@ from os import listdir, path from threading import Lock from typing import TYPE_CHECKING, Any, Union, cast +from core.plugin.manager.tool import PluginToolManager from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.plugin_tool.tool import PluginTool if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity @@ -24,7 +27,7 @@ from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.builtin_tool.tool import BuiltinTool from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.tool import ApiTool -from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral +from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType from core.tools.errors import ToolProviderNotFoundError @@ -41,38 +44,61 @@ logger = logging.getLogger(__name__) class ToolManager: _builtin_provider_lock = Lock() - _builtin_providers = {} + _hardcoded_providers = {} _builtin_providers_loaded = False _builtin_tools_labels = {} @classmethod - def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: + def get_builtin_provider( + cls, provider: str, tenant_id: str + ) -> BuiltinToolProviderController | PluginToolProviderController: """ get the builtin provider :param provider: the name of the provider + :param tenant_id: the id of the tenant :return: the provider """ - if len(cls._builtin_providers) == 0: + if len(cls._hardcoded_providers) == 0: # init the builtin providers - cls.load_builtin_providers_cache() + cls.load_hardcoded_providers_cache() - if provider not in cls._builtin_providers: - raise ToolProviderNotFoundError(f"builtin provider {provider} not found") + if provider not in cls._hardcoded_providers: + # get plugin provider + plugin_provider = cls.get_plugin_provider(provider, tenant_id) + if plugin_provider: + return plugin_provider - return cls._builtin_providers[provider] + return cls._hardcoded_providers[provider] @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None: + def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController: + """ + get the plugin provider + """ + manager = PluginToolManager() + providers = manager.fetch_tool_providers(tenant_id) + provider_entity = next((x for x in providers if x.declaration.identity.name == provider), None) + if not provider_entity: + raise ToolProviderNotFoundError(f"plugin provider {provider} not found") + + return PluginToolProviderController( + entity=provider_entity.declaration, + tenant_id=tenant_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + ) + + @classmethod + def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: """ get the builtin tool :param provider: the name of the provider :param tool_name: the name of the tool - + :param tenant_id: the id of the tenant :return: the provider, the tool """ - provider_controller = cls.get_builtin_provider(provider) + provider_controller = cls.get_builtin_provider(provider, tenant_id) tool = provider_controller.get_tool(tool_name) return tool @@ -97,12 +123,12 @@ class ToolManager: :return: the tool """ if provider_type == ToolProviderType.BUILT_IN: - builtin_tool = cls.get_builtin_tool(provider_id, tool_name) + builtin_tool = cls.get_builtin_tool(provider_id, tool_name, tenant_id) if not builtin_tool: raise ValueError(f"tool {tool_name} not found") # check if the builtin tool need credentials - provider_controller = cls.get_builtin_provider(provider_id) + provider_controller = cls.get_builtin_provider(provider_id, tenant_id) if not provider_controller.need_credentials: return cast( BuiltinTool, @@ -131,7 +157,7 @@ class ToolManager: # decrypt the credentials credentials = builtin_provider.credentials - controller = cls.get_builtin_provider(provider_id) + controller = cls.get_builtin_provider(provider_id, tenant_id) tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, config=controller.get_credentials_schema(), @@ -246,7 +272,7 @@ class ToolManager: tool_invoke_from=ToolInvokeFrom.AGENT, ) runtime_parameters = {} - parameters = tool_entity.get_all_runtime_parameters() + parameters = tool_entity.get_merged_runtime_parameters() for parameter in parameters: # check file types if parameter.type == ToolParameter.ToolParameterType.FILE: @@ -294,7 +320,7 @@ class ToolManager: tool_invoke_from=ToolInvokeFrom.WORKFLOW, ) runtime_parameters = {} - parameters = tool_entity.get_all_runtime_parameters() + parameters = tool_entity.get_merged_runtime_parameters() for parameter in parameters: # save tool parameter to tool entity memory @@ -321,16 +347,17 @@ class ToolManager: return tool_entity @classmethod - def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: + def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]: """ get the absolute path of the icon of the builtin provider :param provider: the name of the provider + :param tenant_id: the id of the tenant :return: the absolute path of the icon, the mime type of the icon """ # get provider - provider_controller = cls.get_builtin_provider(provider) + provider_controller = cls.get_builtin_provider(provider, tenant_id) absolute_path = path.join( path.dirname(path.realpath(__file__)), @@ -351,21 +378,48 @@ class ToolManager: return absolute_path, mime_type @classmethod - def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: + def list_hardcoded_providers(cls): # use cache first if cls._builtin_providers_loaded: - yield from list(cls._builtin_providers.values()) + yield from list(cls._hardcoded_providers.values()) return with cls._builtin_provider_lock: if cls._builtin_providers_loaded: - yield from list(cls._builtin_providers.values()) + yield from list(cls._hardcoded_providers.values()) return - yield from cls._list_builtin_providers() + yield from cls._list_hardcoded_providers() @classmethod - def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: + def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]: + """ + list all the plugin providers + """ + manager = PluginToolManager() + provider_entities = manager.fetch_tool_providers(tenant_id) + return [ + PluginToolProviderController( + entity=provider.declaration, + tenant_id=tenant_id, + plugin_unique_identifier=provider.plugin_unique_identifier, + ) + for provider in provider_entities + ] + + @classmethod + def list_builtin_providers( + cls, tenant_id: str + ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]: + """ + list all the builtin providers + """ + yield from cls.list_hardcoded_providers() + # get plugin providers + yield from cls.list_plugin_providers(tenant_id) + + @classmethod + def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ list all the builtin providers """ @@ -391,7 +445,7 @@ class ToolManager: parent_type=BuiltinToolProviderController, ) provider: BuiltinToolProviderController = provider_class() - cls._builtin_providers[provider.entity.identity.name] = provider + cls._hardcoded_providers[provider.entity.identity.name] = provider for tool in provider.get_tools(): cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label yield provider @@ -403,13 +457,13 @@ class ToolManager: cls._builtin_providers_loaded = True @classmethod - def load_builtin_providers_cache(cls): - for _ in cls.list_builtin_providers(): + def load_hardcoded_providers_cache(cls): + for _ in cls.list_hardcoded_providers(): pass @classmethod - def clear_builtin_providers_cache(cls): - cls._builtin_providers = {} + def clear_hardcoded_providers_cache(cls): + cls._hardcoded_providers = {} cls._builtin_providers_loaded = False @classmethod @@ -423,7 +477,7 @@ class ToolManager: """ if len(cls._builtin_tools_labels) == 0: # init the builtin providers - cls.load_builtin_providers_cache() + cls.load_hardcoded_providers_cache() if tool_name not in cls._builtin_tools_labels: return None @@ -432,9 +486,9 @@ class ToolManager: @classmethod def user_list_providers( - cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral - ) -> list[UserToolProvider]: - result_providers: dict[str, UserToolProvider] = {} + cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral + ) -> list[ToolProviderApiEntity]: + result_providers: dict[str, ToolProviderApiEntity] = {} filters = [] if not typ: @@ -444,7 +498,7 @@ class ToolManager: if "builtin" in filters: # get builtin providers - builtin_providers = cls.list_builtin_providers() + builtin_providers = cls.list_builtin_providers(tenant_id) # get db builtin providers db_builtin_providers: list[BuiltinToolProvider] = ( @@ -666,4 +720,4 @@ class ToolManager: raise ValueError(f"provider type {provider_type} not found") -ToolManager.load_builtin_providers_cache() +ToolManager.load_hardcoded_providers_cache() diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index e1fc5140d0..d8cddd02d8 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -167,7 +167,7 @@ class WorkflowTool(Tool): :param tool_parameters: the tool parameters :return: tool_parameters, files """ - parameter_rules = self.get_all_runtime_parameters() + parameter_rules = self.get_merged_runtime_parameters() parameters_result = {} files = [] for parameter in parameter_rules: diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index e39c2b8a5b..58f0e7bbf5 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -7,7 +7,7 @@ from core.entities.provider_entities import ProviderConfig from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( @@ -201,7 +201,7 @@ class ApiToolManageService: return {"schema": schema} @staticmethod - def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]: + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]: """ list api tool provider tools """ @@ -438,7 +438,7 @@ class ApiToolManageService: return {"result": result or "empty response"} @staticmethod - def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: + def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: """ list api tools """ @@ -447,7 +447,7 @@ class ApiToolManageService: db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] ) - result: list[UserToolProvider] = [] + result: list[ToolProviderApiEntity] = [] for provider in db_providers: # convert provider controller to user provider diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 83b363bb58..c3d778558b 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -5,9 +5,8 @@ from pathlib import Path from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort -from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager @@ -21,11 +20,17 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: @staticmethod - def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: + def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[ToolApiEntity]: """ list builtin tool provider tools + + :param user_id: the id of the user + :param tenant_id: the id of the tenant + :param provider: the name of the provider + + :return: the list of tools """ - provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() tool_provider_configurations = ProviderConfigEncrypter( @@ -64,14 +69,16 @@ class BuiltinToolManageService: return result @staticmethod - def list_builtin_provider_credentials_schema(provider_name): + def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): """ list builtin provider credentials schema + :param provider_name: the name of the provider + :param tenant_id: the id of the tenant :return: the list of tool providers """ - provider = ToolManager.get_builtin_provider(provider_name) - return jsonable_encoder([v for _, v in (provider.entity.credentials_schema or {}).items()]) + provider = ToolManager.get_builtin_provider(provider_name, tenant_id) + return jsonable_encoder([v for _, v in (provider.get_credentials_schema() or {}).items()]) @staticmethod def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): @@ -90,7 +97,7 @@ class BuiltinToolManageService: try: # get provider - provider_controller = ToolManager.get_builtin_provider(provider_name) + provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider_name} does not need credentials") tool_configuration = ProviderConfigEncrypter( @@ -109,7 +116,7 @@ class BuiltinToolManageService: if name in masked_credentials and value == masked_credentials[name]: credentials[name] = original_credentials[name] # validate credentials - provider_controller.validate_credentials(credentials) + provider_controller.validate_credentials(user_id, credentials) # encrypt credentials credentials = tool_configuration.encrypt(credentials) except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e: @@ -154,7 +161,7 @@ class BuiltinToolManageService: if provider_obj is None: return {} - provider_controller = ToolManager.get_builtin_provider(provider_obj.provider) + provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, config=provider_controller.get_credentials_schema(), @@ -186,7 +193,7 @@ class BuiltinToolManageService: db.session.commit() # delete cache - provider_controller = ToolManager.get_builtin_provider(provider_name) + provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, config=provider_controller.get_credentials_schema(), @@ -198,22 +205,22 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_provider_icon(provider: str): + def get_builtin_tool_provider_icon(provider: str, tenant_id: str): """ get tool provider icon and it's mimetype """ - icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) + icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider, tenant_id) icon_bytes = Path(icon_path).read_bytes() return icon_bytes, mime_type @staticmethod - def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: + def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: """ list builtin tools """ # get all builtin providers - provider_controllers = ToolManager.list_builtin_providers() + provider_controllers = ToolManager.list_builtin_providers(tenant_id) # get all user added providers db_providers: list[BuiltinToolProvider] = ( @@ -225,7 +232,7 @@ class BuiltinToolManageService: filter(lambda db_provider: db_provider.provider == provider, db_providers), None ) - result: list[UserToolProvider] = [] + result: list[ToolProviderApiEntity] = [] for provider_controller in provider_controllers: try: diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 1c67f7648c..184596fc23 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,6 +1,6 @@ import logging -from core.tools.entities.api_entities import UserToolProviderTypeLiteral +from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager from services.tools.tools_transform_service import ToolTransformService @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) class ToolCommonService: @staticmethod - def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): + def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None): """ list tool providers diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index d4f132e902..4c2876dca3 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -7,7 +7,7 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( @@ -15,6 +15,7 @@ from core.tools.entities.tool_entities import ( ToolParameter, ToolProviderType, ) +from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool @@ -44,7 +45,7 @@ class ToolTransformService: return "" @staticmethod - def repack_provider(provider: Union[dict, UserToolProvider]): + def repack_provider(provider: Union[dict, ToolProviderApiEntity]): """ repack provider @@ -54,7 +55,7 @@ class ToolTransformService: provider["icon"] = ToolTransformService.get_tool_provider_icon_url( provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) - elif isinstance(provider, UserToolProvider): + elif isinstance(provider, ToolProviderApiEntity): provider.icon = ToolTransformService.get_tool_provider_icon_url( provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) @@ -62,14 +63,14 @@ class ToolTransformService: @classmethod def builtin_provider_to_user_provider( cls, - provider_controller: BuiltinToolProviderController, + provider_controller: BuiltinToolProviderController | PluginToolProviderController, db_provider: Optional[BuiltinToolProvider], decrypt_credentials: bool = True, - ) -> UserToolProvider: + ) -> ToolProviderApiEntity: """ convert provider controller to user provider """ - result = UserToolProvider( + result = ToolProviderApiEntity( id=provider_controller.entity.identity.name, author=provider_controller.entity.identity.author, name=provider_controller.entity.identity.name, @@ -154,7 +155,7 @@ class ToolTransformService: """ convert provider controller to user provider """ - return UserToolProvider( + return ToolProviderApiEntity( id=provider_controller.provider_id, author=provider_controller.entity.identity.author, name=provider_controller.entity.identity.name, @@ -181,7 +182,7 @@ class ToolTransformService: db_provider: ApiToolProvider, decrypt_credentials: bool = True, labels: list[str] | None = None, - ) -> UserToolProvider: + ) -> ToolProviderApiEntity: """ convert provider controller to user provider """ @@ -197,7 +198,7 @@ class ToolTransformService: # add provider into providers credentials = db_provider.credentials - result = UserToolProvider( + result = ToolProviderApiEntity( id=db_provider.id, author=username, name=db_provider.name, @@ -240,7 +241,7 @@ class ToolTransformService: tenant_id: str, credentials: dict | None = None, labels: list[str] | None = None, - ) -> UserTool: + ) -> ToolApiEntity: """ convert tool to user tool """ @@ -248,7 +249,7 @@ class ToolTransformService: # fork tool runtime tool = tool.fork_tool_runtime( runtime=ToolRuntime( - credentials=credentials, + credentials=credentials or {}, tenant_id=tenant_id, ) ) @@ -270,7 +271,7 @@ class ToolTransformService: if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: current_parameters.append(runtime_parameter) - return UserTool( + return ToolApiEntity( author=tool.entity.identity.author, name=tool.entity.identity.name, label=tool.entity.identity.label, @@ -279,7 +280,7 @@ class ToolTransformService: labels=labels or [], ) if isinstance(tool, ApiToolBundle): - return UserTool( + return ToolApiEntity( author=tool.author, name=tool.operation_id, label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 58bf7946bf..3178fe7999 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -4,7 +4,7 @@ from datetime import datetime from sqlalchemy import or_ from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController @@ -183,7 +183,7 @@ class WorkflowToolManageService: return {"result": "success"} @classmethod - def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: + def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: """ List workflow tools. :param user_id: the user id @@ -309,7 +309,7 @@ class WorkflowToolManageService: } @classmethod - def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]: + def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]: """ List workflow tool provider tools. :param user_id: the user id