chore: handle session

This commit is contained in:
Novice 2025-10-27 13:02:13 +08:00
parent e7a575a33c
commit 0b021273bc
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
3 changed files with 41 additions and 19 deletions

View File

@ -1021,28 +1021,32 @@ class ToolMCPAuthApi(Resource):
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Create new transaction for update
with session.begin():
# Update credentials in new transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider_credentials(
provider=db_provider,
provider_id=provider_id,
tenant_id=tenant_id,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
try:
auth_result = auth(provider_entity, args.get("authorization_code"))
with session.begin():
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
return response
return response
except MCPRefreshTokenError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e

View File

@ -296,18 +296,22 @@ class MCPToolManageService:
# ========== OAuth and Credentials Operations ==========
def update_provider_credentials(
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool | None = None
self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None
) -> None:
"""
Update provider credentials with encryption.
Args:
provider: Provider instance
provider_id: Provider ID
tenant_id: Tenant ID
credentials: Credentials to save
authed: Whether provider is authenticated (None means keep current state)
"""
from core.tools.mcp_tool.provider import MCPToolProviderController
# Get provider from current session
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Encrypt new credentials
provider_controller = MCPToolProviderController.from_db(provider)
tool_configuration = ProviderConfigEncrypter(
@ -341,17 +345,25 @@ class MCPToolManageService:
data: Data to save (tokens, client info, or code verifier)
data_type: Type of OAuth data to save
"""
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Determine if this makes the provider authenticated
authed = (
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
)
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
# update_provider_credentials will validate provider existence
self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed)
def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None:
"""
Clear all credentials for a provider.
Args:
provider_id: Provider ID
tenant_id: Tenant ID
"""
# Get provider from current session
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
"""Clear all credentials for a provider."""
provider.tools = EMPTY_TOOLS_JSON
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
provider.updated_at = datetime.now()

View File

@ -1206,7 +1206,10 @@ class TestMCPToolManageService:
service = MCPToolManageService(db.session())
service.update_provider_credentials(
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True
provider_id=mcp_provider.id,
tenant_id=tenant.id,
credentials={"new_key": "new_value"},
authed=True,
)
# Assert: Verify the expected outcomes
@ -1267,7 +1270,10 @@ class TestMCPToolManageService:
service = MCPToolManageService(db.session())
service.update_provider_credentials(
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False
provider_id=mcp_provider.id,
tenant_id=tenant.id,
credentials={"new_key": "new_value"},
authed=False,
)
# Assert: Verify the expected outcomes