mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 20:17:29 +08:00
chore: handle session
This commit is contained in:
parent
e7a575a33c
commit
0b021273bc
@ -1021,28 +1021,32 @@ class ToolMCPAuthApi(Resource):
|
|||||||
timeout=provider_entity.timeout,
|
timeout=provider_entity.timeout,
|
||||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||||
):
|
):
|
||||||
# Create new transaction for update
|
# Update credentials in new transaction
|
||||||
with session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
service.update_provider_credentials(
|
service.update_provider_credentials(
|
||||||
provider=db_provider,
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
credentials=provider_entity.credentials,
|
credentials=provider_entity.credentials,
|
||||||
authed=True,
|
authed=True,
|
||||||
)
|
)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
except MCPAuthError as e:
|
except MCPAuthError as e:
|
||||||
service = MCPToolManageService(session=session)
|
|
||||||
try:
|
try:
|
||||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||||
with session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
response = service.execute_auth_actions(auth_result)
|
response = service.execute_auth_actions(auth_result)
|
||||||
return response
|
return response
|
||||||
except MCPRefreshTokenError as e:
|
except MCPRefreshTokenError as e:
|
||||||
with session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service.clear_provider_credentials(provider=db_provider)
|
service = MCPToolManageService(session=session)
|
||||||
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||||
except MCPError as e:
|
except MCPError as e:
|
||||||
with session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service.clear_provider_credentials(provider=db_provider)
|
service = MCPToolManageService(session=session)
|
||||||
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -296,18 +296,22 @@ class MCPToolManageService:
|
|||||||
# ========== OAuth and Credentials Operations ==========
|
# ========== OAuth and Credentials Operations ==========
|
||||||
|
|
||||||
def update_provider_credentials(
|
def update_provider_credentials(
|
||||||
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool | None = None
|
self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update provider credentials with encryption.
|
Update provider credentials with encryption.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: Provider instance
|
provider_id: Provider ID
|
||||||
|
tenant_id: Tenant ID
|
||||||
credentials: Credentials to save
|
credentials: Credentials to save
|
||||||
authed: Whether provider is authenticated (None means keep current state)
|
authed: Whether provider is authenticated (None means keep current state)
|
||||||
"""
|
"""
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||||
|
|
||||||
|
# Get provider from current session
|
||||||
|
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
|
||||||
# Encrypt new credentials
|
# Encrypt new credentials
|
||||||
provider_controller = MCPToolProviderController.from_db(provider)
|
provider_controller = MCPToolProviderController.from_db(provider)
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
tool_configuration = ProviderConfigEncrypter(
|
||||||
@ -341,17 +345,25 @@ class MCPToolManageService:
|
|||||||
data: Data to save (tokens, client info, or code verifier)
|
data: Data to save (tokens, client info, or code verifier)
|
||||||
data_type: Type of OAuth data to save
|
data_type: Type of OAuth data to save
|
||||||
"""
|
"""
|
||||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
|
||||||
|
|
||||||
# Determine if this makes the provider authenticated
|
# Determine if this makes the provider authenticated
|
||||||
authed = (
|
authed = (
|
||||||
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
|
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
|
# update_provider_credentials will validate provider existence
|
||||||
|
self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed)
|
||||||
|
|
||||||
|
def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Clear all credentials for a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: Provider ID
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
"""
|
||||||
|
# Get provider from current session
|
||||||
|
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
|
||||||
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
|
|
||||||
"""Clear all credentials for a provider."""
|
|
||||||
provider.tools = EMPTY_TOOLS_JSON
|
provider.tools = EMPTY_TOOLS_JSON
|
||||||
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
||||||
provider.updated_at = datetime.now()
|
provider.updated_at = datetime.now()
|
||||||
|
|||||||
@ -1206,7 +1206,10 @@ class TestMCPToolManageService:
|
|||||||
|
|
||||||
service = MCPToolManageService(db.session())
|
service = MCPToolManageService(db.session())
|
||||||
service.update_provider_credentials(
|
service.update_provider_credentials(
|
||||||
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True
|
provider_id=mcp_provider.id,
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
credentials={"new_key": "new_value"},
|
||||||
|
authed=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
@ -1267,7 +1270,10 @@ class TestMCPToolManageService:
|
|||||||
|
|
||||||
service = MCPToolManageService(db.session())
|
service = MCPToolManageService(db.session())
|
||||||
service.update_provider_credentials(
|
service.update_provider_credentials(
|
||||||
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False
|
provider_id=mcp_provider.id,
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
credentials={"new_key": "new_value"},
|
||||||
|
authed=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user