mirror of https://github.com/langgenius/dify.git
feat(oauth): refactor tool provider methods and enhance credential handling
This commit is contained in:
parent
8a954c0b19
commit
daec82bd44
|
|
@ -1,6 +1,6 @@
|
|||
import io
|
||||
|
||||
from flask import redirect, request, send_file
|
||||
from flask import make_response, redirect, request, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import (
|
||||
Resource,
|
||||
|
|
@ -17,6 +17,7 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -127,7 +128,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
|||
return BuiltinToolManageService.add_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
api_type=ToolProviderCredentialType.of(args["type"]),
|
||||
|
|
@ -373,10 +374,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, provider, credential_type):
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id)
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
||||
provider, ToolProviderCredentialType.of(credential_type), tenant_id
|
||||
)
|
||||
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
|
|
@ -613,15 +615,12 @@ class ToolApiListApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in ApiToolManageService.list_api_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
|
|
@ -662,13 +661,10 @@ class ToolPluginOAuthApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
provider = args["provider"]
|
||||
plugin_id = args["plugin_id"]
|
||||
def get(self, provider):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
|
||||
# todo check permission
|
||||
user = current_user
|
||||
|
|
@ -679,63 +675,66 @@ class ToolPluginOAuthApi(Resource):
|
|||
tenant_id = user.current_tenant_id
|
||||
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider
|
||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||
)
|
||||
# todo decrypt oauth params
|
||||
# TODO decrypt oauth params
|
||||
oauth_params = plugin_oauth_config.oauth_params
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
|
||||
oauth_params["redirect_uri"] = redirect_uri
|
||||
|
||||
response = oauth_handler.get_authorization_url(
|
||||
tenant_id,
|
||||
user.id,
|
||||
plugin_id,
|
||||
provider,
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_params,
|
||||
)
|
||||
return response.model_dump()
|
||||
response = make_response(jsonable_encoder(authorization_url_response))
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class ToolOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
args = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("context_id", type=str, required=True, nullable=False, location="args")
|
||||
.parse_args()
|
||||
)
|
||||
context_id = args["context_id"]
|
||||
def get(self, provider):
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id, plugin_id, provider = (
|
||||
context.get("user_id"),
|
||||
context.get("tenant_id"),
|
||||
context.get("plugin_id"),
|
||||
context.get("provider"),
|
||||
)
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
oauth_params = plugin_oauth_config.oauth_params
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
|
||||
oauth_params["redirect_uri"] = redirect_uri
|
||||
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||
credentials = oauth_handler.get_credentials(
|
||||
tenant_id,
|
||||
user_id,
|
||||
plugin_id,
|
||||
provider,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_params,
|
||||
request=request,
|
||||
).credentials
|
||||
|
|
@ -747,12 +746,11 @@ class ToolOAuthCallback(Resource):
|
|||
BuiltinToolManageService.add_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
provider=provider,
|
||||
credentials=dict(credentials),
|
||||
name=provider,
|
||||
api_type=ToolProviderCredentialType.OAUTH2,
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success")
|
||||
|
||||
|
||||
class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
|
|
@ -768,9 +766,41 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return BuiltinToolManageService.setup_oauth_custom_client(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider=provider,
|
||||
client_params=args["client_params"],
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider_name=provider
|
||||
)
|
||||
|
||||
|
||||
# tool oauth
|
||||
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool")
|
||||
api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback")
|
||||
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
|
||||
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
|
||||
|
||||
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||
|
||||
# tool provider
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
|
|
@ -782,14 +812,14 @@ api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/b
|
|||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/set-default"
|
||||
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/<path:credential_type>/credentials_schema",
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema/<path:credential_type>",
|
||||
)
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
|
|||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
OAuthSchema,
|
||||
ToolEntity,
|
||||
ToolProviderCredentialType,
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
|
|
@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
|
||||
credentials_schema.append(credential_dict)
|
||||
|
||||
oauth_schema = None
|
||||
if provider_yaml.get("oauth_schema", None) is not None:
|
||||
oauth_schema = OAuthSchema(
|
||||
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
|
||||
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
entity=ToolProviderEntity(
|
||||
identity=provider_yaml["identity"],
|
||||
credentials_schema=credentials_schema,
|
||||
oauth_schema=oauth_schema,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -91,16 +105,20 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
"""
|
||||
return self.tools
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
def get_credentials_schema(
|
||||
self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY
|
||||
) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if not self.entity.credentials_schema:
|
||||
return []
|
||||
|
||||
return self.entity.credentials_schema.copy()
|
||||
if credential_type == ToolProviderCredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
elif credential_type == ToolProviderCredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -344,10 +344,18 @@ class ToolEntity(BaseModel):
|
|||
return v or []
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: Optional[str] = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
oauth_schema: Optional[OAuthSchema] = None
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
|
|
@ -437,7 +445,7 @@ class ToolSelector(BaseModel):
|
|||
|
||||
|
||||
class ToolProviderCredentialType(enum.StrEnum):
|
||||
API_KEY = "api_key"
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
|
||||
def get_name(self):
|
||||
|
|
@ -446,7 +454,7 @@ class ToolProviderCredentialType(enum.StrEnum):
|
|||
elif self == ToolProviderCredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
else:
|
||||
return self.value.replace("_", " ").upper()
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == ToolProviderCredentialType.API_KEY
|
||||
|
|
@ -461,7 +469,7 @@ class ToolProviderCredentialType(enum.StrEnum):
|
|||
@classmethod
|
||||
def of(cls, credential_type: str) -> "ToolProviderCredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api_key":
|
||||
if type_name == "api-key":
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
return cls.OAUTH2
|
||||
|
|
|
|||
|
|
@ -34,7 +34,13 @@ from core.tools.custom_tool.provider import ApiToolProviderController
|
|||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderCredentialType,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
|
||||
|
|
@ -202,7 +208,12 @@ class ToolManager:
|
|||
credentials = builtin_provider.credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema(
|
||||
ToolProviderCredentialType.of(builtin_provider.credential_type)
|
||||
)
|
||||
],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -64,7 +64,10 @@ class BuiltinToolProvider(Base):
|
|||
"""
|
||||
|
||||
__tablename__ = "tool_builtin_providers"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),)
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
|
||||
)
|
||||
|
||||
# id of the tool provider
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
|
|
@ -86,9 +89,9 @@ class BuiltinToolProvider(Base):
|
|||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
# credential type, e.g., "api_key", "oauth2"
|
||||
# credential type, e.g., "api-key", "oauth2"
|
||||
credential_type: Mapped[str] = mapped_column(
|
||||
db.String(32), nullable=False, server_default=db.text("'api_key'::character varying")
|
||||
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
import json
|
||||
import uuid
|
||||
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
|
|
|||
|
|
@ -446,7 +446,7 @@ class ApiToolManageService:
|
|||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list api tools
|
||||
"""
|
||||
|
|
@ -474,7 +474,7 @@ class ApiToolManageService:
|
|||
for tool in tools or []:
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
||||
tenant_id=tenant_id, tool=tool, labels=labels
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import re
|
|||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import ColumnExpressionArgument
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -13,10 +12,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
|
|
@ -29,6 +30,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class BuiltinToolManageService:
|
||||
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
|
|
@ -42,22 +45,11 @@ class BuiltinToolManageService:
|
|||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
result: list[ToolApiEntity] = []
|
||||
for tool in tools or []:
|
||||
result.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
|
|
@ -73,7 +65,7 @@ class BuiltinToolManageService:
|
|||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
|
|
@ -92,16 +84,19 @@ class BuiltinToolManageService:
|
|||
return entity
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str):
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str
|
||||
):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:param credential_type: credential type
|
||||
:param provider_name: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return jsonable_encoder(provider.get_credentials_schema())
|
||||
return jsonable_encoder(provider.get_credentials_schema(credential_type))
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
|
|
@ -111,11 +106,11 @@ class BuiltinToolManageService:
|
|||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
|
||||
provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
|
||||
try:
|
||||
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
|
|
@ -133,10 +128,12 @@ class BuiltinToolManageService:
|
|||
if key in masked_credentials and value == masked_credentials[key]:
|
||||
credentials[key] = original_credentials[key]
|
||||
|
||||
# Encrypt and save the credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller, tool_configuration, provider, credentials, user_id
|
||||
)
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
|
||||
# encrypt credentials
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
# update name if provided
|
||||
if name is not None and provider.name != name:
|
||||
|
|
@ -158,68 +155,84 @@ class BuiltinToolManageService:
|
|||
user_id: str,
|
||||
api_type: ToolProviderCredentialType,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}"
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
if name is None:
|
||||
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
|
||||
# check if the provider count is over the limit
|
||||
provider_count = (
|
||||
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
||||
)
|
||||
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
|
||||
raise ValueError(f"you have reached the maximum number of providers for {provider}")
|
||||
|
||||
provider = BuiltinToolProvider(
|
||||
# TODO should we get name from oauth authentication?
|
||||
name = (
|
||||
name
|
||||
if name
|
||||
else BuiltinToolManageService.generate_builtin_tool_provider_name(
|
||||
tenant_id, provider, credential_type=api_type
|
||||
)
|
||||
)
|
||||
|
||||
db_provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
provider=provider,
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
credential_type=api_type.value,
|
||||
name=name,
|
||||
)
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
|
||||
# Encrypt and save the credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller, tool_configuration, provider, credentials, user_id
|
||||
provider_controller=provider_controller,
|
||||
tool_configuration=tool_configuration,
|
||||
provider=db_provider,
|
||||
credentials=credentials,
|
||||
user_id=user_id,
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.add(db_provider)
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_next_builtin_tool_provider_name(
|
||||
tenant_id: str, provider_name: str, type: ToolProviderCredentialType
|
||||
def generate_builtin_tool_provider_name(
|
||||
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType
|
||||
) -> str:
|
||||
try:
|
||||
providers = (
|
||||
db_providers = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
credential_type=type.value,
|
||||
provider=provider,
|
||||
credential_type=credential_type.value,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.created_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get the default name pattern
|
||||
default_pattern = type.get_name()
|
||||
default_pattern = f"{credential_type.get_name()}"
|
||||
|
||||
# Find all names that match the default pattern: "{default_pattern} {number}"
|
||||
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
|
||||
numbers = []
|
||||
|
||||
for provider in providers:
|
||||
if provider.name:
|
||||
match = re.match(pattern, provider.name.strip())
|
||||
for db_provider in db_providers:
|
||||
if db_provider.name:
|
||||
match = re.match(pattern, db_provider.name.strip())
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
|
|
@ -231,9 +244,9 @@ class BuiltinToolManageService:
|
|||
max_number = max(numbers)
|
||||
return f"{default_pattern} {max_number + 1}"
|
||||
except Exception as e:
|
||||
logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}")
|
||||
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
|
||||
# fallback
|
||||
return f"{type.get_name()} 1"
|
||||
return f"{credential_type.get_name()} 1"
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
|
|
@ -242,31 +255,43 @@ class BuiltinToolManageService:
|
|||
"""
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
|
||||
|
||||
if len(providers) == 0:
|
||||
return []
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||
for provider in providers:
|
||||
decrypt_credential = tool_configuration.mask_tool_credentials(
|
||||
tool_configuration.decrypt(provider.credentials)
|
||||
with db.session.no_autoflush:
|
||||
providers = (
|
||||
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
|
||||
)
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
credentials=decrypt_credential,
|
||||
)
|
||||
credentials.append(credential_entity)
|
||||
return credentials
|
||||
|
||||
if len(providers) == 0:
|
||||
return []
|
||||
|
||||
default_provider = sorted(
|
||||
providers,
|
||||
key=lambda p: (
|
||||
not getattr(p, "is_default", False),
|
||||
getattr(p, "created_at", None) or 0,
|
||||
),
|
||||
)[0]
|
||||
|
||||
default_provider.is_default = True
|
||||
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||
for provider in providers:
|
||||
decrypt_credential = tool_configuration.mask_tool_credentials(
|
||||
tool_configuration.decrypt(provider.credentials)
|
||||
)
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
credentials=decrypt_credential,
|
||||
)
|
||||
credentials.append(credential_entity)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
|
||||
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
||||
|
||||
if tool_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
|
@ -387,7 +412,6 @@ class BuiltinToolManageService:
|
|||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
|
@ -399,7 +423,7 @@ class BuiltinToolManageService:
|
|||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
|
||||
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
|
||||
provider: Optional[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
|
|
@ -411,47 +435,62 @@ class BuiltinToolManageService:
|
|||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
||||
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
||||
"""
|
||||
This method is used to fetch the builtin provider from the database
|
||||
1.if the default provider exists, return the default provider
|
||||
2.if the default provider does not exist, return the oldest provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = ToolProviderID(provider_name)
|
||||
provider_name = provider_id_entity.provider_name
|
||||
|
||||
def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]:
|
||||
return (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == full_provider_name,
|
||||
)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
| (BuiltinToolProvider.provider == full_provider_name),
|
||||
)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
return None
|
||||
|
||||
provider.provider = ToolProviderID(provider.provider).to_string()
|
||||
return provider
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
return (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = ToolProviderID(provider_name)
|
||||
provider_name = provider_id_entity.provider_name
|
||||
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
provider = _query([BuiltinToolProvider.provider == full_provider_name])
|
||||
else:
|
||||
provider = _query(
|
||||
[
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
| (BuiltinToolProvider.provider == full_provider_name)
|
||||
]
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
return None
|
||||
|
||||
provider.provider = ToolProviderID(provider.provider).to_string()
|
||||
return provider
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
return _query([BuiltinToolProvider.provider == provider_name])
|
||||
|
||||
@staticmethod
|
||||
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
|
||||
|
|
@ -463,7 +502,13 @@ class BuiltinToolManageService:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
|
||||
def _encrypt_and_save_credentials(
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
tool_configuration: ProviderConfigEncrypter,
|
||||
provider: BuiltinToolProvider,
|
||||
credentials: dict,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Validate and encrypt credentials, then save to database
|
||||
|
||||
|
|
@ -480,3 +525,25 @@ class BuiltinToolManageService:
|
|||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
@staticmethod
|
||||
def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict):
|
||||
"""
|
||||
setup oauth custom client
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller:
|
||||
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
|
||||
# Validate and encrypt credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller=provider_controller,
|
||||
tool_configuration=tool_configuration,
|
||||
provider=None, # No need to save in DB
|
||||
credentials=client_params,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -255,7 +255,6 @@ class ToolTransformService:
|
|||
def convert_tool_entity_to_api_entity(
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
tenant_id: str,
|
||||
credentials: dict | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> ToolApiEntity:
|
||||
"""
|
||||
|
|
@ -265,7 +264,7 @@ class ToolTransformService:
|
|||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials or {},
|
||||
credentials= {},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue