diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d45d45c520..2593e381cf 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -205,16 +205,160 @@ class ToolManager: :return: the tool """ - if provider_type == ToolProviderType.BUILT_IN: - # check if the builtin tool need credentials - provider_controller = cls.get_builtin_provider(provider_id, tenant_id) + match provider_type: + case ToolProviderType.BUILT_IN: + provider_controller = cls.get_builtin_provider(provider_id, tenant_id) - builtin_tool = provider_controller.get_tool(tool_name) - if not builtin_tool: - raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") + builtin_tool = provider_controller.get_tool(tool_name) + if not builtin_tool: + raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") + + if not provider_controller.need_credentials: + return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + user_id=user_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ) + builtin_provider = None + if isinstance(provider_controller, PluginToolProviderController): + provider_id_entity = ToolProviderID(provider_id) + if is_valid_uuid(credential_id): + try: + builtin_provider_stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + builtin_provider = db.session.scalar(builtin_provider_stmt) + except Exception as e: + builtin_provider = None + logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") + + if builtin_provider is None: + with Session(db.engine) as session: + builtin_provider = session.scalar( + sa.select(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + ) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"no default provider for {provider_id}") + else: + builtin_provider = db.session.scalar( + select(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id) + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .limit(1) + ) + + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=builtin_provider.id, + provider=provider_id, + credential_type=PluginCredentialType.TOOL, + check_existence=False, + ) + + encrypter, cache = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) + ], + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id + ), + ) + + decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) + + if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): + # TODO: circular import + from core.plugin.impl.oauth import OAuthHandler + from services.tools.builtin_tools_manage_service import BuiltinToolManageService + + tool_provider = ToolProviderID(provider_id) + provider_name = tool_provider.provider_name + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" + system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) + + oauth_handler = OAuthHandler() + refreshed_credentials = oauth_handler.refresh_credentials( + tenant_id=tenant_id, + user_id=builtin_provider.user_id, + plugin_id=tool_provider.plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=decrypted_credentials, + ) + # update the credentials + builtin_provider.encrypted_credentials = json.dumps( + encrypter.encrypt(refreshed_credentials.credentials) + ) + builtin_provider.expires_at = refreshed_credentials.expires_at + db.session.commit() + decrypted_credentials = refreshed_credentials.credentials + cache.delete() - if not provider_controller.need_credentials: return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + user_id=user_id, + credentials=dict(decrypted_credentials), + credential_type=builtin_provider.credential_type, + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ) + + case ToolProviderType.API: + api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) + encrypter, _ = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=api_provider, + ) + return api_provider.get_tool(tool_name).fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + user_id=user_id, + credentials=dict(encrypter.decrypt(credentials)), + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ) + case ToolProviderType.WORKFLOW: + workflow_provider_stmt = select(WorkflowToolProvider).where( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id + ) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + workflow_provider = session.scalar(workflow_provider_stmt) + + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, user_id=user_id, @@ -223,177 +367,28 @@ class ToolManager: tool_invoke_from=tool_invoke_from, ) ) - builtin_provider = None - if isinstance(provider_controller, PluginToolProviderController): - provider_id_entity = ToolProviderID(provider_id) - # get specific credentials - if is_valid_uuid(credential_id): - try: - builtin_provider_stmt = select(BuiltinToolProvider).where( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.id == credential_id, - ) - builtin_provider = db.session.scalar(builtin_provider_stmt) - except Exception as e: - builtin_provider = None - logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) - # if the provider has been deleted, raise an error - if builtin_provider is None: - raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") - - # fallback to the default provider - if builtin_provider is None: - # use the default provider - with Session(db.engine) as session: - builtin_provider = session.scalar( - sa.select(BuiltinToolProvider) - .where( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), - ) - .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - ) - if builtin_provider is None: - raise ToolProviderNotFoundError(f"no default provider for {provider_id}") - else: - builtin_provider = db.session.scalar( - select(BuiltinToolProvider) - .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) - .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .limit(1) - ) - - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - - # check if the credential is allowed to be used - from core.helper.credential_utils import check_credential_policy_compliance - - check_credential_policy_compliance( - credential_id=builtin_provider.id, - provider=provider_id, - credential_type=PluginCredentialType.TOOL, - check_existence=False, - ) - - encrypter, cache = create_provider_encrypter( - tenant_id=tenant_id, - config=[ - x.to_basic_provider_config() - for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) - ], - cache=ToolProviderCredentialsCache( - tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id - ), - ) - - # decrypt the credentials - decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) - - # check if the credentials is expired - if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): - # TODO: circular import - from core.plugin.impl.oauth import OAuthHandler - from services.tools.builtin_tools_manage_service import BuiltinToolManageService - - # refresh the credentials - tool_provider = ToolProviderID(provider_id) - provider_name = tool_provider.provider_name - redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" - system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) - - oauth_handler = OAuthHandler() - # refresh the credentials - refreshed_credentials = oauth_handler.refresh_credentials( - tenant_id=tenant_id, - user_id=builtin_provider.user_id, - plugin_id=tool_provider.plugin_id, - provider=provider_name, - redirect_uri=redirect_uri, - system_credentials=system_credentials or {}, - credentials=decrypted_credentials, - ) - # update the credentials - builtin_provider.encrypted_credentials = json.dumps( - encrypter.encrypt(refreshed_credentials.credentials) - ) - builtin_provider.expires_at = refreshed_credentials.expires_at - db.session.commit() - decrypted_credentials = refreshed_credentials.credentials - cache.delete() - - return builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - user_id=user_id, - credentials=dict(decrypted_credentials), - credential_type=builtin_provider.credential_type, - runtime_parameters={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ) - - elif provider_type == ToolProviderType.API: - api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - encrypter, _ = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=api_provider, - ) - return api_provider.get_tool(tool_name).fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - user_id=user_id, - credentials=dict(encrypter.decrypt(credentials)), - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ) - elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider_stmt = select(WorkflowToolProvider).where( - WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id - ) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): - workflow_provider = session.scalar(workflow_provider_stmt) - - if workflow_provider is None: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) - controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) - if controller_tools is None or len(controller_tools) == 0: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - user_id=user_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ) - elif provider_type == ToolProviderType.APP: - raise NotImplementedError("app provider not implemented") - elif provider_type == ToolProviderType.PLUGIN: - plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) - runtime = getattr(plugin_tool, "runtime", None) - if runtime is not None: - runtime.user_id = user_id - runtime.invoke_from = invoke_from - runtime.tool_invoke_from = tool_invoke_from - return plugin_tool - elif provider_type == ToolProviderType.MCP: - mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) - runtime = getattr(mcp_tool, "runtime", None) - if runtime is not None: - runtime.user_id = user_id - runtime.invoke_from = invoke_from - runtime.tool_invoke_from = tool_invoke_from - return mcp_tool - else: - raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + case ToolProviderType.APP: + raise NotImplementedError("app provider not implemented") + case ToolProviderType.PLUGIN: + plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + runtime = getattr(plugin_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return plugin_tool + case ToolProviderType.MCP: + mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + runtime = getattr(mcp_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return mcp_tool + case ToolProviderType.DATASET_RETRIEVAL: + raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + case _: + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod def get_agent_tool_runtime( @@ -1027,31 +1022,31 @@ class ToolManager: :param provider_id: the id of the provider :return: """ - provider_type = provider_type - provider_id = provider_id - if provider_type == ToolProviderType.BUILT_IN: - provider = ToolManager.get_builtin_provider(provider_id, tenant_id) - if isinstance(provider, PluginToolProviderController): + match provider_type: + case ToolProviderType.BUILT_IN: + provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + if isinstance(provider, PluginToolProviderController): + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + return cls.generate_builtin_tool_icon_url(provider_id) + case ToolProviderType.API: + return cls.generate_api_tool_icon_url(tenant_id, provider_id) + case ToolProviderType.WORKFLOW: + return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) + case ToolProviderType.PLUGIN: + provider = ToolManager.get_plugin_provider(provider_id, tenant_id) try: return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} - return cls.generate_builtin_tool_icon_url(provider_id) - elif provider_type == ToolProviderType.API: - return cls.generate_api_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.WORKFLOW: - return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.PLUGIN: - provider = ToolManager.get_plugin_provider(provider_id, tenant_id) - try: - return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - raise ValueError(f"plugin provider {provider_id} not found") - elif provider_type == ToolProviderType.MCP: - return cls.generate_mcp_tool_icon_url(tenant_id, provider_id) - else: - raise ValueError(f"provider type {provider_type} not found") + case ToolProviderType.MCP: + return cls.generate_mcp_tool_icon_url(tenant_id, provider_id) + case ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL: + raise ValueError(f"provider type {provider_type} not found") + case _: + raise ValueError(f"provider type {provider_type} not found") @classmethod def _convert_tool_parameters_type(