From 7de3436e6bdcdddfd41c748a5af8de9891e588b8 Mon Sep 17 00:00:00 2001 From: Harry Date: Sun, 13 Jul 2025 01:44:29 +0800 Subject: [PATCH] feat(oauth): add credential handling and context support for tool invocations --- api/controllers/inner_api/plugin/plugin.py | 1 + api/core/agent/entities.py | 1 + api/core/agent/strategy/base.py | 5 ++- api/core/agent/strategy/plugin.py | 5 +++ api/core/plugin/backwards_invocation/tool.py | 5 +-- api/core/plugin/entities/request.py | 15 +++++++++ api/core/plugin/impl/agent.py | 3 ++ api/core/tools/entities/tool_entities.py | 1 + api/core/tools/tool_manager.py | 3 ++ api/core/workflow/nodes/agent/agent_node.py | 34 +++++++++++++++++++- 10 files changed, 69 insertions(+), 4 deletions(-) diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 327e9ce834..6c4500af50 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource): provider=payload.provider, tool_name=payload.tool, tool_parameters=payload.tool_parameters, + credential_id=payload.credential_id ), ) diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 143a3a51aa..a31c1050bd 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel): tool_name: str tool_parameters: dict[str, Any] = Field(default_factory=dict) plugin_unique_identifier: str | None = None + credential_id: str | None = None class AgentPromptEntity(BaseModel): diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py index ead81a7a0e..a52a1dfd7a 100644 --- a/api/core/agent/strategy/base.py +++ b/api/core/agent/strategy/base.py @@ -4,6 +4,7 @@ from typing import Any, Optional from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyParameter +from core.plugin.entities.request import InvokeCredentials class BaseAgentStrategy(ABC): @@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. """ - yield from self._invoke(params, user_id, conversation_id, app_id, message_id) + yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials) def get_parameters(self) -> Sequence[AgentStrategyParameter]: """ @@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: pass diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index 4cfcfbf86a..90e5782c47 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Optional from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter from core.agent.strategy.base import BaseAgentStrategy +from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext from core.plugin.impl.agent import PluginAgentClient from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. @@ -58,4 +60,7 @@ class PluginAgentStrategy(BaseAgentStrategy): conversation_id=conversation_id, app_id=app_id, message_id=message_id, + context=PluginInvokeContext( + credentials=credentials or InvokeCredentials() + ), ) diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index 1d62743f13..06773504d9 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, Optional from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.backwards_invocation.base import BaseBackwardsInvocation @@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): provider: str, tool_name: str, tool_parameters: dict[str, Any], + credential_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tool @@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters + tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 89f595ec46..3a783dad3e 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import ( ) +class InvokeCredentials(BaseModel): + tool_credentials: dict[str, str] = Field( + default_factory=dict, + description="Map of tool provider to credential id, used to store the credential id for the tool provider.", + ) + + +class PluginInvokeContext(BaseModel): + credentials: Optional[InvokeCredentials] = Field( + default_factory=InvokeCredentials, + description="Credentials context for the plugin invocation or backward invocation.", + ) + + class RequestInvokeTool(BaseModel): """ Request to invoke a tool @@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel): provider: str tool: str tool_parameters: dict + credential_id: Optional[str] = None class BaseRequestInvokeModel(BaseModel): diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py index 66b77c7489..9575c57ac8 100644 --- a/api/core/plugin/impl/agent.py +++ b/api/core/plugin/impl/agent.py @@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID from core.plugin.entities.plugin_daemon import ( PluginAgentProviderEntity, ) +from core.plugin.entities.request import PluginInvokeContext from core.plugin.impl.base import BasePluginClient @@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + context: Optional[PluginInvokeContext] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent with the given tenant, user, plugin, provider, name and parameters. @@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient): "conversation_id": conversation_id, "app_id": app_id, "message_id": message_id, + "context": context.model_dump() if context else {}, "data": { "agent_strategy_provider": agent_provider_id.provider_name, "agent_strategy": agent_strategy, diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 2e53f1ae73..08fd632ba5 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -446,6 +446,7 @@ class ToolSelector(BaseModel): options: Optional[list[PluginParameterOption]] = None provider_id: str = Field(..., description="The id of the provider") + credential_id: Optional[str] = Field(default=None, description="The id of the credential") tool_name: str = Field(..., description="The name of the tool") tool_description: str = Field(..., description="The description of the tool") tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 9d5000f8d4..d38b80372b 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -321,6 +321,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=agent_tool.credential_id, ) runtime_parameters = {} parameters = tool_entity.get_merged_runtime_parameters() @@ -393,6 +394,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + credential_id: Optional[str] = None, ) -> Tool: """ get tool runtime from plugin @@ -404,6 +406,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=credential_id, ) runtime_parameters = {} parameters = tool_entity.get_merged_runtime_parameters() diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 678b99d546..ce67197a58 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -4,6 +4,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from packaging.version import Version +from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session @@ -13,10 +14,16 @@ from core.agent.strategy.plugin import PluginAgentStrategy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.plugin.entities.request import InvokeCredentials from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller from core.provider_manager import ProviderManager -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ( + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) from core.tools.tool_manager import ToolManager from core.variables.segments import StringSegment from core.workflow.entities.node_entities import NodeRunResult @@ -84,6 +91,7 @@ class AgentNode(ToolNode): for_log=True, strategy=strategy, ) + credentials = self._generate_credentials(parameters=parameters) # get conversation id conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) @@ -94,6 +102,7 @@ class AgentNode(ToolNode): user_id=self.user_id, app_id=self.app_id, conversation_id=conversation_id.text if conversation_id else None, + credentials=credentials, ) except Exception as e: yield RunCompletedEvent( @@ -246,6 +255,7 @@ class AgentNode(ToolNode): tool_name=tool.get("tool_name", ""), tool_parameters=parameters, plugin_unique_identifier=tool.get("plugin_unique_identifier", None), + credential_id=tool.get("credential_id", None), ) extra = tool.get("extra", {}) @@ -276,6 +286,7 @@ class AgentNode(ToolNode): { **tool_runtime.entity.model_dump(mode="json"), "runtime_parameters": runtime_parameters, + "credential_id": tool.get("credential_id", None), "provider_type": provider_type.value, } ) @@ -305,6 +316,27 @@ class AgentNode(ToolNode): return result + def _generate_credentials( + self, + parameters: dict[str, Any], + ) -> InvokeCredentials: + """ + Generate credentials based on the given agent parameters. + """ + + credentials = InvokeCredentials() + + # generate credentials for tools selector + credentials.tool_credentials = {} + for tool in parameters.get("tools", []): + if tool.get("credential_id"): + try: + identity = ToolIdentity.model_validate(tool.get("identity", {})) + credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) + except ValidationError: + continue + return credentials + @classmethod def _extract_variable_selector_to_variable_mapping( cls,