mirror of https://github.com/langgenius/dify.git
feat(oauth): enhance OAuth client handling and add custom client support
This commit is contained in:
parent
6ef1e017df
commit
988a76066d
|
|
@ -675,18 +675,17 @@ class ToolPluginOAuthApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider
|
||||
)
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("no oauth available client config found for this tool provider")
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||
)
|
||||
# TODO decrypt oauth params
|
||||
oauth_params = plugin_oauth_config.oauth_params
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -694,7 +693,7 @@ class ToolPluginOAuthApi(Resource):
|
|||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_params,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
response = make_response(jsonable_encoder(authorization_url_response))
|
||||
response.set_cookie(
|
||||
|
|
@ -724,12 +723,10 @@ class ToolOAuthCallback(Resource):
|
|||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
oauth_params = plugin_oauth_config.oauth_params
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("no oauth available client config found for this tool provider")
|
||||
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||
credentials = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -737,7 +734,7 @@ class ToolOAuthCallback(Resource):
|
|||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_params,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
).credentials
|
||||
|
||||
|
|
@ -774,7 +771,8 @@ class ToolOAuthCustomClient(Resource):
|
|||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user = current_user
|
||||
|
|
@ -782,18 +780,21 @@ class ToolOAuthCustomClient(Resource):
|
|||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return BuiltinToolManageService.setup_oauth_custom_client(
|
||||
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider=provider,
|
||||
client_params=args["client_params"],
|
||||
client_params=args.get("client_params", {}),
|
||||
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider_name=provider
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -85,4 +85,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
|
|||
|
||||
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
||||
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
|
||||
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")
|
||||
is_oauth_custom_client_enabled: bool = Field(
|
||||
default=False, description="Whether the OAuth custom client is enabled for the provider"
|
||||
)
|
||||
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -21,6 +21,7 @@ from core.tools.entities.api_entities import (
|
|||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import create_encrypter
|
||||
|
|
@ -41,7 +42,12 @@ class BuiltinToolManageService:
|
|||
get builtin tool provider oauth client schema
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return provider.get_oauth_client_schema()
|
||||
return {
|
||||
"schema": provider.get_oauth_client_schema(),
|
||||
"is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled(
|
||||
tenant_id, provider_name
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
||||
|
|
@ -139,7 +145,7 @@ class BuiltinToolManageService:
|
|||
|
||||
# encrypt credentials
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
|
||||
|
||||
|
||||
cache.delete()
|
||||
|
||||
# update name if provided
|
||||
|
|
@ -279,20 +285,16 @@ class BuiltinToolManageService:
|
|||
"""
|
||||
with db.session.no_autoflush:
|
||||
providers = (
|
||||
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider_name)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(providers) == 0:
|
||||
return []
|
||||
|
||||
default_provider = sorted(
|
||||
providers,
|
||||
key=lambda p: (
|
||||
not getattr(p, "is_default", False),
|
||||
getattr(p, "created_at", None) or 0,
|
||||
),
|
||||
)[0]
|
||||
|
||||
default_provider = providers[0]
|
||||
default_provider.is_default = True
|
||||
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
||||
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||
|
|
@ -319,6 +321,7 @@ class BuiltinToolManageService:
|
|||
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
|
||||
credential_info = ToolProviderCredentialInfoApiEntity(
|
||||
supported_credential_types=supported_credential_types,
|
||||
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
|
|
@ -368,30 +371,61 @@ class BuiltinToolManageService:
|
|||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_oauth_client(
|
||||
tenant_id: str, provider: str, plugin_id: str
|
||||
) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]:
|
||||
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
|
||||
"""
|
||||
get builtin tool provider
|
||||
check if oauth custom client is enabled
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
user_client = (
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_client:
|
||||
return user_client
|
||||
return user_client is not None and user_client.enabled
|
||||
|
||||
system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
|
||||
if system_client is None:
|
||||
raise ValueError("no oauth available client config found for this tool provider")
|
||||
return system_client
|
||||
@staticmethod
|
||||
def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
get builtin tool provider
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
encrypter, _ = create_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
oauth_params: dict[str, Any] | None = None
|
||||
if user_client:
|
||||
oauth_params = encrypter.decrypt(user_client.oauth_params)
|
||||
return oauth_params
|
||||
|
||||
system_client: ToolOAuthSystemClient | None = (
|
||||
session.query(ToolOAuthSystemClient)
|
||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||
.first()
|
||||
)
|
||||
if system_client:
|
||||
oauth_params = encrypter.decrypt(system_client.oauth_params)
|
||||
|
||||
return oauth_params
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(provider: str):
|
||||
|
|
@ -533,12 +567,79 @@ class BuiltinToolManageService:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict):
|
||||
def save_custom_oauth_client_params(
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
client_params: Optional[dict] = None,
|
||||
enable_oauth_custom_client: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
setup oauth custom client
|
||||
"""
|
||||
if client_params is None and enable_oauth_custom_client is None:
|
||||
return {"result": "success"}
|
||||
|
||||
tool_provider = ToolProviderID(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller:
|
||||
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||
|
||||
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
custom_client_params = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# if the record does not exist, create a basic record
|
||||
if custom_client_params is None:
|
||||
custom_client_params = ToolOAuthTenantClient(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
)
|
||||
session.add(custom_client_params)
|
||||
|
||||
if client_params is not None:
|
||||
encrypter, _ = create_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(client_params))
|
||||
|
||||
if enable_oauth_custom_client is not None:
|
||||
custom_client_params.enabled = enable_oauth_custom_client
|
||||
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_custom_oauth_client_params(tenant_id: str, provider: str):
|
||||
"""
|
||||
get custom oauth client params
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
tool_provider = ToolProviderID(provider)
|
||||
custom_oauth_client_params: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if custom_oauth_client_params is None:
|
||||
return {}
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller:
|
||||
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||
|
|
@ -551,17 +652,4 @@ class BuiltinToolManageService:
|
|||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# encrypt credentials
|
||||
encrypted_credentials = encrypter.encrypt(client_params)
|
||||
session.add(
|
||||
ToolOAuthTenantClient(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
enabled=True,
|
||||
encrypted_oauth_params=json.dumps(encrypted_credentials),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
|
||||
|
|
|
|||
Loading…
Reference in New Issue