From ef330fec2c100f7b12a44c58647c2a4ed024d6e5 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 9 Jul 2025 11:54:10 +0800 Subject: [PATCH] feat(oauth): add credential validation for providers --- .../tools/builtin_tools_manage_service.py | 97 +++++++++++-------- 1 file changed, 54 insertions(+), 43 deletions(-) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 8e7b179ea7..f32df9763d 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -95,9 +95,7 @@ class BuiltinToolManageService: return entity @staticmethod - def list_builtin_provider_credentials_schema( - provider_name: str, credential_type: CredentialType, tenant_id: str - ): + def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str): """ list builtin provider credentials schema @@ -141,7 +139,8 @@ class BuiltinToolManageService: if key in masked_credentials and value == masked_credentials[key]: credentials[key] = original_credentials[key] - provider_controller.validate_credentials(user_id, credentials) + if CredentialType.of(db_provider.credential_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, credentials) # encrypt credentials db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) @@ -159,6 +158,7 @@ class BuiltinToolManageService: ToolNotFoundError, ToolProviderCredentialValidationError, ) as e: + db.session.rollback() raise ValueError(str(e)) return {"result": "success"} @@ -176,46 +176,59 @@ class BuiltinToolManageService: add builtin tool provider """ lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" - with redis_client.lock(lock, timeout=20): - # check if the provider count is over the limit - provider_count = ( - db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() - ) - if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: - raise ValueError(f"you have reached the maximum number of providers for {provider}") + try: + with redis_client.lock(lock, timeout=20): + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider} does not need credentials") - # TODO should we get name from oauth authentication? - name = ( - name - if name - else BuiltinToolManageService.generate_builtin_tool_provider_name( - tenant_id=tenant_id, provider=provider, credential_type=api_type + provider_count = ( + db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() ) - ) - db_provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider, - encrypted_credentials=json.dumps(credentials), - credential_type=api_type.value, - name=name, - ) + # check if the provider count is reached the limit + if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: + raise ValueError(f"you have reached the maximum number of providers for {provider}") - provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider} does not need credentials") + # validate credentials if allowed + if CredentialType.of(api_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, credentials) - encrypter, cache = BuiltinToolManageService.create_tool_encrypter( - tenant_id, db_provider, provider, provider_controller - ) + # generate name if not provided + if name is None: + name = BuiltinToolManageService.generate_builtin_tool_provider_name( + tenant_id=tenant_id, provider=provider, credential_type=api_type + ) - # encrypt credentials - db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) + # create encrypter + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(api_type) + ], + cache=NoOpProviderCredentialCache(), + ) - cache.delete() - db.session.add(db_provider) - db.session.commit() + db_provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), + credential_type=api_type.value, + name=name, + ) + + db.session.add(db_provider) + db.session.commit() + except ( + PluginDaemonClientSideError, + ToolProviderNotFoundError, + ToolNotFoundError, + ToolProviderCredentialValidationError, + ) as e: + db.session.rollback() + raise ValueError(str(e)) return {"result": "success"} @staticmethod @@ -236,9 +249,7 @@ class BuiltinToolManageService: return encrypter, cache @staticmethod - def generate_builtin_tool_provider_name( - tenant_id: str, provider: str, credential_type: CredentialType - ) -> str: + def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str: try: db_providers = ( db.session.query(BuiltinToolProvider) @@ -324,7 +335,7 @@ class BuiltinToolManageService: is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider), credentials=credentials, ) - + return credential_info @staticmethod @@ -362,8 +373,8 @@ class BuiltinToolManageService: # clear default provider session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, user_id=user_id, provider=provider, default=True - ).update({"default": False}) + tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True + ).update({"is_default": False}) # set new default provider target_provider.is_default = True