mirror of https://github.com/langgenius/dify.git
feat(oauth): add multi credentials support
This commit is contained in:
parent
b316867bab
commit
26b46b88c9
|
|
@ -6,7 +6,7 @@ from pydantic import BaseModel
|
|||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderCredentialType
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginClient):
|
||||
|
|
@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient):
|
|||
tool_provider: str,
|
||||
tool_name: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_type: ToolProviderCredentialType,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
|
|
@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient):
|
|||
"provider": tool_provider_id.provider_name,
|
||||
"tool": tool_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"tool_parameters": tool_parameters,
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from openai import BaseModel
|
|||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolProviderCredentialType
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
|
|
@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
|
|||
invoke_from: Optional[InvokeFrom] = None
|
||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
credential_type: Optional[ToolProviderCredentialType] = ToolProviderCredentialType.API_KEY
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class PluginTool(Tool):
|
|||
tool_provider=self.entity.identity.provider,
|
||||
tool_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
credential_type=self.runtime.credential_type,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import mimetypes
|
|||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
|
|
@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import (
|
|||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderCredentialType,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
|
|
@ -148,6 +149,7 @@ class ToolManager:
|
|||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
credential_id: Optional[str] = None,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
|
@ -158,6 +160,7 @@ class ToolManager:
|
|||
:param tenant_id: the tenant id
|
||||
:param invoke_from: invoke from
|
||||
:param tool_invoke_from: the tool invoke from
|
||||
:param credential_id: the credential id
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
|
|
@ -185,19 +188,31 @@ class ToolManager:
|
|||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider | None = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
if credential_id:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {credential_id} not found")
|
||||
else:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
else:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
|
|
@ -209,8 +224,6 @@ class ToolManager:
|
|||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
encrypter, _ = create_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
|
|
@ -221,15 +234,13 @@ class ToolManager:
|
|||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
decrypted_credentials = encrypter.decrypt(credentials)
|
||||
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=decrypted_credentials,
|
||||
credentials=encrypter.decrypt(builtin_provider.credentials),
|
||||
credential_type=ToolProviderCredentialType.of(builtin_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
|
|
@ -362,6 +373,7 @@ class ToolManager:
|
|||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
credential_id=workflow_tool.credential_id,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
|
|||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
credential_id: str | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
|
|
|
|||
|
|
@ -582,6 +582,11 @@ class AppDslService:
|
|||
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
# filter credential id from tool node
|
||||
if node.get("data", {}).get("type", "") == NodeType.TOOL.value:
|
||||
node["data"]["credential_id"] = None
|
||||
|
||||
|
||||
export_data["workflow"] = workflow_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue