From 0b021273bc1b7db8fde704de29c7348b77dbd8f7 Mon Sep 17 00:00:00 2001 From: Novice Date: Mon, 27 Oct 2025 13:02:13 +0800 Subject: [PATCH] chore: handle session --- .../console/workspace/tool_providers.py | 24 ++++++++++------- .../tools/mcp_tools_manage_service.py | 26 ++++++++++++++----- .../tools/test_mcp_tools_manage_service.py | 10 +++++-- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index a8d4f0f5de..d8fad83f60 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1021,28 +1021,32 @@ class ToolMCPAuthApi(Resource): timeout=provider_entity.timeout, sse_read_timeout=provider_entity.sse_read_timeout, ): - # Create new transaction for update - with session.begin(): + # Update credentials in new transaction + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) service.update_provider_credentials( - provider=db_provider, + provider_id=provider_id, + tenant_id=tenant_id, credentials=provider_entity.credentials, authed=True, ) return {"result": "success"} except MCPAuthError as e: - service = MCPToolManageService(session=session) try: auth_result = auth(provider_entity, args.get("authorization_code")) - with session.begin(): + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) response = service.execute_auth_actions(auth_result) - return response + return response except MCPRefreshTokenError as e: - with session.begin(): - service.clear_provider_credentials(provider=db_provider) + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e except MCPError as e: - with session.begin(): - service.clear_provider_credentials(provider=db_provider) + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) raise ValueError(f"Failed to connect to MCP server: {e}") from e diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index b24483b9c6..ba664a0154 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -296,18 +296,22 @@ class MCPToolManageService: # ========== OAuth and Credentials Operations ========== def update_provider_credentials( - self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool | None = None + self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None ) -> None: """ Update provider credentials with encryption. Args: - provider: Provider instance + provider_id: Provider ID + tenant_id: Tenant ID credentials: Credentials to save authed: Whether provider is authenticated (None means keep current state) """ from core.tools.mcp_tool.provider import MCPToolProviderController + # Get provider from current session + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + # Encrypt new credentials provider_controller = MCPToolProviderController.from_db(provider) tool_configuration = ProviderConfigEncrypter( @@ -341,17 +345,25 @@ class MCPToolManageService: data: Data to save (tokens, client info, or code verifier) data_type: Type of OAuth data to save """ - db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) - # Determine if this makes the provider authenticated authed = ( data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None ) - self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed) + # update_provider_credentials will validate provider existence + self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed) + + def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None: + """ + Clear all credentials for a provider. + + Args: + provider_id: Provider ID + tenant_id: Tenant ID + """ + # Get provider from current session + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) - def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None: - """Clear all credentials for a provider.""" provider.tools = EMPTY_TOOLS_JSON provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON provider.updated_at = datetime.now() diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 3c77d0c0da..8c190762cf 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -1206,7 +1206,10 @@ class TestMCPToolManageService: service = MCPToolManageService(db.session()) service.update_provider_credentials( - provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True + provider_id=mcp_provider.id, + tenant_id=tenant.id, + credentials={"new_key": "new_value"}, + authed=True, ) # Assert: Verify the expected outcomes @@ -1267,7 +1270,10 @@ class TestMCPToolManageService: service = MCPToolManageService(db.session()) service.update_provider_credentials( - provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False + provider_id=mcp_provider.id, + tenant_id=tenant.id, + credentials={"new_key": "new_value"}, + authed=False, ) # Assert: Verify the expected outcomes