mirror of
https://github.com/langgenius/dify.git
synced 2026-04-24 17:16:37 +08:00
feat(oauth): add credential handling and context support for tool invocations
This commit is contained in:
parent
8fc5ccab35
commit
7de3436e6b
@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource):
|
|||||||
provider=payload.provider,
|
provider=payload.provider,
|
||||||
tool_name=payload.tool,
|
tool_name=payload.tool,
|
||||||
tool_parameters=payload.tool_parameters,
|
tool_parameters=payload.tool_parameters,
|
||||||
|
credential_id=payload.credential_id
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel):
|
|||||||
tool_name: str
|
tool_name: str
|
||||||
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
plugin_unique_identifier: str | None = None
|
plugin_unique_identifier: str | None = None
|
||||||
|
credential_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentPromptEntity(BaseModel):
|
class AgentPromptEntity(BaseModel):
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from core.agent.entities import AgentInvokeMessage
|
from core.agent.entities import AgentInvokeMessage
|
||||||
from core.agent.plugin_entities import AgentStrategyParameter
|
from core.agent.plugin_entities import AgentStrategyParameter
|
||||||
|
from core.plugin.entities.request import InvokeCredentials
|
||||||
|
|
||||||
|
|
||||||
class BaseAgentStrategy(ABC):
|
class BaseAgentStrategy(ABC):
|
||||||
@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC):
|
|||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
credentials: Optional[InvokeCredentials] = None,
|
||||||
) -> Generator[AgentInvokeMessage, None, None]:
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke the agent strategy.
|
Invoke the agent strategy.
|
||||||
"""
|
"""
|
||||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials)
|
||||||
|
|
||||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||||
"""
|
"""
|
||||||
@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC):
|
|||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
credentials: Optional[InvokeCredentials] = None,
|
||||||
) -> Generator[AgentInvokeMessage, None, None]:
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||||||
from core.agent.entities import AgentInvokeMessage
|
from core.agent.entities import AgentInvokeMessage
|
||||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||||
from core.agent.strategy.base import BaseAgentStrategy
|
from core.agent.strategy.base import BaseAgentStrategy
|
||||||
|
from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext
|
||||||
from core.plugin.impl.agent import PluginAgentClient
|
from core.plugin.impl.agent import PluginAgentClient
|
||||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||||
|
|
||||||
@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
|||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
credentials: Optional[InvokeCredentials] = None,
|
||||||
) -> Generator[AgentInvokeMessage, None, None]:
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke the agent strategy.
|
Invoke the agent strategy.
|
||||||
@ -58,4 +60,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
|||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
|
context=PluginInvokeContext(
|
||||||
|
credentials=credentials or InvokeCredentials()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||||
@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
provider: str,
|
provider: str,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_parameters: dict[str, Any],
|
tool_parameters: dict[str, Any],
|
||||||
|
credential_id: Optional[str] = None,
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
invoke tool
|
invoke tool
|
||||||
@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
# get tool runtime
|
# get tool runtime
|
||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
||||||
tool_type, tenant_id, provider, tool_name, tool_parameters
|
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
|
||||||
)
|
)
|
||||||
response = ToolEngine.generic_invoke(
|
response = ToolEngine.generic_invoke(
|
||||||
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
||||||
|
|||||||
@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeCredentials(BaseModel):
|
||||||
|
tool_credentials: dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeContext(BaseModel):
|
||||||
|
credentials: Optional[InvokeCredentials] = Field(
|
||||||
|
default_factory=InvokeCredentials,
|
||||||
|
description="Credentials context for the plugin invocation or backward invocation.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RequestInvokeTool(BaseModel):
|
class RequestInvokeTool(BaseModel):
|
||||||
"""
|
"""
|
||||||
Request to invoke a tool
|
Request to invoke a tool
|
||||||
@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel):
|
|||||||
provider: str
|
provider: str
|
||||||
tool: str
|
tool: str
|
||||||
tool_parameters: dict
|
tool_parameters: dict
|
||||||
|
credential_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class BaseRequestInvokeModel(BaseModel):
|
class BaseRequestInvokeModel(BaseModel):
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID
|
|||||||
from core.plugin.entities.plugin_daemon import (
|
from core.plugin.entities.plugin_daemon import (
|
||||||
PluginAgentProviderEntity,
|
PluginAgentProviderEntity,
|
||||||
)
|
)
|
||||||
|
from core.plugin.entities.request import PluginInvokeContext
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
|
||||||
|
|
||||||
@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient):
|
|||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
context: Optional[PluginInvokeContext] = None,
|
||||||
) -> Generator[AgentInvokeMessage, None, None]:
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
|
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
|
||||||
@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient):
|
|||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"app_id": app_id,
|
"app_id": app_id,
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
|
"context": context.model_dump() if context else {},
|
||||||
"data": {
|
"data": {
|
||||||
"agent_strategy_provider": agent_provider_id.provider_name,
|
"agent_strategy_provider": agent_provider_id.provider_name,
|
||||||
"agent_strategy": agent_strategy,
|
"agent_strategy": agent_strategy,
|
||||||
|
|||||||
@ -446,6 +446,7 @@ class ToolSelector(BaseModel):
|
|||||||
options: Optional[list[PluginParameterOption]] = None
|
options: Optional[list[PluginParameterOption]] = None
|
||||||
|
|
||||||
provider_id: str = Field(..., description="The id of the provider")
|
provider_id: str = Field(..., description="The id of the provider")
|
||||||
|
credential_id: Optional[str] = Field(default=None, description="The id of the credential")
|
||||||
tool_name: str = Field(..., description="The name of the tool")
|
tool_name: str = Field(..., description="The name of the tool")
|
||||||
tool_description: str = Field(..., description="The description of the tool")
|
tool_description: str = Field(..., description="The description of the tool")
|
||||||
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||||
|
|||||||
@ -321,6 +321,7 @@ class ToolManager:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||||
|
credential_id=agent_tool.credential_id,
|
||||||
)
|
)
|
||||||
runtime_parameters = {}
|
runtime_parameters = {}
|
||||||
parameters = tool_entity.get_merged_runtime_parameters()
|
parameters = tool_entity.get_merged_runtime_parameters()
|
||||||
@ -393,6 +394,7 @@ class ToolManager:
|
|||||||
provider: str,
|
provider: str,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_parameters: dict[str, Any],
|
tool_parameters: dict[str, Any],
|
||||||
|
credential_id: Optional[str] = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""
|
"""
|
||||||
get tool runtime from plugin
|
get tool runtime from plugin
|
||||||
@ -404,6 +406,7 @@ class ToolManager:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||||
|
credential_id=credential_id,
|
||||||
)
|
)
|
||||||
runtime_parameters = {}
|
runtime_parameters = {}
|
||||||
parameters = tool_entity.get_merged_runtime_parameters()
|
parameters = tool_entity.get_merged_runtime_parameters()
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
|||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
from pydantic import ValidationError
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -13,10 +14,16 @@ from core.agent.strategy.plugin import PluginAgentStrategy
|
|||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
|
from core.plugin.entities.request import InvokeCredentials
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
from core.tools.entities.tool_entities import (
|
||||||
|
ToolIdentity,
|
||||||
|
ToolInvokeMessage,
|
||||||
|
ToolParameter,
|
||||||
|
ToolProviderType,
|
||||||
|
)
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.variables.segments import StringSegment
|
from core.variables.segments import StringSegment
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
@ -84,6 +91,7 @@ class AgentNode(ToolNode):
|
|||||||
for_log=True,
|
for_log=True,
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
)
|
)
|
||||||
|
credentials = self._generate_credentials(parameters=parameters)
|
||||||
|
|
||||||
# get conversation id
|
# get conversation id
|
||||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||||
@ -94,6 +102,7 @@ class AgentNode(ToolNode):
|
|||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
app_id=self.app_id,
|
app_id=self.app_id,
|
||||||
conversation_id=conversation_id.text if conversation_id else None,
|
conversation_id=conversation_id.text if conversation_id else None,
|
||||||
|
credentials=credentials,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
@ -246,6 +255,7 @@ class AgentNode(ToolNode):
|
|||||||
tool_name=tool.get("tool_name", ""),
|
tool_name=tool.get("tool_name", ""),
|
||||||
tool_parameters=parameters,
|
tool_parameters=parameters,
|
||||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||||
|
credential_id=tool.get("credential_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
extra = tool.get("extra", {})
|
extra = tool.get("extra", {})
|
||||||
@ -276,6 +286,7 @@ class AgentNode(ToolNode):
|
|||||||
{
|
{
|
||||||
**tool_runtime.entity.model_dump(mode="json"),
|
**tool_runtime.entity.model_dump(mode="json"),
|
||||||
"runtime_parameters": runtime_parameters,
|
"runtime_parameters": runtime_parameters,
|
||||||
|
"credential_id": tool.get("credential_id", None),
|
||||||
"provider_type": provider_type.value,
|
"provider_type": provider_type.value,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -305,6 +316,27 @@ class AgentNode(ToolNode):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _generate_credentials(
|
||||||
|
self,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
) -> InvokeCredentials:
|
||||||
|
"""
|
||||||
|
Generate credentials based on the given agent parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
credentials = InvokeCredentials()
|
||||||
|
|
||||||
|
# generate credentials for tools selector
|
||||||
|
credentials.tool_credentials = {}
|
||||||
|
for tool in parameters.get("tools", []):
|
||||||
|
if tool.get("credential_id"):
|
||||||
|
try:
|
||||||
|
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||||
|
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||||
|
except ValidationError:
|
||||||
|
continue
|
||||||
|
return credentials
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user