mirror of https://github.com/langgenius/dify.git
chore: handle session
This commit is contained in:
parent
e7a575a33c
commit
0b021273bc
|
|
@ -1021,28 +1021,32 @@ class ToolMCPAuthApi(Resource):
|
|||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Create new transaction for update
|
||||
with session.begin():
|
||||
# Update credentials in new transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
service = MCPToolManageService(session=session)
|
||||
try:
|
||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||
with session.begin():
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
return response
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -296,18 +296,22 @@ class MCPToolManageService:
|
|||
# ========== OAuth and Credentials Operations ==========
|
||||
|
||||
def update_provider_credentials(
|
||||
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool | None = None
|
||||
self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update provider credentials with encryption.
|
||||
|
||||
Args:
|
||||
provider: Provider instance
|
||||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
credentials: Credentials to save
|
||||
authed: Whether provider is authenticated (None means keep current state)
|
||||
"""
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
|
||||
# Get provider from current session
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Encrypt new credentials
|
||||
provider_controller = MCPToolProviderController.from_db(provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
|
|
@ -341,17 +345,25 @@ class MCPToolManageService:
|
|||
data: Data to save (tokens, client info, or code verifier)
|
||||
data_type: Type of OAuth data to save
|
||||
"""
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Determine if this makes the provider authenticated
|
||||
authed = (
|
||||
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
|
||||
)
|
||||
|
||||
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
|
||||
# update_provider_credentials will validate provider existence
|
||||
self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed)
|
||||
|
||||
def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Clear all credentials for a provider.
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
"""
|
||||
# Get provider from current session
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
|
||||
"""Clear all credentials for a provider."""
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
||||
provider.updated_at = datetime.now()
|
||||
|
|
|
|||
|
|
@ -1206,7 +1206,10 @@ class TestMCPToolManageService:
|
|||
|
||||
service = MCPToolManageService(db.session())
|
||||
service.update_provider_credentials(
|
||||
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True
|
||||
provider_id=mcp_provider.id,
|
||||
tenant_id=tenant.id,
|
||||
credentials={"new_key": "new_value"},
|
||||
authed=True,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
|
@ -1267,7 +1270,10 @@ class TestMCPToolManageService:
|
|||
|
||||
service = MCPToolManageService(db.session())
|
||||
service.update_provider_credentials(
|
||||
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False
|
||||
provider_id=mcp_provider.id,
|
||||
tenant_id=tenant.id,
|
||||
credentials={"new_key": "new_value"},
|
||||
authed=False,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
|
|
|||
Loading…
Reference in New Issue