feat(oauth): refactor tool provider methods and enhance credential handling

This commit is contained in:
Harry 2025-06-27 13:17:09 +08:00
parent 8a954c0b19
commit daec82bd44
9 changed files with 309 additions and 170 deletions

View File

@ -1,6 +1,6 @@
import io
from flask import redirect, request, send_file
from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restful import (
Resource,
@ -17,6 +17,7 @@ from controllers.console.wraps import (
setup_required,
)
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import ToolProviderCredentialType
from extensions.ext_database import db
@ -127,7 +128,7 @@ class ToolBuiltinProviderAddApi(Resource):
return BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
provider=provider,
credentials=args["credentials"],
name=args["name"],
api_type=ToolProviderCredentialType.of(args["type"]),
@ -373,10 +374,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@account_initialization_required
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id)
return BuiltinToolManageService.list_builtin_provider_credentials_schema(
provider, ToolProviderCredentialType.of(credential_type), tenant_id
)
class ToolApiProviderSchemaApi(Resource):
@ -613,15 +615,12 @@ class ToolApiListApi(Resource):
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
provider.to_dict()
for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)
]
@ -662,13 +661,10 @@ class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
def get(self, provider):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
@ -679,63 +675,66 @@ class ToolPluginOAuthApi(Resource):
tenant_id = user.current_tenant_id
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
tenant_id=tenant_id,
provider=provider,
provider=provider_name,
plugin_id=plugin_id,
)
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
# todo decrypt oauth params
# TODO decrypt oauth params
oauth_params = plugin_oauth_config.oauth_params
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
oauth_params["redirect_uri"] = redirect_uri
response = oauth_handler.get_authorization_url(
tenant_id,
user.id,
plugin_id,
provider,
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_params,
)
return response.model_dump()
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
class ToolOAuthCallback(Resource):
@setup_required
def get(self):
args = (
reqparse.RequestParser()
.add_argument("context_id", type=str, required=True, nullable=False, location="args")
.parse_args()
)
context_id = args["context_id"]
def get(self, provider):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id, plugin_id, provider = (
context.get("user_id"),
context.get("tenant_id"),
context.get("plugin_id"),
context.get("provider"),
)
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
oauth_handler = OAuthHandler()
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
tenant_id=tenant_id,
provider=provider,
provider=provider_name,
plugin_id=plugin_id,
)
oauth_params = plugin_oauth_config.oauth_params
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
oauth_params["redirect_uri"] = redirect_uri
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
tenant_id,
user_id,
plugin_id,
provider,
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_params,
request=request,
).credentials
@ -747,12 +746,11 @@ class ToolOAuthCallback(Resource):
BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
provider=provider,
credentials=dict(credentials),
name=provider,
api_type=ToolProviderCredentialType.OAUTH2,
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success")
class ToolBuiltinProviderSetDefaultApi(Resource):
@ -768,9 +766,41 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
)
class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.setup_oauth_custom_client(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider=provider,
client_params=args["client_params"],
)
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=current_user.current_tenant_id, provider_name=provider
)
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback")
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
@ -782,14 +812,14 @@ api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/b
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/set-default"
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/<path:credential_type>/credentials_schema",
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema/<path:credential_type>",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")

View File

@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
from core.tools.entities.tool_entities import (
OAuthSchema,
ToolEntity,
ToolProviderCredentialType,
ToolProviderEntity,
ToolProviderType,
)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
oauth_schema=oauth_schema,
),
)
@ -91,16 +105,20 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.tools
def get_credentials_schema(self) -> list[ProviderConfig]:
def get_credentials_schema(
self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY
) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
if not self.entity.credentials_schema:
return []
return self.entity.credentials_schema.copy()
if credential_type == ToolProviderCredentialType.OAUTH2:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
elif credential_type == ToolProviderCredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
else:
raise ValueError(f"Invalid credential type: {credential_type}")
def get_tools(self) -> list[BuiltinTool]:
"""

View File

@ -344,10 +344,18 @@ class ToolEntity(BaseModel):
return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
@ -437,7 +445,7 @@ class ToolSelector(BaseModel):
class ToolProviderCredentialType(enum.StrEnum):
API_KEY = "api_key"
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
@ -446,7 +454,7 @@ class ToolProviderCredentialType(enum.StrEnum):
elif self == ToolProviderCredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("_", " ").upper()
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == ToolProviderCredentialType.API_KEY
@ -461,7 +469,7 @@ class ToolProviderCredentialType(enum.StrEnum):
@classmethod
def of(cls, credential_type: str) -> "ToolProviderCredentialType":
type_name = credential_type.lower()
if type_name == "api_key":
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2

View File

@ -34,7 +34,13 @@ from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
ToolProviderCredentialType,
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
@ -202,7 +208,12 @@ class ToolManager:
credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema(
ToolProviderCredentialType.of(builtin_provider.credential_type)
)
],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)

View File

@ -64,7 +64,10 @@ class BuiltinToolProvider(Base):
"""
__tablename__ = "tool_builtin_providers"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),)
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
@ -86,9 +89,9 @@ class BuiltinToolProvider(Base):
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
# credential type, e.g., "api_key", "oauth2"
# credential type, e.g., "api-key", "oauth2"
credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api_key'::character varying")
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
)
@property

View File

@ -1,3 +1,6 @@
import json
import uuid
from core.plugin.impl.base import BasePluginClient
from extensions.ext_redis import redis_client

View File

@ -446,7 +446,7 @@ class ApiToolManageService:
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
@ -474,7 +474,7 @@ class ApiToolManageService:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
tenant_id=tenant_id, tool=tool, labels=labels
)
)

View File

@ -4,7 +4,6 @@ import re
from pathlib import Path
from typing import Optional, Union
from sqlalchemy import ColumnExpressionArgument
from sqlalchemy.orm import Session
from configs import dify_config
@ -13,10 +12,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.tool_entities import ToolProviderCredentialType
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
@ -29,6 +30,8 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
@ -42,22 +45,11 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_configuration.decrypt(credentials)
result: list[ToolApiEntity] = []
for tool in tools or []:
result.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
@ -73,7 +65,7 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
@ -92,16 +84,19 @@ class BuiltinToolManageService:
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str):
def list_builtin_provider_credentials_schema(
provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str
):
"""
list builtin provider credentials schema
:param credential_type: credential type
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder(provider.get_credentials_schema())
return jsonable_encoder(provider.get_credentials_schema(credential_type))
@staticmethod
def update_builtin_tool_provider(
@ -111,11 +106,11 @@ class BuiltinToolManageService:
update builtin tool provider
"""
# get if the provider exists
provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
try:
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
@ -133,10 +128,12 @@ class BuiltinToolManageService:
if key in masked_credentials and value == masked_credentials[key]:
credentials[key] = original_credentials[key]
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.encrypted_credentials = json.dumps(encrypted_credentials)
tool_configuration.delete_tool_credentials_cache()
# update name if provided
if name is not None and provider.name != name:
@ -158,68 +155,84 @@ class BuiltinToolManageService:
user_id: str,
api_type: ToolProviderCredentialType,
tenant_id: str,
provider_name: str,
provider: str,
credentials: dict,
name: str | None = None,
):
"""
add builtin tool provider
"""
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}"
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
if name is None:
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
# check if the provider count is over the limit
provider_count = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}")
provider = BuiltinToolProvider(
# TODO should we get name from oauth authentication?
name = (
name
if name
else BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id, provider, credential_type=api_type
)
)
db_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
provider=provider,
encrypted_credentials=json.dumps(credentials),
credential_type=api_type.value,
name=name,
)
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
raise ValueError(f"provider {provider} does not need credentials")
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
provider_controller=provider_controller,
tool_configuration=tool_configuration,
provider=db_provider,
credentials=credentials,
user_id=user_id,
)
db.session.add(provider)
db.session.add(db_provider)
db.session.commit()
return {"result": "success"}
@staticmethod
def get_next_builtin_tool_provider_name(
tenant_id: str, provider_name: str, type: ToolProviderCredentialType
def generate_builtin_tool_provider_name(
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType
) -> str:
try:
providers = (
db_providers = (
db.session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_name,
credential_type=type.value,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.limit(10)
.all()
)
# Get the default name pattern
default_pattern = type.get_name()
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for provider in providers:
if provider.name:
match = re.match(pattern, provider.name.strip())
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
@ -231,9 +244,9 @@ class BuiltinToolManageService:
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}")
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
# fallback
return f"{type.get_name()} 1"
return f"{credential_type.get_name()} 1"
@staticmethod
def get_builtin_tool_provider_credentials(
@ -242,31 +255,43 @@ class BuiltinToolManageService:
"""
get builtin tool provider credentials
"""
providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
if len(providers) == 0:
return []
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
decrypt_credential = tool_configuration.mask_tool_credentials(
tool_configuration.decrypt(provider.credentials)
with db.session.no_autoflush:
providers = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
)
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
if len(providers) == 0:
return []
default_provider = sorted(
providers,
key=lambda p: (
not getattr(p, "is_default", False),
getattr(p, "created_at", None) or 0,
),
)[0]
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
decrypt_credential = tool_configuration.mask_tool_credentials(
tool_configuration.decrypt(provider.credentials)
)
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
@staticmethod
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
"""
delete tool provider
"""
tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
if tool_provider is None:
raise ValueError(f"you have not added provider {provider_name}")
@ -387,7 +412,6 @@ class BuiltinToolManageService:
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
@ -399,7 +423,7 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
@ -411,47 +435,62 @@ class BuiltinToolManageService:
return provider
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
with Session(db.engine) as session:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]:
return (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
else:
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
.first()
)
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = _query([BuiltinToolProvider.provider == full_provider_name])
else:
provider = _query(
[
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name)
]
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
return _query([BuiltinToolProvider.provider == provider_name])
@staticmethod
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
@ -463,7 +502,13 @@ class BuiltinToolManageService:
)
@staticmethod
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
def _encrypt_and_save_credentials(
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
tool_configuration: ProviderConfigEncrypter,
provider: BuiltinToolProvider,
credentials: dict,
user_id: str,
):
"""
Validate and encrypt credentials, then save to database
@ -480,3 +525,25 @@ class BuiltinToolManageService:
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.encrypted_credentials = json.dumps(encrypted_credentials)
tool_configuration.delete_tool_credentials_cache()
@staticmethod
def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict):
"""
setup oauth custom client
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# Validate and encrypt credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller=provider_controller,
tool_configuration=tool_configuration,
provider=None, # No need to save in DB
credentials=client_params,
user_id=user_id,
)
return {"result": "success"}

View File

@ -255,7 +255,6 @@ class ToolTransformService:
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
credentials: dict | None = None,
labels: list[str] | None = None,
) -> ToolApiEntity:
"""
@ -265,7 +264,7 @@ class ToolTransformService:
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials=credentials or {},
credentials= {},
tenant_id=tenant_id,
)
)