mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 20:17:29 +08:00
refactor(tool): implement multi provider credentials support
This commit is contained in:
parent
daec82bd44
commit
7951a1c4df
@ -82,7 +82,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
|||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
|
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinProviderDeleteApi(Resource):
|
class ToolBuiltinProviderDeleteApi(Resource):
|
||||||
@ -159,7 +159,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_name=provider,
|
provider=provider,
|
||||||
credentials=args["credentials"],
|
credentials=args["credentials"],
|
||||||
credential_id=args["credential_id"],
|
credential_id=args["credential_id"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
@ -782,7 +782,6 @@ class ToolOAuthCustomClient(Resource):
|
|||||||
|
|
||||||
return BuiltinToolManageService.setup_oauth_custom_client(
|
return BuiltinToolManageService.setup_oauth_custom_client(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=user.current_tenant_id,
|
||||||
user_id=user.id,
|
|
||||||
provider=provider,
|
provider=provider,
|
||||||
client_params=args["client_params"],
|
client_params=args["client_params"],
|
||||||
)
|
)
|
||||||
|
|||||||
77
api/core/helper/provider_cache.py
Normal file
77
api/core/helper/provider_cache.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCredentialsCache(ABC):
|
||||||
|
"""Base class for provider credentials cache"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.cache_key = self._generate_cache_key(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generate_cache_key(self, **kwargs) -> str:
|
||||||
|
"""Generate cache key based on subclass implementation"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""Get cached provider credentials"""
|
||||||
|
cached_credentials = redis_client.get(self.cache_key)
|
||||||
|
if cached_credentials:
|
||||||
|
try:
|
||||||
|
cached_credentials = cached_credentials.decode("utf-8")
|
||||||
|
return dict(json.loads(cached_credentials))
|
||||||
|
except JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, config: dict[str, Any]) -> None:
|
||||||
|
"""Cache provider credentials"""
|
||||||
|
redis_client.setex(self.cache_key, 86400, json.dumps(config))
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""Delete cached provider credentials"""
|
||||||
|
redis_client.delete(self.cache_key)
|
||||||
|
|
||||||
|
|
||||||
|
class GenericProviderCredentialsCache(ProviderCredentialsCache):
|
||||||
|
"""Cache for generic provider credentials"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, identity_id: str):
|
||||||
|
super().__init__(tenant_id=tenant_id, identity_id=identity_id)
|
||||||
|
|
||||||
|
def _generate_cache_key(self, **kwargs) -> str:
|
||||||
|
tenant_id = kwargs["tenant_id"]
|
||||||
|
identity_id = kwargs["identity_id"]
|
||||||
|
return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||||
|
|
||||||
|
class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||||
|
"""Cache for tool provider credentials"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||||
|
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||||
|
|
||||||
|
def _generate_cache_key(self, **kwargs) -> str:
|
||||||
|
tenant_id = kwargs["tenant_id"]
|
||||||
|
provider = kwargs["provider"]
|
||||||
|
credential_id = kwargs["credential_id"]
|
||||||
|
return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpProviderCredentialCache:
|
||||||
|
"""No-op provider credential cache"""
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""Get cached provider credentials"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, config: dict[str, Any]) -> None:
|
||||||
|
"""Cache provider credentials"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""Delete cached provider credentials"""
|
||||||
|
pass
|
||||||
@ -1,51 +0,0 @@
|
|||||||
import json
|
|
||||||
from enum import Enum
|
|
||||||
from json import JSONDecodeError
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from extensions.ext_redis import redis_client
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderCredentialsCacheType(Enum):
|
|
||||||
PROVIDER = "tool_provider"
|
|
||||||
ENDPOINT = "endpoint"
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderCredentialsCache:
|
|
||||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
|
|
||||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
|
||||||
|
|
||||||
def get(self) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Get cached model provider credentials.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
cached_provider_credentials = redis_client.get(self.cache_key)
|
|
||||||
if cached_provider_credentials:
|
|
||||||
try:
|
|
||||||
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
|
|
||||||
cached_provider_credentials = json.loads(cached_provider_credentials)
|
|
||||||
except JSONDecodeError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return dict(cached_provider_credentials)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def set(self, credentials: dict) -> None:
|
|
||||||
"""
|
|
||||||
Cache model provider credentials.
|
|
||||||
|
|
||||||
:param credentials: provider credentials
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
|
||||||
"""
|
|
||||||
Delete cached model provider credentials.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
redis_client.delete(self.cache_key)
|
|
||||||
@ -1,12 +1,12 @@
|
|||||||
from core.plugin.entities.request import RequestInvokeEncrypt
|
from core.plugin.entities.request import RequestInvokeEncrypt
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.configuration import create_generic_encrypter
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
|
|
||||||
|
|
||||||
class PluginEncrypter:
|
class PluginEncrypter:
|
||||||
@classmethod
|
@classmethod
|
||||||
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
|
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
|
||||||
encrypter = ProviderConfigEncrypter(
|
encrypter, cache = create_generic_encrypter(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
config=payload.config,
|
config=payload.config,
|
||||||
provider_type=payload.namespace,
|
provider_type=payload.namespace,
|
||||||
@ -22,7 +22,7 @@ class PluginEncrypter:
|
|||||||
"data": encrypter.decrypt(payload.data),
|
"data": encrypter.decrypt(payload.data),
|
||||||
}
|
}
|
||||||
elif payload.opt == "clear":
|
elif payload.opt == "clear":
|
||||||
encrypter.delete_tool_credentials_cache()
|
cache.delete()
|
||||||
return {
|
return {
|
||||||
"data": {},
|
"data": {},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -105,20 +105,34 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
"""
|
"""
|
||||||
return self.tools
|
return self.tools
|
||||||
|
|
||||||
def get_credentials_schema(
|
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||||
self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY
|
|
||||||
) -> list[ProviderConfig]:
|
|
||||||
"""
|
"""
|
||||||
returns the credentials schema of the provider
|
returns the credentials schema of the provider
|
||||||
|
|
||||||
:return: the credentials schema
|
:return: the credentials schema
|
||||||
"""
|
"""
|
||||||
if credential_type == ToolProviderCredentialType.OAUTH2:
|
return self.get_credentials_schema_by_type(ToolProviderCredentialType.API_KEY.value)
|
||||||
|
|
||||||
|
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||||
|
"""
|
||||||
|
returns the credentials schema of the provider
|
||||||
|
|
||||||
|
:param credential_type: the type of the credential
|
||||||
|
:return: the credentials schema of the provider
|
||||||
|
"""
|
||||||
|
if credential_type == ToolProviderCredentialType.OAUTH2.value:
|
||||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||||
elif credential_type == ToolProviderCredentialType.API_KEY:
|
if credential_type == ToolProviderCredentialType.API_KEY.value:
|
||||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||||
else:
|
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
|
||||||
|
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||||
|
"""
|
||||||
|
returns the oauth client schema of the provider
|
||||||
|
|
||||||
|
:return: the oauth client schema
|
||||||
|
"""
|
||||||
|
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||||
|
|
||||||
def get_tools(self) -> list[BuiltinTool]:
|
def get_tools(self) -> list[BuiltinTool]:
|
||||||
"""
|
"""
|
||||||
@ -141,7 +155,11 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
|
|
||||||
:return: whether the provider needs credentials
|
:return: whether the provider needs credentials
|
||||||
"""
|
"""
|
||||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
return (
|
||||||
|
self.entity.credentials_schema is not None
|
||||||
|
and len(self.entity.credentials_schema) != 0
|
||||||
|
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_type(self) -> ToolProviderType:
|
def provider_type(self) -> ToolProviderType:
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
from core.plugin.impl.tool import PluginToolManager
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
@ -38,12 +39,16 @@ from core.tools.entities.tool_entities import (
|
|||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
ToolInvokeFrom,
|
ToolInvokeFrom,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolProviderCredentialType,
|
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.errors import ToolProviderNotFoundError
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
|
from core.tools.utils.configuration import (
|
||||||
|
ProviderConfigEncrypter,
|
||||||
|
ToolParameterConfigurationManager,
|
||||||
|
create_encrypter,
|
||||||
|
create_generic_encrypter,
|
||||||
|
)
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||||
@ -206,19 +211,18 @@ class ToolManager:
|
|||||||
|
|
||||||
# decrypt the credentials
|
# decrypt the credentials
|
||||||
credentials = builtin_provider.credentials
|
credentials = builtin_provider.credentials
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=[
|
config=[
|
||||||
x.to_basic_provider_config()
|
x.to_basic_provider_config()
|
||||||
for x in provider_controller.get_credentials_schema(
|
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||||
ToolProviderCredentialType.of(builtin_provider.credential_type)
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
provider_type=provider_controller.provider_type.value,
|
cache=ToolProviderCredentialsCache(
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
decrypted_credentials = encrypter.decrypt(credentials)
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
@ -235,22 +239,18 @@ class ToolManager:
|
|||||||
|
|
||||||
elif provider_type == ToolProviderType.API:
|
elif provider_type == ToolProviderType.API:
|
||||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||||
|
encrypter, _ = create_generic_encrypter(
|
||||||
# decrypt the credentials
|
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
|
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
|
||||||
provider_type=api_provider.provider_type.value,
|
provider_type=api_provider.provider_type.value,
|
||||||
provider_identity=api_provider.entity.identity.name,
|
provider_identity=api_provider.entity.identity.name,
|
||||||
)
|
)
|
||||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
ApiTool,
|
ApiTool,
|
||||||
api_provider.get_tool(tool_name).fork_tool_runtime(
|
api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||||
runtime=ToolRuntime(
|
runtime=ToolRuntime(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
credentials=decrypted_credentials,
|
credentials=encrypter.decrypt(credentials),
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
tool_invoke_from=tool_invoke_from,
|
tool_invoke_from=tool_invoke_from,
|
||||||
)
|
)
|
||||||
@ -730,7 +730,7 @@ class ToolManager:
|
|||||||
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
||||||
)
|
)
|
||||||
# init tool configuration
|
# init tool configuration
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
tool_configuration = ProviderConfigEncrypter.create_cached(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||||
provider_type=controller.provider_type.value,
|
provider_type=controller.provider_type.value,
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any
|
from typing import Any, Optional, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
|
from core.helper.provider_cache import GenericProviderCredentialsCache
|
||||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
@ -14,11 +12,38 @@ from core.tools.entities.tool_entities import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfigEncrypter(BaseModel):
|
class ProviderConfigCache(Protocol):
|
||||||
|
"""
|
||||||
|
Interface for provider configuration cache operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""Get cached provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def set(self, config: dict[str, Any]) -> None:
|
||||||
|
"""Cache provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""Delete cached provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigEncrypter:
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
config: list[BasicProviderConfig]
|
config: list[BasicProviderConfig]
|
||||||
provider_type: str
|
provider_config_cache: ProviderConfigCache
|
||||||
provider_identity: str
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
config: list[BasicProviderConfig],
|
||||||
|
provider_config_cache: ProviderConfigCache,
|
||||||
|
):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.config = config
|
||||||
|
self.provider_config_cache = provider_config_cache
|
||||||
|
|
||||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
@ -72,18 +97,13 @@ class ProviderConfigEncrypter(BaseModel):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
decrypt tool credentials with tenant id
|
decrypt tool credentials with tenant id
|
||||||
|
|
||||||
return a deep copy of credentials with decrypted values
|
return a deep copy of credentials with decrypted values
|
||||||
"""
|
"""
|
||||||
cache = ToolProviderCredentialsCache(
|
cached_credentials = self.provider_config_cache.get()
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
|
||||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
|
||||||
)
|
|
||||||
cached_credentials = cache.get()
|
|
||||||
if cached_credentials:
|
if cached_credentials:
|
||||||
return cached_credentials
|
return cached_credentials
|
||||||
data = self._deep_copy(data)
|
data = self._deep_copy(data)
|
||||||
@ -104,16 +124,24 @@ class ProviderConfigEncrypter(BaseModel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
cache.set(data)
|
self.provider_config_cache.set(data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def delete_tool_credentials_cache(self):
|
|
||||||
cache = ToolProviderCredentialsCache(
|
def create_encrypter(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
|
||||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
):
|
||||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
return ProviderConfigEncrypter(
|
||||||
)
|
tenant_id=tenant_id, config=config, provider_config_cache=cache
|
||||||
cache.delete()
|
), cache
|
||||||
|
|
||||||
|
|
||||||
|
def create_generic_encrypter(
|
||||||
|
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
|
||||||
|
):
|
||||||
|
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
|
||||||
|
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
|
||||||
|
return encrypt, cache
|
||||||
|
|
||||||
|
|
||||||
class ToolParameterConfigurationManager:
|
class ToolParameterConfigurationManager:
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
|
|||||||
)
|
)
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter
|
||||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import ApiToolProvider
|
from models.tools import ApiToolProvider
|
||||||
@ -297,28 +297,28 @@ class ApiToolManageService:
|
|||||||
provider_controller.load_bundled_tools(tool_bundles)
|
provider_controller.load_bundled_tools(tool_bundles)
|
||||||
|
|
||||||
# get original credentials if exists
|
# get original credentials if exists
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, cache = create_generic_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()),
|
config=list(provider_controller.get_credentials_schema()),
|
||||||
provider_type=provider_controller.provider_type.value,
|
provider_type=provider_controller.provider_type.value,
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
provider_identity=provider_controller.entity.identity.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
original_credentials = encrypter.decrypt(provider.credentials)
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
|
||||||
# check if the credential has changed, save the original credential
|
# check if the credential has changed, save the original credential
|
||||||
for name, value in credentials.items():
|
for name, value in credentials.items():
|
||||||
if name in masked_credentials and value == masked_credentials[name]:
|
if name in masked_credentials and value == masked_credentials[name]:
|
||||||
credentials[name] = original_credentials[name]
|
credentials[name] = original_credentials[name]
|
||||||
|
|
||||||
credentials = tool_configuration.encrypt(credentials)
|
credentials = encrypter.encrypt(credentials)
|
||||||
provider.credentials_str = json.dumps(credentials)
|
provider.credentials_str = json.dumps(credentials)
|
||||||
|
|
||||||
db.session.add(provider)
|
db.session.add(provider)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# delete cache
|
# delete cache
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
cache.delete()
|
||||||
|
|
||||||
# update labels
|
# update labels
|
||||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||||
@ -416,15 +416,15 @@ class ApiToolManageService:
|
|||||||
|
|
||||||
# decrypt credentials
|
# decrypt credentials
|
||||||
if db_provider.id:
|
if db_provider.id:
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_generic_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()),
|
config=list(provider_controller.get_credentials_schema()),
|
||||||
provider_type=provider_controller.provider_type.value,
|
provider_type=provider_controller.provider_type.value,
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
provider_identity=provider_controller.entity.identity.name,
|
||||||
)
|
)
|
||||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
decrypted_credentials = encrypter.decrypt(credentials)
|
||||||
# check if the credential has changed, save the original credential
|
# check if the credential has changed, save the original credential
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
|
||||||
for name, value in credentials.items():
|
for name, value in credentials.items():
|
||||||
if name in masked_credentials and value == masked_credentials[name]:
|
if name in masked_credentials and value == masked_credentials[name]:
|
||||||
credentials[name] = decrypted_credentials[name]
|
credentials[name] = decrypted_credentials[name]
|
||||||
|
|||||||
@ -8,19 +8,18 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
|
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
||||||
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
||||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.configuration import create_encrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
||||||
@ -58,20 +57,15 @@ class BuiltinToolManageService:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
|
def get_builtin_tool_provider_info(tenant_id: str, provider: str):
|
||||||
"""
|
"""
|
||||||
get builtin tool provider info
|
get builtin tool provider info
|
||||||
"""
|
"""
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
|
||||||
# check if user has added the provider
|
# check if user has added the provider
|
||||||
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
|
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
|
||||||
|
if builtin_provider is None:
|
||||||
credentials = {}
|
raise ValueError(f"you have not added provider {provider}")
|
||||||
if builtin_provider is not None:
|
|
||||||
# get credentials
|
|
||||||
credentials = builtin_provider.credentials
|
|
||||||
credentials = tool_configuration.decrypt(credentials)
|
|
||||||
|
|
||||||
entity = ToolTransformService.builtin_provider_to_user_provider(
|
entity = ToolTransformService.builtin_provider_to_user_provider(
|
||||||
provider_controller=provider_controller,
|
provider_controller=provider_controller,
|
||||||
@ -80,7 +74,6 @@ class BuiltinToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
entity.original_credentials = {}
|
entity.original_credentials = {}
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -96,32 +89,34 @@ class BuiltinToolManageService:
|
|||||||
:return: the list of tool providers
|
:return: the list of tool providers
|
||||||
"""
|
"""
|
||||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||||
return jsonable_encoder(provider.get_credentials_schema(credential_type))
|
return jsonable_encoder(provider.get_credentials_schema_by_type(credential_type))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_builtin_tool_provider(
|
def update_builtin_tool_provider(
|
||||||
user_id: str, tenant_id: str, provider_name: str, credentials: dict, credential_id: str, name: str | None = None
|
user_id: str, tenant_id: str, provider: str, credentials: dict, credential_id: str, name: str | None = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
update builtin tool provider
|
update builtin tool provider
|
||||||
"""
|
"""
|
||||||
# get if the provider exists
|
# get if the provider exists
|
||||||
provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
||||||
|
|
||||||
if provider is None:
|
if db_provider is None:
|
||||||
raise ValueError(f"you have not added provider {provider_name}")
|
raise ValueError(f"you have not added provider {provider}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
|
if ToolProviderCredentialType.of(db_provider.credential_type).is_editable():
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
if not provider_controller.need_credentials:
|
if not provider_controller.need_credentials:
|
||||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
raise ValueError(f"provider {provider} does not need credentials")
|
||||||
|
|
||||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||||
|
tenant_id, db_provider, provider, provider_controller
|
||||||
|
)
|
||||||
|
|
||||||
# Decrypt and restore original credentials for masked values
|
# Decrypt and restore original credentials for masked values
|
||||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
|
||||||
|
|
||||||
# check if the credential has changed, save the original credential
|
# check if the credential has changed, save the original credential
|
||||||
for key, value in credentials.items():
|
for key, value in credentials.items():
|
||||||
@ -131,13 +126,13 @@ class BuiltinToolManageService:
|
|||||||
provider_controller.validate_credentials(user_id, credentials)
|
provider_controller.validate_credentials(user_id, credentials)
|
||||||
|
|
||||||
# encrypt credentials
|
# encrypt credentials
|
||||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
|
||||||
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
cache.delete()
|
||||||
|
|
||||||
# update name if provided
|
# update name if provided
|
||||||
if name is not None and provider.name != name:
|
if name is not None and db_provider.name != name:
|
||||||
provider.name = name
|
db_provider.name = name
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
except (
|
except (
|
||||||
@ -176,7 +171,7 @@ class BuiltinToolManageService:
|
|||||||
name
|
name
|
||||||
if name
|
if name
|
||||||
else BuiltinToolManageService.generate_builtin_tool_provider_name(
|
else BuiltinToolManageService.generate_builtin_tool_provider_name(
|
||||||
tenant_id, provider, credential_type=api_type
|
tenant_id=tenant_id, provider=provider, credential_type=api_type
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -193,20 +188,35 @@ class BuiltinToolManageService:
|
|||||||
if not provider_controller.need_credentials:
|
if not provider_controller.need_credentials:
|
||||||
raise ValueError(f"provider {provider} does not need credentials")
|
raise ValueError(f"provider {provider} does not need credentials")
|
||||||
|
|
||||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||||
|
tenant_id, db_provider, provider, provider_controller
|
||||||
# Encrypt and save the credentials
|
|
||||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
|
||||||
provider_controller=provider_controller,
|
|
||||||
tool_configuration=tool_configuration,
|
|
||||||
provider=db_provider,
|
|
||||||
credentials=credentials,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# encrypt credentials
|
||||||
|
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
|
||||||
|
|
||||||
|
cache.delete()
|
||||||
db.session.add(db_provider)
|
db.session.add(db_provider)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_tool_encrypter(
|
||||||
|
tenant_id: str,
|
||||||
|
db_provider: BuiltinToolProvider,
|
||||||
|
provider: str,
|
||||||
|
provider_controller: BuiltinToolProviderController,
|
||||||
|
):
|
||||||
|
encrypter, cache = create_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[
|
||||||
|
x.to_basic_provider_config()
|
||||||
|
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
|
||||||
|
],
|
||||||
|
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
|
||||||
|
)
|
||||||
|
return encrypter, cache
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_builtin_tool_provider_name(
|
def generate_builtin_tool_provider_name(
|
||||||
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType
|
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType
|
||||||
@ -273,12 +283,13 @@ class BuiltinToolManageService:
|
|||||||
|
|
||||||
default_provider.is_default = True
|
default_provider.is_default = True
|
||||||
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
||||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||||
|
tenant_id, default_provider, default_provider.provider, provider_controller
|
||||||
|
)
|
||||||
|
|
||||||
credentials: list[ToolProviderCredentialApiEntity] = []
|
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
decrypt_credential = tool_configuration.mask_tool_credentials(
|
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
|
||||||
tool_configuration.decrypt(provider.credentials)
|
|
||||||
)
|
|
||||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
credentials=decrypt_credential,
|
credentials=decrypt_credential,
|
||||||
@ -287,22 +298,24 @@ class BuiltinToolManageService:
|
|||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
|
def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
|
||||||
"""
|
"""
|
||||||
delete tool provider
|
delete tool provider
|
||||||
"""
|
"""
|
||||||
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
||||||
|
|
||||||
if tool_provider is None:
|
if tool_provider is None:
|
||||||
raise ValueError(f"you have not added provider {provider_name}")
|
raise ValueError(f"you have not added provider {provider}")
|
||||||
|
|
||||||
db.session.delete(tool_provider)
|
db.session.delete(tool_provider)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# delete cache
|
# delete cache
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
_, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
tenant_id, tool_provider, provider, provider_controller
|
||||||
|
)
|
||||||
|
cache.delete()
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@ -493,57 +506,35 @@ class BuiltinToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
|
def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict):
|
||||||
return ProviderConfigEncrypter(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
|
||||||
provider_type=provider_controller.provider_type.value,
|
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _encrypt_and_save_credentials(
|
|
||||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
|
||||||
tool_configuration: ProviderConfigEncrypter,
|
|
||||||
provider: BuiltinToolProvider,
|
|
||||||
credentials: dict,
|
|
||||||
user_id: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Validate and encrypt credentials, then save to database
|
|
||||||
|
|
||||||
:param provider_controller: the provider controller
|
|
||||||
:param tool_configuration: the tool configuration encrypter
|
|
||||||
:param provider: the provider object from database
|
|
||||||
:param credentials: the credentials to encrypt and save
|
|
||||||
:param user_id: the user id for validation
|
|
||||||
"""
|
|
||||||
if ToolProviderCredentialType.of(provider.credential_type).is_validate_allowed():
|
|
||||||
provider_controller.validate_credentials(user_id, credentials)
|
|
||||||
|
|
||||||
# encrypt credentials
|
|
||||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
|
||||||
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict):
|
|
||||||
"""
|
"""
|
||||||
setup oauth custom client
|
setup oauth custom client
|
||||||
"""
|
"""
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
with Session(db.engine) as session:
|
||||||
if not provider_controller:
|
tool_provider = ToolProviderID(provider)
|
||||||
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||||
|
|
||||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
if not isinstance(provider_controller, BuiltinToolProviderController):
|
||||||
|
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||||
|
|
||||||
# Validate and encrypt credentials
|
encrypter, _ = create_encrypter(
|
||||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
tenant_id=tenant_id,
|
||||||
provider_controller=provider_controller,
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||||
tool_configuration=tool_configuration,
|
cache=NoOpProviderCredentialCache(),
|
||||||
provider=None, # No need to save in DB
|
)
|
||||||
credentials=client_params,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# encrypt credentials
|
||||||
|
encrypted_credentials = encrypter.encrypt(client_params)
|
||||||
|
session.add(
|
||||||
|
ToolOAuthTenantClient(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=tool_provider.plugin_id,
|
||||||
|
provider=tool_provider.provider_name,
|
||||||
|
enabled=True,
|
||||||
|
encrypted_oauth_params=json.dumps(encrypted_credentials),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import Optional, Union, cast
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
@ -19,7 +20,7 @@ from core.tools.entities.tool_entities import (
|
|||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.configuration import create_encrypter, create_generic_encrypter
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||||
@ -109,7 +110,14 @@ class ToolTransformService:
|
|||||||
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
|
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
|
||||||
|
|
||||||
# get credentials schema
|
# get credentials schema
|
||||||
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
|
schema = {
|
||||||
|
x.to_basic_provider_config().name: x
|
||||||
|
for x in provider_controller.get_credentials_schema_by_type(
|
||||||
|
ToolProviderCredentialType.of(db_provider.credential_type)
|
||||||
|
if db_provider
|
||||||
|
else ToolProviderCredentialType.API_KEY
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
for name, value in schema.items():
|
for name, value in schema.items():
|
||||||
if result.masked_credentials:
|
if result.masked_credentials:
|
||||||
@ -126,15 +134,23 @@ class ToolTransformService:
|
|||||||
credentials = db_provider.credentials
|
credentials = db_provider.credentials
|
||||||
|
|
||||||
# init tool configuration
|
# init tool configuration
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_encrypter(
|
||||||
tenant_id=db_provider.tenant_id,
|
tenant_id=db_provider.tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
config=[
|
||||||
provider_type=provider_controller.provider_type.value,
|
x.to_basic_provider_config()
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
for x in provider_controller.get_credentials_schema_by_type(
|
||||||
|
ToolProviderCredentialType.of(db_provider.credential_type)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
cache=ToolProviderCredentialsCache(
|
||||||
|
tenant_id=db_provider.tenant_id,
|
||||||
|
provider=db_provider.provider,
|
||||||
|
credential_id=db_provider.id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# decrypt the credentials and mask the credentials
|
# decrypt the credentials and mask the credentials
|
||||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||||
|
|
||||||
result.masked_credentials = masked_credentials
|
result.masked_credentials = masked_credentials
|
||||||
result.original_credentials = decrypted_credentials
|
result.original_credentials = decrypted_credentials
|
||||||
@ -236,7 +252,7 @@ class ToolTransformService:
|
|||||||
|
|
||||||
if decrypt_credentials:
|
if decrypt_credentials:
|
||||||
# init tool configuration
|
# init tool configuration
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_generic_encrypter(
|
||||||
tenant_id=db_provider.tenant_id,
|
tenant_id=db_provider.tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||||
provider_type=provider_controller.provider_type.value,
|
provider_type=provider_controller.provider_type.value,
|
||||||
@ -244,8 +260,8 @@ class ToolTransformService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# decrypt the credentials and mask the credentials
|
# decrypt the credentials and mask the credentials
|
||||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||||
|
|
||||||
result.masked_credentials = masked_credentials
|
result.masked_credentials = masked_credentials
|
||||||
|
|
||||||
@ -264,7 +280,7 @@ class ToolTransformService:
|
|||||||
# fork tool runtime
|
# fork tool runtime
|
||||||
tool = tool.fork_tool_runtime(
|
tool = tool.fork_tool_runtime(
|
||||||
runtime=ToolRuntime(
|
runtime=ToolRuntime(
|
||||||
credentials= {},
|
credentials={},
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user