From 26b46b88c9acb6ddedd213dee5beaee7d6ddb26b Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 4 Jul 2025 14:25:33 +0800 Subject: [PATCH] feat(oauth): add multi credentials support --- api/core/plugin/impl/tool.py | 4 +- api/core/tools/__base/tool_runtime.py | 3 +- api/core/tools/plugin_tool/tool.py | 1 + api/core/tools/tool_manager.py | 48 +++++++++++++++--------- api/core/workflow/nodes/tool/entities.py | 1 + api/services/app_dsl_service.py | 5 +++ 6 files changed, 42 insertions(+), 20 deletions(-) diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe3..f84e8c6c5e 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderCredentialType class PluginToolManager(BasePluginClient): @@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient): tool_provider: str, tool_name: str, credentials: dict[str, Any], + credential_type: ToolProviderCredentialType, tool_parameters: dict[str, Any], conversation_id: Optional[str] = None, app_id: Optional[str] = None, @@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient): "provider": tool_provider_id.provider_name, "tool": tool_name, "credentials": credentials, + "credential_type": credential_type, "tool_parameters": tool_parameters, }, }, diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index c9e157cb77..51e339bed1 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -4,7 +4,7 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom +from core.tools.entities.tool_entities import ToolInvokeFrom, ToolProviderCredentialType class ToolRuntime(BaseModel): @@ -17,6 +17,7 @@ class ToolRuntime(BaseModel): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) + credential_type: Optional[ToolProviderCredentialType] = ToolProviderCredentialType.API_KEY runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index d21e3d7d1c..aef2677c36 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -44,6 +44,7 @@ class PluginTool(Tool): tool_provider=self.entity.identity.provider, tool_name=self.entity.identity.name, credentials=self.runtime.credentials, + credential_type=self.runtime.credential_type, tool_parameters=tool_parameters, conversation_id=conversation_id, app_id=app_id, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index e9423a6c49..7e37192979 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,7 +4,7 @@ import mimetypes from collections.abc import Generator from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from yarl import URL @@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolInvokeFrom, ToolParameter, + ToolProviderCredentialType, ToolProviderType, ) from core.tools.errors import ToolProviderNotFoundError @@ -148,6 +149,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + credential_id: Optional[str] = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: """ get the tool runtime @@ -158,6 +160,7 @@ class ToolManager: :param tenant_id: the tenant id :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from + :param credential_id: the credential id :return: the tool """ @@ -185,19 +188,31 @@ class ToolManager: if isinstance(provider_controller, PluginToolProviderController): provider_id_entity = ToolProviderID(provider_id) # get credentials - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + if credential_id: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {credential_id} not found") + else: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .first() ) - .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() - ) - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") else: builtin_provider = ( db.session.query(BuiltinToolProvider) @@ -209,8 +224,6 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - # decrypt the credentials - credentials = builtin_provider.credentials encrypter, _ = create_encrypter( tenant_id=tenant_id, config=[ @@ -221,15 +234,13 @@ class ToolManager: tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id ), ) - - decrypted_credentials = encrypter.decrypt(credentials) - return cast( BuiltinTool, builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(builtin_provider.credentials), + credential_type=ToolProviderCredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -362,6 +373,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=workflow_tool.credential_id, ) runtime_parameters = {} parameters = tool_runtime.get_merged_runtime_parameters() diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 21023d4ab7..2ce6ac3fc1 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -14,6 +14,7 @@ class ToolEntity(BaseModel): tool_name: str tool_label: str # redundancy tool_configurations: dict[str, Any] + credential_id: str | None = None plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 20257fa345..f53048a690 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -582,6 +582,11 @@ class AppDslService: cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] + # filter credential id from tool node + if node.get("data", {}).get("type", "") == NodeType.TOOL.value: + node["data"]["credential_id"] = None + + export_data["workflow"] = workflow_dict dependencies = cls._extract_dependencies_from_workflow(workflow) export_data["dependencies"] = [