diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 690b06ea7d..89762d6772 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -17,6 +17,7 @@ from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.entities import AuthActionType, AuthResult from core.mcp.error import MCPAuthError, MCPError from core.mcp.types import Tool as MCPTool from core.tools.entities.api_entities import ToolProviderApiEntity @@ -496,7 +497,13 @@ class MCPToolManageService: ) as mcp_client: return mcp_client.list_tools() - def execute_auth_actions(self, auth_result: Any) -> dict[str, str]: + _ACTION_TO_OAUTH: dict[AuthActionType, OAuthDataType] = { + AuthActionType.SAVE_CLIENT_INFO: OAuthDataType.CLIENT_INFO, + AuthActionType.SAVE_TOKENS: OAuthDataType.TOKENS, + AuthActionType.SAVE_CODE_VERIFIER: OAuthDataType.CODE_VERIFIER, + } + + def execute_auth_actions(self, auth_result: AuthResult) -> dict[str, str]: """ Execute the actions returned by the auth function. @@ -508,19 +515,13 @@ class MCPToolManageService: Returns: The response from the auth result """ - from core.mcp.entities import AuthAction, AuthActionType - - action: AuthAction for action in auth_result.actions: if action.provider_id is None or action.tenant_id is None: continue - if action.action_type == AuthActionType.SAVE_CLIENT_INFO: - self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO) - elif action.action_type == AuthActionType.SAVE_TOKENS: - self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS) - elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER: - self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER) + oauth_type = self._ACTION_TO_OAUTH.get(action.action_type) + if oauth_type is not None: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, oauth_type) return auth_result.response