mirror of https://github.com/langgenius/dify.git
feat(oauth): add credential handling and context support for tool invocations
This commit is contained in:
parent
8fc5ccab35
commit
7de3436e6b
|
|
@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource):
|
|||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
credential_id=payload.credential_id
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel):
|
|||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
plugin_unique_identifier: str | None = None
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class BaseAgentStrategy(ABC):
|
||||
|
|
@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC):
|
|||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
credentials: Optional[InvokeCredentials] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
"""
|
||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials)
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
"""
|
||||
|
|
@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC):
|
|||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
credentials: Optional[InvokeCredentials] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext
|
||||
from core.plugin.impl.agent import PluginAgentClient
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
|
@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
|||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
credentials: Optional[InvokeCredentials] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
|
|
@ -58,4 +60,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
|||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
context=PluginInvokeContext(
|
||||
credentials=credentials or InvokeCredentials()
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
|
|
@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
|||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
credential_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke tool
|
||||
|
|
@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
|||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
||||
tool_type, tenant_id, provider, tool_name, tool_parameters
|
||||
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
|
||||
)
|
||||
response = ToolEngine.generic_invoke(
|
||||
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
||||
|
|
|
|||
|
|
@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import (
|
|||
)
|
||||
|
||||
|
||||
class InvokeCredentials(BaseModel):
|
||||
tool_credentials: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
|
||||
)
|
||||
|
||||
|
||||
class PluginInvokeContext(BaseModel):
|
||||
credentials: Optional[InvokeCredentials] = Field(
|
||||
default_factory=InvokeCredentials,
|
||||
description="Credentials context for the plugin invocation or backward invocation.",
|
||||
)
|
||||
|
||||
|
||||
class RequestInvokeTool(BaseModel):
|
||||
"""
|
||||
Request to invoke a tool
|
||||
|
|
@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel):
|
|||
provider: str
|
||||
tool: str
|
||||
tool_parameters: dict
|
||||
credential_id: Optional[str] = None
|
||||
|
||||
|
||||
class BaseRequestInvokeModel(BaseModel):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID
|
|||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginAgentProviderEntity,
|
||||
)
|
||||
from core.plugin.entities.request import PluginInvokeContext
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
|
|
@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient):
|
|||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
context: Optional[PluginInvokeContext] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
|
||||
|
|
@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient):
|
|||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"context": context.model_dump() if context else {},
|
||||
"data": {
|
||||
"agent_strategy_provider": agent_provider_id.provider_name,
|
||||
"agent_strategy": agent_strategy,
|
||||
|
|
|
|||
|
|
@ -446,6 +446,7 @@ class ToolSelector(BaseModel):
|
|||
options: Optional[list[PluginParameterOption]] = None
|
||||
|
||||
provider_id: str = Field(..., description="The id of the provider")
|
||||
credential_id: Optional[str] = Field(default=None, description="The id of the credential")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
tool_description: str = Field(..., description="The description of the tool")
|
||||
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||
|
|
|
|||
|
|
@ -321,6 +321,7 @@ class ToolManager:
|
|||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credential_id=agent_tool.credential_id,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
|
|
@ -393,6 +394,7 @@ class ToolManager:
|
|||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
credential_id: Optional[str] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get tool runtime from plugin
|
||||
|
|
@ -404,6 +406,7 @@ class ToolManager:
|
|||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
|||
from typing import Any, Optional, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -13,10 +14,16 @@ from core.agent.strategy.plugin import PluginAgentStrategy
|
|||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
|
@ -84,6 +91,7 @@ class AgentNode(ToolNode):
|
|||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
credentials = self._generate_credentials(parameters=parameters)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
|
@ -94,6 +102,7 @@ class AgentNode(ToolNode):
|
|||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
yield RunCompletedEvent(
|
||||
|
|
@ -246,6 +255,7 @@ class AgentNode(ToolNode):
|
|||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
|
@ -276,6 +286,7 @@ class AgentNode(ToolNode):
|
|||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
|
|
@ -305,6 +316,27 @@ class AgentNode(ToolNode):
|
|||
|
||||
return result
|
||||
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> InvokeCredentials:
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
# generate credentials for tools selector
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if tool.get("credential_id"):
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
except ValidationError:
|
||||
continue
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
|
|
|||
Loading…
Reference in New Issue