diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index ceea178214..5da20c3d29 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,6 +1,6 @@ import io -from flask import redirect, request, send_file +from flask import make_response, redirect, request, send_file from flask_login import current_user from flask_restful import ( Resource, @@ -17,6 +17,7 @@ from controllers.console.wraps import ( setup_required, ) from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ToolProviderCredentialType from extensions.ext_database import db @@ -127,7 +128,7 @@ class ToolBuiltinProviderAddApi(Resource): return BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, - provider_name=provider, + provider=provider, credentials=args["credentials"], name=args["name"], api_type=ToolProviderCredentialType.of(args["type"]), @@ -373,10 +374,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @account_initialization_required def get(self, provider, credential_type): user = current_user - tenant_id = user.current_tenant_id - return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id) + return BuiltinToolManageService.list_builtin_provider_credentials_schema( + provider, ToolProviderCredentialType.of(credential_type), tenant_id + ) class ToolApiProviderSchemaApi(Resource): @@ -613,15 +615,12 @@ class ToolApiListApi(Resource): @account_initialization_required def get(self): user = current_user - - user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder( [ provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, tenant_id, ) ] @@ -662,13 +661,10 @@ class ToolPluginOAuthApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") - parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() - provider = args["provider"] - plugin_id = args["plugin_id"] + def get(self, provider): + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name # todo check permission user = current_user @@ -679,63 +675,66 @@ class ToolPluginOAuthApi(Resource): tenant_id = user.current_tenant_id plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( tenant_id=tenant_id, - provider=provider, + provider=provider_name, plugin_id=plugin_id, ) oauth_handler = OAuthHandler() context_id = OAuthProxyService.create_proxy_context( - user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider + user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name ) - # todo decrypt oauth params + # TODO decrypt oauth params oauth_params = plugin_oauth_config.oauth_params - redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" - oauth_params["redirect_uri"] = redirect_uri - - response = oauth_handler.get_authorization_url( - tenant_id, - user.id, - plugin_id, - provider, + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" + authorization_url_response = oauth_handler.get_authorization_url( + tenant_id=tenant_id, + user_id=user.id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, system_credentials=oauth_params, ) - return response.model_dump() + response = make_response(jsonable_encoder(authorization_url_response)) + response.set_cookie( + "context_id", + context_id, + httponly=True, + samesite="Lax", + max_age=OAuthProxyService.__MAX_AGE__, + ) + return response class ToolOAuthCallback(Resource): @setup_required - def get(self): - args = ( - reqparse.RequestParser() - .add_argument("context_id", type=str, required=True, nullable=False, location="args") - .parse_args() - ) - context_id = args["context_id"] + def get(self, provider): + context_id = request.cookies.get("context_id") + if not context_id: + raise Forbidden("context_id not found") + context = OAuthProxyService.use_proxy_context(context_id) if context is None: raise Forbidden("Invalid context_id") - user_id, tenant_id, plugin_id, provider = ( - context.get("user_id"), - context.get("tenant_id"), - context.get("plugin_id"), - context.get("provider"), - ) + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name + user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + oauth_handler = OAuthHandler() plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( tenant_id=tenant_id, - provider=provider, + provider=provider_name, plugin_id=plugin_id, ) oauth_params = plugin_oauth_config.oauth_params - redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" - oauth_params["redirect_uri"] = redirect_uri - + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" credentials = oauth_handler.get_credentials( - tenant_id, - user_id, - plugin_id, - provider, + tenant_id=tenant_id, + user_id=user_id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, system_credentials=oauth_params, request=request, ).credentials @@ -747,12 +746,11 @@ class ToolOAuthCallback(Resource): BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, - provider_name=provider, + provider=provider, credentials=dict(credentials), - name=provider, api_type=ToolProviderCredentialType.OAUTH2, ) - return redirect(f"{dify_config.CONSOLE_WEB_URL}") + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success") class ToolBuiltinProviderSetDefaultApi(Resource): @@ -768,9 +766,41 @@ class ToolBuiltinProviderSetDefaultApi(Resource): ) +class ToolOAuthCustomClient(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + return BuiltinToolManageService.setup_oauth_custom_client( + tenant_id=user.current_tenant_id, + user_id=user.id, + provider=provider, + client_params=args["client_params"], + ) + + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return BuiltinToolManageService.get_builtin_tool_provider_credentials( + tenant_id=current_user.current_tenant_id, provider_name=provider + ) + + # tool oauth -api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool") -api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback") +api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url") +api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback") + +api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client") # tool provider api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") @@ -782,14 +812,14 @@ api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/b api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") api.add_resource( - ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//set-default" + ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential" ) api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) api.add_resource( ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin///credentials_schema", + "/workspaces/current/tool-provider/builtin//credentials_schema/", ) api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index cf75bd3d7e..9e3c13849f 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType +from core.tools.entities.tool_entities import ( + OAuthSchema, + ToolEntity, + ToolProviderCredentialType, + ToolProviderEntity, + ToolProviderType, +) from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolProviderNotFoundError, @@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController): credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {}) credentials_schema.append(credential_dict) + oauth_schema = None + if provider_yaml.get("oauth_schema", None) is not None: + oauth_schema = OAuthSchema( + client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []), + credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []), + ) + super().__init__( entity=ToolProviderEntity( identity=provider_yaml["identity"], credentials_schema=credentials_schema, + oauth_schema=oauth_schema, ), ) @@ -91,16 +105,20 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.tools - def get_credentials_schema(self) -> list[ProviderConfig]: + def get_credentials_schema( + self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY + ) -> list[ProviderConfig]: """ returns the credentials schema of the provider :return: the credentials schema """ - if not self.entity.credentials_schema: - return [] - - return self.entity.credentials_schema.copy() + if credential_type == ToolProviderCredentialType.OAUTH2: + return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] + elif credential_type == ToolProviderCredentialType.API_KEY: + return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + else: + raise ValueError(f"Invalid credential type: {credential_type}") def get_tools(self) -> list[BuiltinTool]: """ diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 5094519b6f..922e30b2e0 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -344,10 +344,18 @@ class ToolEntity(BaseModel): return v or [] +class OAuthSchema(BaseModel): + client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client") + credentials_schema: list[ProviderConfig] = Field( + default_factory=list, description="The schema of the OAuth credentials" + ) + + class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity plugin_id: Optional[str] = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) + oauth_schema: Optional[OAuthSchema] = None class ToolProviderEntityWithPlugin(ToolProviderEntity): @@ -437,7 +445,7 @@ class ToolSelector(BaseModel): class ToolProviderCredentialType(enum.StrEnum): - API_KEY = "api_key" + API_KEY = "api-key" OAUTH2 = "oauth2" def get_name(self): @@ -446,7 +454,7 @@ class ToolProviderCredentialType(enum.StrEnum): elif self == ToolProviderCredentialType.OAUTH2: return "AUTH" else: - return self.value.replace("_", " ").upper() + return self.value.replace("-", " ").upper() def is_editable(self): return self == ToolProviderCredentialType.API_KEY @@ -461,7 +469,7 @@ class ToolProviderCredentialType(enum.StrEnum): @classmethod def of(cls, credential_type: str) -> "ToolProviderCredentialType": type_name = credential_type.lower() - if type_name == "api_key": + if type_name == "api-key": return cls.API_KEY elif type_name == "oauth2": return cls.OAUTH2 diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index bd4a635923..35d4eb0c7e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -34,7 +34,13 @@ from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.tool import ApiTool from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolInvokeFrom, + ToolParameter, + ToolProviderCredentialType, + ToolProviderType, +) from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager @@ -202,7 +208,12 @@ class ToolManager: credentials = builtin_provider.credentials tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema( + ToolProviderCredentialType.of(builtin_provider.credential_type) + ) + ], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) diff --git a/api/models/tools.py b/api/models/tools.py index b2979a69dc..ef2f7bcdde 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -64,7 +64,10 @@ class BuiltinToolProvider(Base): """ __tablename__ = "tool_builtin_providers" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),) + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), + db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), + ) # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) @@ -86,9 +89,9 @@ class BuiltinToolProvider(Base): db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - # credential type, e.g., "api_key", "oauth2" + # credential type, e.g., "api-key", "oauth2" credential_type: Mapped[str] = mapped_column( - db.String(32), nullable=False, server_default=db.text("'api_key'::character varying") + db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") ) @property diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 4ad3335ff6..b84dd0afc5 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,3 +1,6 @@ +import json +import uuid + from core.plugin.impl.base import BasePluginClient from extensions.ext_redis import redis_client diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 6f848d49c4..b429851349 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -446,7 +446,7 @@ class ApiToolManageService: return {"result": result or "empty response"} @staticmethod - def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: + def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ list api tools """ @@ -474,7 +474,7 @@ class ApiToolManageService: for tool in tools or []: user_provider.tools.append( ToolTransformService.convert_tool_entity_to_api_entity( - tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + tenant_id=tenant_id, tool=tool, labels=labels ) ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 0137e13b20..80ee9b080c 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -4,7 +4,6 @@ import re from pathlib import Path from typing import Optional, Union -from sqlalchemy import ColumnExpressionArgument from sqlalchemy.orm import Session from configs import dify_config @@ -13,10 +12,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.__base.tool_provider import ToolProviderController +from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter @@ -29,6 +30,8 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: + __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 + @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: """ @@ -42,22 +45,11 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) - # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_configuration.decrypt(credentials) - result: list[ToolApiEntity] = [] for tool in tools or []: result.append( ToolTransformService.convert_tool_entity_to_api_entity( tool=tool, - credentials=credentials, tenant_id=tenant_id, labels=ToolLabelManager.get_tool_labels(provider_controller), ) @@ -73,7 +65,7 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) + builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id) credentials = {} if builtin_provider is not None: @@ -92,16 +84,19 @@ class BuiltinToolManageService: return entity @staticmethod - def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str): + def list_builtin_provider_credentials_schema( + provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str + ): """ list builtin provider credentials schema + :param credential_type: credential type :param provider_name: the name of the provider :param tenant_id: the id of the tenant :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return jsonable_encoder(provider.get_credentials_schema()) + return jsonable_encoder(provider.get_credentials_schema(credential_type)) @staticmethod def update_builtin_tool_provider( @@ -111,11 +106,11 @@ class BuiltinToolManageService: update builtin tool provider """ # get if the provider exists - provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) + provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) if provider is None: raise ValueError(f"you have not added provider {provider_name}") - + try: if ToolProviderCredentialType.of(provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) @@ -133,10 +128,12 @@ class BuiltinToolManageService: if key in masked_credentials and value == masked_credentials[key]: credentials[key] = original_credentials[key] - # Encrypt and save the credentials - BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller, tool_configuration, provider, credentials, user_id - ) + provider_controller.validate_credentials(user_id, credentials) + + # encrypt credentials + encrypted_credentials = tool_configuration.encrypt(credentials) + provider.encrypted_credentials = json.dumps(encrypted_credentials) + tool_configuration.delete_tool_credentials_cache() # update name if provided if name is not None and provider.name != name: @@ -158,68 +155,84 @@ class BuiltinToolManageService: user_id: str, api_type: ToolProviderCredentialType, tenant_id: str, - provider_name: str, + provider: str, credentials: dict, name: str | None = None, ): """ add builtin tool provider """ - lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}" + lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" with redis_client.lock(lock, timeout=20): - if name is None: - name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type) + # check if the provider count is over the limit + provider_count = ( + db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() + ) + if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: + raise ValueError(f"you have reached the maximum number of providers for {provider}") - provider = BuiltinToolProvider( + # TODO should we get name from oauth authentication? + name = ( + name + if name + else BuiltinToolManageService.generate_builtin_tool_provider_name( + tenant_id, provider, credential_type=api_type + ) + ) + + db_provider = BuiltinToolProvider( tenant_id=tenant_id, user_id=user_id, - provider=provider_name, + provider=provider, encrypted_credentials=json.dumps(credentials), credential_type=api_type.value, name=name, ) - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") + raise ValueError(f"provider {provider} does not need credentials") tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller, tool_configuration, provider, credentials, user_id + provider_controller=provider_controller, + tool_configuration=tool_configuration, + provider=db_provider, + credentials=credentials, + user_id=user_id, ) - db.session.add(provider) + db.session.add(db_provider) db.session.commit() return {"result": "success"} @staticmethod - def get_next_builtin_tool_provider_name( - tenant_id: str, provider_name: str, type: ToolProviderCredentialType + def generate_builtin_tool_provider_name( + tenant_id: str, provider: str, credential_type: ToolProviderCredentialType ) -> str: try: - providers = ( + db_providers = ( db.session.query(BuiltinToolProvider) .filter_by( tenant_id=tenant_id, - provider=provider_name, - credential_type=type.value, + provider=provider, + credential_type=credential_type.value, ) .order_by(BuiltinToolProvider.created_at.desc()) - .limit(10) .all() ) # Get the default name pattern - default_pattern = type.get_name() + default_pattern = f"{credential_type.get_name()}" # Find all names that match the default pattern: "{default_pattern} {number}" pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" numbers = [] - for provider in providers: - if provider.name: - match = re.match(pattern, provider.name.strip()) + for db_provider in db_providers: + if db_provider.name: + match = re.match(pattern, db_provider.name.strip()) if match: numbers.append(int(match.group(1))) @@ -231,9 +244,9 @@ class BuiltinToolManageService: max_number = max(numbers) return f"{default_pattern} {max_number + 1}" except Exception as e: - logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}") + logger.warning(f"Error generating next provider name for {provider}: {str(e)}") # fallback - return f"{type.get_name()} 1" + return f"{credential_type.get_name()} 1" @staticmethod def get_builtin_tool_provider_credentials( @@ -242,31 +255,43 @@ class BuiltinToolManageService: """ get builtin tool provider credentials """ - providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all() - - if len(providers) == 0: - return [] - - provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id) - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) - credentials: list[ToolProviderCredentialApiEntity] = [] - for provider in providers: - decrypt_credential = tool_configuration.mask_tool_credentials( - tool_configuration.decrypt(provider.credentials) + with db.session.no_autoflush: + providers = ( + db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all() ) - credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( - provider=provider, - credentials=decrypt_credential, - ) - credentials.append(credential_entity) - return credentials + + if len(providers) == 0: + return [] + + default_provider = sorted( + providers, + key=lambda p: ( + not getattr(p, "is_default", False), + getattr(p, "created_at", None) or 0, + ), + )[0] + + default_provider.is_default = True + provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + credentials: list[ToolProviderCredentialApiEntity] = [] + for provider in providers: + decrypt_credential = tool_configuration.mask_tool_credentials( + tool_configuration.decrypt(provider.credentials) + ) + credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( + provider=provider, + credentials=decrypt_credential, + ) + credentials.append(credential_entity) + return credentials @staticmethod def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str): """ delete tool provider """ - tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) + tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) if tool_provider is None: raise ValueError(f"you have not added provider {provider_name}") @@ -387,7 +412,6 @@ class BuiltinToolManageService: ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, - credentials=user_builtin_provider.original_credentials, labels=ToolLabelManager.get_tool_labels(provider_controller), ) ) @@ -399,7 +423,7 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]: + def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]: provider: Optional[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider) .filter( @@ -411,47 +435,62 @@ class BuiltinToolManageService: return provider @staticmethod - def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: + def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: """ This method is used to fetch the builtin provider from the database 1.if the default provider exists, return the default provider 2.if the default provider does not exist, return the oldest provider """ + with Session(db.engine) as session: + try: + full_provider_name = provider_name + provider_id_entity = ToolProviderID(provider_name) + provider_name = provider_id_entity.provider_name - def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]: - return ( - db.session.query(BuiltinToolProvider) - .filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters) - .order_by( - BuiltinToolProvider.is_default.desc(), # default=True first - BuiltinToolProvider.created_at.asc(), # oldest first + if provider_id_entity.organization != "langgenius": + provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == full_provider_name, + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() + ) + else: + provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == provider_name) + | (BuiltinToolProvider.provider == full_provider_name), + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() + ) + + if provider is None: + return None + + provider.provider = ToolProviderID(provider.provider).to_string() + return provider + except Exception: + # it's an old provider without organization + return ( + session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() ) - .first() - ) - - try: - full_provider_name = provider_name - provider_id_entity = ToolProviderID(provider_name) - provider_name = provider_id_entity.provider_name - - if provider_id_entity.organization != "langgenius": - provider = _query([BuiltinToolProvider.provider == full_provider_name]) - else: - provider = _query( - [ - (BuiltinToolProvider.provider == provider_name) - | (BuiltinToolProvider.provider == full_provider_name) - ] - ) - - if provider is None: - return None - - provider.provider = ToolProviderID(provider.provider).to_string() - return provider - except Exception: - # it's an old provider without organization - return _query([BuiltinToolProvider.provider == provider_name]) @staticmethod def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController): @@ -463,7 +502,13 @@ class BuiltinToolManageService: ) @staticmethod - def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id): + def _encrypt_and_save_credentials( + provider_controller: BuiltinToolProviderController | PluginToolProviderController, + tool_configuration: ProviderConfigEncrypter, + provider: BuiltinToolProvider, + credentials: dict, + user_id: str, + ): """ Validate and encrypt credentials, then save to database @@ -480,3 +525,25 @@ class BuiltinToolManageService: encrypted_credentials = tool_configuration.encrypt(credentials) provider.encrypted_credentials = json.dumps(encrypted_credentials) tool_configuration.delete_tool_credentials_cache() + + @staticmethod + def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict): + """ + setup oauth custom client + """ + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") + + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + + # Validate and encrypt credentials + BuiltinToolManageService._encrypt_and_save_credentials( + provider_controller=provider_controller, + tool_configuration=tool_configuration, + provider=None, # No need to save in DB + credentials=client_params, + user_id=user_id, + ) + + return {"result": "success"} diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 66be67dbe6..160352c4c0 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -255,7 +255,6 @@ class ToolTransformService: def convert_tool_entity_to_api_entity( tool: Union[ApiToolBundle, WorkflowTool, Tool], tenant_id: str, - credentials: dict | None = None, labels: list[str] | None = None, ) -> ToolApiEntity: """ @@ -265,7 +264,7 @@ class ToolTransformService: # fork tool runtime tool = tool.fork_tool_runtime( runtime=ToolRuntime( - credentials=credentials or {}, + credentials= {}, tenant_id=tenant_id, ) )