mirror of
https://github.com/langgenius/dify.git
synced 2026-06-10 18:24:09 +08:00
feat(api): add MCP user-identity forwarding (#36839)
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
db1aa683bc
commit
37e1d452b8
@ -20,7 +20,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
@ -210,6 +210,30 @@ class MCPProviderBasePayload(BaseModel):
|
||||
configuration: dict[str, Any] | None = Field(default_factory=dict)
|
||||
headers: dict[str, Any] | None = Field(default_factory=dict)
|
||||
authentication: dict[str, Any] | None = Field(default_factory=dict)
|
||||
# None means "leave unchanged" on update; the controller resolves it to a
|
||||
# concrete IdentityMode before calling the service (see _resolve_identity_mode).
|
||||
identity_mode: IdentityMode | None = None
|
||||
|
||||
|
||||
def _resolve_identity_mode(requested: IdentityMode | None, *, current: IdentityMode) -> IdentityMode:
|
||||
"""Resolve the effective MCP identity_mode for a create/update request.
|
||||
|
||||
Keeps two API-layer concerns out of the service so the service always
|
||||
receives a concrete value:
|
||||
|
||||
* ``None`` means "leave unchanged" (update semantics) — fall back to
|
||||
``current`` (``IdentityMode.OFF`` for a brand-new provider).
|
||||
* Identity forwarding is an enterprise-only capability. On non-enterprise
|
||||
deployments any non-OFF value is coerced back to OFF so a persisted row
|
||||
can never imply forwarding that the runtime won't perform. This gates the
|
||||
API surface to match the backend gate in
|
||||
``MCPTool._forwarding_requested`` — both the API and the backend
|
||||
invocation must be gated on ``dify_config.ENTERPRISE_ENABLED``.
|
||||
"""
|
||||
mode = current if requested is None else requested
|
||||
if mode != IdentityMode.OFF and not dify_config.ENTERPRISE_ENABLED:
|
||||
return IdentityMode.OFF
|
||||
return mode
|
||||
|
||||
|
||||
class MCPProviderCreatePayload(MCPProviderBasePayload):
|
||||
@ -1000,6 +1024,7 @@ class ToolProviderMCPApi(Resource):
|
||||
headers=payload.headers or {},
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
identity_mode=_resolve_identity_mode(payload.identity_mode, current=IdentityMode.OFF),
|
||||
)
|
||||
|
||||
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||
@ -1054,6 +1079,11 @@ class ToolProviderMCPApi(Resource):
|
||||
# Step 3: Perform database update in a transaction
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
# Resolve "leave unchanged" (None) against the stored value, and gate
|
||||
# the result on ENTERPRISE_ENABLED — both are API-layer concerns, so
|
||||
# the service receives a concrete IdentityMode.
|
||||
existing = service.get_provider(provider_id=payload.provider_id, tenant_id=current_tenant_id)
|
||||
identity_mode = _resolve_identity_mode(payload.identity_mode, current=IdentityMode(existing.identity_mode))
|
||||
service.update_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=payload.provider_id,
|
||||
@ -1067,6 +1097,7 @@ class ToolProviderMCPApi(Resource):
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
identity_mode=identity_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -37,6 +37,13 @@ class MCPSupportGrantType(StrEnum):
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
class IdentityMode(StrEnum):
|
||||
"""How Dify forwards the end-user's identity to an MCP server."""
|
||||
|
||||
OFF = "off"
|
||||
IDP_TOKEN = "idp_token"
|
||||
|
||||
|
||||
class MCPAuthentication(BaseModel):
|
||||
client_id: str
|
||||
client_secret: str | None = None
|
||||
@ -76,6 +83,8 @@ class MCPProviderEntity(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
identity_mode: IdentityMode = IdentityMode.OFF
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
|
||||
"""Create entity from database model with decryption"""
|
||||
@ -96,6 +105,7 @@ class MCPProviderEntity(BaseModel):
|
||||
icon=db_provider.icon or "",
|
||||
created_at=db_provider.created_at,
|
||||
updated_at=db_provider.updated_at,
|
||||
identity_mode=IdentityMode(db_provider.identity_mode),
|
||||
)
|
||||
|
||||
@property
|
||||
@ -170,6 +180,7 @@ class MCPProviderEntity(BaseModel):
|
||||
"updated_at": int(self.updated_at.timestamp()),
|
||||
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||
"identity_mode": self.identity_mode,
|
||||
}
|
||||
|
||||
# Add configuration
|
||||
|
||||
@ -40,6 +40,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
provider_entity: MCPProviderEntity | None = None,
|
||||
authorization_code: str | None = None,
|
||||
by_server_id: bool = False,
|
||||
forward_identity_active: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the MCP client with auth retry capability.
|
||||
@ -52,12 +53,15 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
provider_entity: Provider entity for authentication
|
||||
authorization_code: Optional authorization code for initial auth
|
||||
by_server_id: Whether to look up provider by server ID
|
||||
forward_identity_active: If True, suppress the static-OAuth retry
|
||||
on 401 — the forwarded identity must propagate as-is.
|
||||
"""
|
||||
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||
|
||||
self.provider_entity = provider_entity
|
||||
self.authorization_code = authorization_code
|
||||
self.by_server_id = by_server_id
|
||||
self.forward_identity_active = forward_identity_active
|
||||
self._has_retried = False
|
||||
|
||||
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||
@ -73,6 +77,8 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails or max retries reached
|
||||
"""
|
||||
if self.forward_identity_active:
|
||||
raise error
|
||||
if not self.provider_entity:
|
||||
raise error
|
||||
if self._has_retried:
|
||||
|
||||
@ -54,6 +54,9 @@ class ToolProviderApiEntity(BaseModel):
|
||||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
# M3 — user-identity forwarding selector. Round-tripped through the
|
||||
# console API so the create/edit modal can hydrate the toggle state.
|
||||
identity_mode: str = Field(default="off", description="Identity-forwarding mechanism: 'off' or 'idp_token'")
|
||||
# Workflow
|
||||
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
|
||||
|
||||
@ -92,6 +95,9 @@ class ToolProviderApiEntity(BaseModel):
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
# M3 — forwarding selector. Always emit ("off" is a valid
|
||||
# value that the UI must hydrate, not skip).
|
||||
optional_fields["identity_mode"] = self.identity_mode
|
||||
case ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
case _:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any, Self
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.entities.mcp_provider import IdentityMode, MCPProviderEntity
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -28,6 +28,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
identity_mode: IdentityMode = IdentityMode.OFF,
|
||||
):
|
||||
super().__init__(entity)
|
||||
self.entity: ToolProviderEntityWithPlugin = entity
|
||||
@ -37,6 +38,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.identity_mode: IdentityMode = identity_mode
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
@ -105,6 +107,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers=entity.headers,
|
||||
timeout=entity.timeout,
|
||||
sse_read_timeout=entity.sse_read_timeout,
|
||||
identity_mode=entity.identity_mode,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
@ -134,6 +137,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
identity_mode=self.identity_mode,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]:
|
||||
@ -151,6 +155,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
identity_mode=self.identity_mode,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
|
||||
@ -6,6 +6,8 @@ import logging
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.mcp_provider import IdentityMode
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import (
|
||||
@ -25,6 +27,11 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetada
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom header used to carry the forwarded SSO access token. Picked to avoid
|
||||
# stomping on the workspace-scoped Authorization header (provider OAuth /
|
||||
# user-supplied custom credentials), which would silently break those flows.
|
||||
FORWARDED_IDENTITY_HEADER = "X-Dify-SSO-Access-Token"
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
def __init__(
|
||||
@ -38,6 +45,7 @@ class MCPTool(Tool):
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
identity_mode: IdentityMode = IdentityMode.OFF,
|
||||
):
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
@ -47,6 +55,7 @@ class MCPTool(Tool):
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.identity_mode: IdentityMode = identity_mode
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
@ -60,7 +69,7 @@ class MCPTool(Tool):
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters, user_id=user_id, app_id=app_id)
|
||||
|
||||
# Extract usage metadata from MCP protocol's _meta field
|
||||
self._latest_usage = self._derive_usage_from_result(result)
|
||||
@ -234,6 +243,7 @@ class MCPTool(Tool):
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
identity_mode=self.identity_mode,
|
||||
)
|
||||
|
||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
@ -246,7 +256,26 @@ class MCPTool(Tool):
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
||||
|
||||
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
|
||||
@property
|
||||
def _forwarding_requested(self) -> bool:
|
||||
"""True only when the configured identity_mode wants forwarding AND
|
||||
the deployment actually has the enterprise side that can mint tokens.
|
||||
Non-enterprise installs treat the DB value as a no-op — a stale row
|
||||
won't trigger a 5xx against a missing inner-API endpoint."""
|
||||
return self.identity_mode != IdentityMode.OFF and dify_config.ENTERPRISE_ENABLED
|
||||
|
||||
def invoke_remote_mcp_tool(
|
||||
self,
|
||||
tool_parameters: dict[str, Any],
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
) -> CallToolResult:
|
||||
# Fail closed: forwarding requires user_id (refuse before any DB I/O).
|
||||
if self._forwarding_requested and not user_id:
|
||||
raise ToolInvokeError(
|
||||
"Forward-user-identity is enabled for this MCP provider but no end-user context was supplied."
|
||||
)
|
||||
|
||||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
@ -271,6 +300,15 @@ class MCPTool(Tool):
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# Forwarded identity rides in a custom header so workspace-scoped
|
||||
# provider credentials (Authorization / custom Headers) keep working
|
||||
# untouched. The MCP server is expected to read X-Dify-SSO-Access-Token
|
||||
# when identity forwarding is configured.
|
||||
forward_identity_active = False
|
||||
if self._forwarding_requested and user_id:
|
||||
self._inject_forwarded_identity(headers, user_id=user_id, app_id=app_id, audience=server_url)
|
||||
forward_identity_active = True
|
||||
|
||||
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||
try:
|
||||
@ -280,9 +318,44 @@ class MCPTool(Tool):
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
forward_identity_active=forward_identity_active,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
def _inject_forwarded_identity(
|
||||
self,
|
||||
headers: dict[str, str],
|
||||
*,
|
||||
user_id: str,
|
||||
app_id: str | None,
|
||||
audience: str,
|
||||
) -> None:
|
||||
"""Call the enterprise IssueMCPToken endpoint and stamp the issued
|
||||
token into X-Dify-SSO-Access-Token.
|
||||
|
||||
A custom header is used (rather than Authorization) so it composes
|
||||
with workspace-scoped provider credentials — the user may have OAuth
|
||||
tokens or a custom Authorization header configured on the MCP
|
||||
provider, and forwarding must not silently overwrite them.
|
||||
|
||||
Errors are surfaced as ToolInvokeError so the workflow halts with a
|
||||
clear message instead of silently dropping identity and hitting the
|
||||
MCP server unauthenticated.
|
||||
"""
|
||||
from services.enterprise.base import MCPTokenError
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
try:
|
||||
token, _expires_at = EnterpriseService.issue_mcp_token(
|
||||
user_id=user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=app_id,
|
||||
audience=audience,
|
||||
)
|
||||
except MCPTokenError as e:
|
||||
raise ToolInvokeError(f"Failed to obtain forwarded identity token: {e}") from e
|
||||
headers[FORWARDED_IDENTITY_HEADER] = token
|
||||
|
||||
@ -0,0 +1,44 @@
|
||||
"""add identity mode to mcp tool provider
|
||||
|
||||
Revision ID: 3df4dbcc1e21
|
||||
Revises: 2b3c4d5e6f70
|
||||
Create Date: 2026-05-29 15:00:00.000000
|
||||
|
||||
Adds the `identity_mode` column to `tool_mcp_providers` to drive the M2 MCP
|
||||
user-identity forwarding feature. Reserved values:
|
||||
|
||||
"off" — no header forwarded (default; pre-M2 behaviour).
|
||||
"idp_token" — call dify-enterprise /inner/api/mcp/issue-token, stamp the
|
||||
returned SSO access token on the outbound MCP request as
|
||||
`X-Dify-SSO-Access-Token: <token>`.
|
||||
|
||||
The column is filled with the safe default "off" for existing rows so older
|
||||
providers keep their current behaviour until an admin opts in.
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3df4dbcc1e21"
|
||||
down_revision = "2b3c4d5e6f70"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"tool_mcp_providers",
|
||||
sa.Column(
|
||||
"identity_mode",
|
||||
sa.String(length=32),
|
||||
nullable=False,
|
||||
server_default=sa.text("'off'"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("tool_mcp_providers", "identity_mode")
|
||||
@ -350,6 +350,14 @@ class MCPToolProvider(TypeBase):
|
||||
# encrypted headers for MCP server requests
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
|
||||
# M2 (MCP user-identity forwarding) — which identity-forwarding mechanism
|
||||
# this provider uses. Reserved values:
|
||||
# "off" — no forwarding (default; preserves pre-M2 behaviour).
|
||||
# "idp_token" — forward an SSO access token minted by dify-enterprise.
|
||||
identity_mode: Mapped[str] = mapped_column(
|
||||
sa.String(32), nullable=False, server_default=sa.text("'off'"), default="off"
|
||||
)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
|
||||
@ -14907,6 +14907,14 @@ Icon information model.
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| IconType | string | | |
|
||||
|
||||
#### IdentityMode
|
||||
|
||||
How Dify forwards the end-user's identity to an MCP server.
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| IdentityMode | string | How Dify forwards the end-user's identity to an MCP server. | |
|
||||
|
||||
#### Import
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
@ -15217,6 +15225,7 @@ Enum class for large language model mode.
|
||||
| icon | string | | Yes |
|
||||
| icon_background | string | | No |
|
||||
| icon_type | string | | Yes |
|
||||
| identity_mode | [IdentityMode](#identitymode) | | No |
|
||||
| name | string | | Yes |
|
||||
| server_identifier | string | | Yes |
|
||||
| server_url | string | | Yes |
|
||||
@ -15237,6 +15246,7 @@ Enum class for large language model mode.
|
||||
| icon | string | | Yes |
|
||||
| icon_background | string | | No |
|
||||
| icon_type | string | | Yes |
|
||||
| identity_mode | [IdentityMode](#identitymode) | | No |
|
||||
| name | string | | Yes |
|
||||
| provider_id | string | | Yes |
|
||||
| server_identifier | string | | Yes |
|
||||
|
||||
@ -18,7 +18,7 @@ import yaml
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -748,6 +748,9 @@ class MigrationImportService:
|
||||
headers=mcp_data.get("headers") if isinstance(mcp_data.get("headers"), dict) else {},
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
# Re-import must not silently reset forwarding: preserve the
|
||||
# stored mode (update_provider now defaults to OFF when omitted).
|
||||
identity_mode=IdentityMode(existing.identity_mode),
|
||||
)
|
||||
db.session.commit()
|
||||
status = "updated"
|
||||
|
||||
@ -12,8 +12,28 @@ from services.errors.enterprise import (
|
||||
EnterpriseAPIForbiddenError,
|
||||
EnterpriseAPINotFoundError,
|
||||
EnterpriseAPIUnauthorizedError,
|
||||
EnterpriseServiceError,
|
||||
)
|
||||
|
||||
|
||||
class MCPTokenError(EnterpriseServiceError):
|
||||
"""Generic failure of the IssueMCPToken RPC."""
|
||||
|
||||
|
||||
class MCPNoRefreshTokenError(MCPTokenError):
|
||||
"""User has no stored SSO refresh_token; ask them to re-authenticate."""
|
||||
|
||||
def __init__(self, description: str = ""):
|
||||
super().__init__(description, status_code=428)
|
||||
|
||||
|
||||
class MCPIdentityRefreshError(MCPTokenError):
|
||||
"""IdP rejected the refresh attempt (revoked/expired session)."""
|
||||
|
||||
def __init__(self, description: str = ""):
|
||||
super().__init__(description, status_code=401)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,15 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
from services.enterprise.base import (
|
||||
EnterpriseRequest,
|
||||
MCPIdentityRefreshError,
|
||||
MCPNoRefreshTokenError,
|
||||
MCPTokenError,
|
||||
)
|
||||
from services.errors.enterprise import (
|
||||
EnterpriseServiceError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.feature_service import LicenseStatus
|
||||
@ -121,6 +129,77 @@ class EnterpriseService:
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
@classmethod
|
||||
def issue_mcp_token(
|
||||
cls,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str | None,
|
||||
audience: str,
|
||||
) -> tuple[str, int]:
|
||||
"""Mint a short-lived SSO id_token (or OAuth2 access_token) representing
|
||||
the calling Dify user, audience-scoped to the given MCP server identifier.
|
||||
|
||||
Used by MCPTool.invoke_remote_mcp_tool to stamp the
|
||||
X-Dify-SSO-Access-Token header on outbound MCP requests when the
|
||||
provider's identity_mode is set to "idp_token".
|
||||
|
||||
Returns:
|
||||
(token, expires_at_unix_seconds)
|
||||
|
||||
Raises:
|
||||
MCPNoRefreshTokenError: user has no stored SSO refresh_token on the
|
||||
enterprise side; surface to the workflow as "please log in via SSO".
|
||||
MCPIdentityRefreshError: enterprise tried to refresh against the IdP
|
||||
and the IdP rejected (revoked/expired session).
|
||||
MCPTokenError: any other failure of the enterprise endpoint.
|
||||
"""
|
||||
try:
|
||||
response = EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/mcp/issue-token",
|
||||
json={
|
||||
"user_id": user_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id or "",
|
||||
"audience": audience,
|
||||
},
|
||||
)
|
||||
except EnterpriseServiceError as e:
|
||||
# The HTTP-status subclasses (400/401/403/404) inherit directly
|
||||
# from EnterpriseServiceError, not EnterpriseAPIError, so we
|
||||
# must catch the base class to route them all.
|
||||
status = getattr(e, "status_code", None)
|
||||
if status == 401:
|
||||
# Enterprise side returns 401 when the IdP rejected the refresh.
|
||||
raise MCPIdentityRefreshError(str(e) or "identity refresh failed; please re-authenticate") from e
|
||||
if status == 428:
|
||||
raise MCPNoRefreshTokenError(
|
||||
str(e) or "user has no stored SSO refresh token; please re-authenticate"
|
||||
) from e
|
||||
if status == 403:
|
||||
# 403 most often means the tenant isn't licensed for MCP
|
||||
# identity-forwarding. Surface as identity-refresh-failure so
|
||||
# the workflow halts loudly rather than retrying.
|
||||
raise MCPIdentityRefreshError(
|
||||
str(e) or "enterprise refused to issue an MCP identity token (license or policy)"
|
||||
) from e
|
||||
raise MCPTokenError(f"issue_mcp_token failed (status={status}): {e}") from e
|
||||
|
||||
if not isinstance(response, dict):
|
||||
raise MCPTokenError("invalid response shape from enterprise /mcp/issue-token")
|
||||
|
||||
token = response.get("token")
|
||||
expires_at = response.get("expires_at")
|
||||
# Accept int or float for expires_at (some clocks emit float
|
||||
# seconds-since-epoch). Reject bools explicitly because `bool` is
|
||||
# an `int` subclass in Python and would pass isinstance(_, int).
|
||||
if not isinstance(token, str) or not token:
|
||||
raise MCPTokenError(f"missing or non-string token in enterprise response: {response!r}")
|
||||
if isinstance(expires_at, bool) or not isinstance(expires_at, (int, float)):
|
||||
raise MCPTokenError(f"missing or non-numeric expires_at in enterprise response: {response!r}")
|
||||
return token, int(expires_at)
|
||||
|
||||
@classmethod
|
||||
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
|
||||
return EnterpriseRequest.send_request(
|
||||
|
||||
@ -12,7 +12,7 @@ from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
@ -136,6 +136,7 @@ class MCPToolManageService:
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
identity_mode: IdentityMode = IdentityMode.OFF,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Create a new MCP provider."""
|
||||
# Validate URL format
|
||||
@ -171,6 +172,7 @@ class MCPToolManageService:
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
identity_mode=identity_mode,
|
||||
)
|
||||
|
||||
self._session.add(mcp_tool)
|
||||
@ -194,6 +196,7 @@ class MCPToolManageService:
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
validation_result: ServerUrlValidationResult | None = None,
|
||||
identity_mode: IdentityMode = IdentityMode.OFF,
|
||||
) -> None:
|
||||
"""
|
||||
Update an MCP provider.
|
||||
@ -255,6 +258,11 @@ class MCPToolManageService:
|
||||
if authentication and authentication.client_id:
|
||||
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
|
||||
|
||||
# Update user-identity forwarding mode. The controller has already
|
||||
# resolved "leave unchanged" and applied the ENTERPRISE_ENABLED gate,
|
||||
# so this is always a concrete, vetted value.
|
||||
mcp_provider.identity_mode = identity_mode
|
||||
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
|
||||
|
||||
@ -380,3 +380,53 @@ def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.Mon
|
||||
resp = controller_module.ToolLabelsApi().get()
|
||||
|
||||
assert resp == ["a", "b"]
|
||||
|
||||
|
||||
# --- _resolve_identity_mode: gating + None-resolution (PR #36839 review) ---
|
||||
|
||||
|
||||
def test_resolve_identity_mode_none_keeps_current_when_enterprise(controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
"""None means 'leave unchanged' — fall back to the stored mode (update path)."""
|
||||
identity_mode = importlib.import_module("core.entities.mcp_provider").IdentityMode
|
||||
monkeypatch.setattr(controller_module.dify_config, "ENTERPRISE_ENABLED", True)
|
||||
|
||||
resolved = controller_module._resolve_identity_mode(None, current=identity_mode.IDP_TOKEN)
|
||||
|
||||
assert resolved == identity_mode.IDP_TOKEN
|
||||
|
||||
|
||||
def test_resolve_identity_mode_explicit_value_overrides_current(controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
"""An explicit value wins over the stored mode."""
|
||||
identity_mode = importlib.import_module("core.entities.mcp_provider").IdentityMode
|
||||
monkeypatch.setattr(controller_module.dify_config, "ENTERPRISE_ENABLED", True)
|
||||
|
||||
resolved = controller_module._resolve_identity_mode(identity_mode.OFF, current=identity_mode.IDP_TOKEN)
|
||||
|
||||
assert resolved == identity_mode.OFF
|
||||
|
||||
|
||||
def test_resolve_identity_mode_coerces_non_off_to_off_when_not_enterprise(
|
||||
controller_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Gate: a non-EE deployment must never persist a non-OFF mode — the
|
||||
runtime won't forward, so the stored row must not imply it does."""
|
||||
identity_mode = importlib.import_module("core.entities.mcp_provider").IdentityMode
|
||||
monkeypatch.setattr(controller_module.dify_config, "ENTERPRISE_ENABLED", False)
|
||||
|
||||
# Both an explicit idp_token request AND an inherited non-OFF current
|
||||
# must collapse to OFF.
|
||||
assert (
|
||||
controller_module._resolve_identity_mode(identity_mode.IDP_TOKEN, current=identity_mode.OFF)
|
||||
== identity_mode.OFF
|
||||
)
|
||||
assert controller_module._resolve_identity_mode(None, current=identity_mode.IDP_TOKEN) == identity_mode.OFF
|
||||
|
||||
|
||||
def test_resolve_identity_mode_off_is_passthrough_when_not_enterprise(
|
||||
controller_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""OFF is always fine — the gate only neutralizes non-OFF values."""
|
||||
identity_mode = importlib.import_module("core.entities.mcp_provider").IdentityMode
|
||||
monkeypatch.setattr(controller_module.dify_config, "ENTERPRISE_ENABLED", False)
|
||||
|
||||
assert controller_module._resolve_identity_mode(None, current=identity_mode.OFF) == identity_mode.OFF
|
||||
|
||||
@ -53,6 +53,7 @@ def test_from_db_model_maps_fields() -> None:
|
||||
icon=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
identity_mode="off",
|
||||
)
|
||||
|
||||
# Act
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError
|
||||
|
||||
|
||||
class TestForwardIdentityShortCircuit:
|
||||
def test_forward_identity_active_reraises_without_retry(self):
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="https://mcp.example.com",
|
||||
headers={"Authorization": "Bearer user-jwt"},
|
||||
forward_identity_active=True,
|
||||
)
|
||||
|
||||
with pytest.raises(MCPAuthError):
|
||||
client._handle_auth_error(MCPAuthError("unauthorized"))
|
||||
|
||||
assert client.headers["Authorization"] == "Bearer user-jwt"
|
||||
assert client._has_retried is False
|
||||
|
||||
def test_forward_identity_active_takes_precedence_over_provider_entity(self):
|
||||
sentinel_entity = object()
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="https://mcp.example.com",
|
||||
provider_entity=sentinel_entity, # type: ignore[arg-type]
|
||||
forward_identity_active=True,
|
||||
)
|
||||
|
||||
with pytest.raises(MCPAuthError, match="forwarded-id-401"):
|
||||
client._handle_auth_error(MCPAuthError("forwarded-id-401"))
|
||||
|
||||
def test_default_path_unchanged_without_provider_entity(self):
|
||||
client = MCPClientWithAuthRetry(server_url="https://mcp.example.com")
|
||||
with pytest.raises(MCPAuthError, match="no-provider"):
|
||||
client._handle_auth_error(MCPAuthError("no-provider"))
|
||||
|
||||
def test_default_constructor_defaults_forward_identity_to_false(self):
|
||||
client = MCPClientWithAuthRetry(server_url="https://mcp.example.com")
|
||||
assert client.forward_identity_active is False
|
||||
@ -148,3 +148,112 @@ def test_mcp_tool_handle_none_parameter_filters_empty_values():
|
||||
tool = _build_mcp_tool()
|
||||
cleaned = tool._handle_none_parameter({"a": 1, "b": None, "c": "", "d": " ", "e": "ok"})
|
||||
assert cleaned == {"a": 1, "e": "ok"}
|
||||
|
||||
|
||||
# ----- M2/M3 user-identity forwarding ---------------------------------------
|
||||
|
||||
|
||||
def _build_forwarding_tool(*, mode: str = "idp_token") -> MCPTool:
|
||||
"""Helper that builds an MCPTool with the identity_mode set."""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="author",
|
||||
name="remote-tool",
|
||||
label=I18nObject(en_US="remote-tool"),
|
||||
provider="provider-id",
|
||||
),
|
||||
parameters=[],
|
||||
output_schema={},
|
||||
)
|
||||
return MCPTool(
|
||||
entity=entity,
|
||||
runtime=ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER),
|
||||
tenant_id="tenant-1",
|
||||
icon="icon.svg",
|
||||
server_url="https://mcp.example.com/mcp/",
|
||||
provider_id="provider-id",
|
||||
identity_mode=mode,
|
||||
)
|
||||
|
||||
|
||||
def test_inject_forwarded_identity_stamps_custom_header():
|
||||
"""The minted SSO token must be placed in X-Dify-SSO-Access-Token; the
|
||||
workspace-scoped Authorization header and any other custom headers must
|
||||
pass through untouched so provider credentials keep working."""
|
||||
from core.tools.mcp_tool.tool import FORWARDED_IDENTITY_HEADER
|
||||
|
||||
tool = _build_forwarding_tool()
|
||||
headers: dict[str, str] = {"Authorization": "Bearer static-client-token", "X-Other": "keep"}
|
||||
|
||||
with patch(
|
||||
"services.enterprise.enterprise_service.EnterpriseService.issue_mcp_token",
|
||||
return_value=("forwarded.jwt.payload", 1900000000),
|
||||
):
|
||||
tool._inject_forwarded_identity(headers, user_id="alice", app_id=None, audience="https://mcp.example.com/mcp/")
|
||||
|
||||
assert headers[FORWARDED_IDENTITY_HEADER] == "forwarded.jwt.payload"
|
||||
assert headers["Authorization"] == "Bearer static-client-token"
|
||||
assert headers["X-Other"] == "keep"
|
||||
|
||||
|
||||
def test_inject_forwarded_identity_translates_token_error_to_invoke_error():
|
||||
"""EnterpriseService failures must surface as ToolInvokeError so the
|
||||
workflow halts loudly instead of proceeding without identity."""
|
||||
from core.tools.mcp_tool.tool import FORWARDED_IDENTITY_HEADER
|
||||
from services.enterprise.base import MCPNoRefreshTokenError
|
||||
|
||||
tool = _build_forwarding_tool()
|
||||
headers: dict[str, str] = {}
|
||||
|
||||
with patch(
|
||||
"services.enterprise.enterprise_service.EnterpriseService.issue_mcp_token",
|
||||
side_effect=MCPNoRefreshTokenError("please re-sso"),
|
||||
):
|
||||
with pytest.raises(ToolInvokeError, match="forwarded identity token"):
|
||||
tool._inject_forwarded_identity(
|
||||
headers, user_id="alice", app_id=None, audience="https://mcp.example.com/mcp/"
|
||||
)
|
||||
|
||||
# Headers must NOT have been mutated when token-issuance failed.
|
||||
assert FORWARDED_IDENTITY_HEADER not in headers
|
||||
assert "Authorization" not in headers
|
||||
|
||||
|
||||
def test_invoke_remote_mcp_tool_fails_closed_when_user_id_missing():
|
||||
"""When forwarding is enabled AND the deployment is enterprise, missing
|
||||
user_id must raise — never silently invoke as the static identity."""
|
||||
tool = _build_forwarding_tool()
|
||||
|
||||
with patch("core.tools.mcp_tool.tool.dify_config") as cfg:
|
||||
cfg.ENTERPRISE_ENABLED = True
|
||||
with pytest.raises(ToolInvokeError, match="no end-user context"):
|
||||
tool.invoke_remote_mcp_tool({}, user_id=None, app_id=None)
|
||||
|
||||
|
||||
def test_invoke_skips_forwarding_when_enterprise_disabled():
|
||||
"""Non-enterprise deployments treat the DB selector as a no-op: a stale
|
||||
`identity_mode="idp_token"` row must NOT raise (fail-closed) AND must
|
||||
NOT call the enterprise inner API. The runtime falls through to the
|
||||
legacy provider-identity path."""
|
||||
tool = _build_forwarding_tool()
|
||||
|
||||
with patch("core.tools.mcp_tool.tool.dify_config") as cfg:
|
||||
cfg.ENTERPRISE_ENABLED = False
|
||||
# The fail-closed branch must NOT fire (no enterprise → no forwarding).
|
||||
# The function will still try the legacy DB-load path; we patch that
|
||||
# to keep the test unit-scoped.
|
||||
with patch("core.tools.mcp_tool.tool.MCPClientWithAuthRetry") as client_cls:
|
||||
client_cls.return_value.__enter__.return_value.invoke_tool.return_value = CallToolResult(
|
||||
content=[],
|
||||
_meta=None,
|
||||
)
|
||||
with patch.object(tool, "_inject_forwarded_identity") as inject:
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPToolManageService"):
|
||||
with patch("core.entities.mcp_provider.MCPProviderEntity.decrypt_server_url", return_value="u"):
|
||||
with patch("core.entities.mcp_provider.MCPProviderEntity.decrypt_headers", return_value={}):
|
||||
# Should not raise; should not call enterprise.
|
||||
try:
|
||||
tool.invoke_remote_mcp_tool({}, user_id=None, app_id=None)
|
||||
except Exception:
|
||||
pass
|
||||
inject.assert_not_called()
|
||||
|
||||
@ -596,7 +596,9 @@ def test_api_tool_create_records_id_mapping(monkeypatch):
|
||||
|
||||
|
||||
def test_mcp_tool_import_restores_exported_tool_list(monkeypatch):
|
||||
provider = type("Provider", (), {"id": "target-provider-id", "tools": "[]", "authed": False})()
|
||||
provider = type(
|
||||
"Provider", (), {"id": "target-provider-id", "tools": "[]", "authed": False, "identity_mode": "off"}
|
||||
)()
|
||||
report_items = []
|
||||
|
||||
class StubSession:
|
||||
@ -652,7 +654,9 @@ def test_mcp_tool_import_restores_exported_tool_list(monkeypatch):
|
||||
|
||||
@pytest.mark.parametrize("conflict_strategy", [ConflictStrategy.SKIP, ConflictStrategy.UPDATE])
|
||||
def test_mcp_tool_existing_provider_records_id_mapping(monkeypatch, conflict_strategy):
|
||||
provider = type("Provider", (), {"id": "target-mcp-provider-id", "tools": "[]", "authed": False})()
|
||||
provider = type(
|
||||
"Provider", (), {"id": "target-mcp-provider-id", "tools": "[]", "authed": False, "identity_mode": "off"}
|
||||
)()
|
||||
id_mapping = {}
|
||||
id_mapping_details = []
|
||||
|
||||
@ -712,7 +716,9 @@ def test_mcp_tool_existing_provider_records_id_mapping(monkeypatch, conflict_str
|
||||
|
||||
|
||||
def test_mcp_tool_create_records_id_mapping(monkeypatch):
|
||||
provider = type("Provider", (), {"id": "target-mcp-provider-id", "tools": "[]", "authed": False})()
|
||||
provider = type(
|
||||
"Provider", (), {"id": "target-mcp-provider-id", "tools": "[]", "authed": False, "identity_mode": "off"}
|
||||
)()
|
||||
id_mapping = {}
|
||||
provider_created = False
|
||||
|
||||
|
||||
@ -466,3 +466,111 @@ class TestGetCachedLicenseStatus:
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
|
||||
class TestIssueMCPToken:
|
||||
"""Coverage for EnterpriseService.issue_mcp_token (M2).
|
||||
|
||||
The function wraps `POST /inner/api/mcp/issue-token` and must map
|
||||
EnterpriseServiceError subclasses to MCP-typed errors so the workflow
|
||||
layer can halt with a precise message instead of leaking transport text.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _call():
|
||||
return EnterpriseService.issue_mcp_token(
|
||||
user_id="user-uuid",
|
||||
tenant_id="tenant-uuid",
|
||||
app_id="app-uuid",
|
||||
audience="https://mcp.example.com/mcp/",
|
||||
)
|
||||
|
||||
def test_happy_path_returns_token_and_expiry(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"token": "abc.def.ghi", "expires_at": 1900000000}
|
||||
token, exp = self._call()
|
||||
|
||||
assert token == "abc.def.ghi"
|
||||
assert exp == 1900000000
|
||||
req.send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/mcp/issue-token",
|
||||
json={
|
||||
"user_id": "user-uuid",
|
||||
"tenant_id": "tenant-uuid",
|
||||
"app_id": "app-uuid",
|
||||
"audience": "https://mcp.example.com/mcp/",
|
||||
},
|
||||
)
|
||||
|
||||
def test_401_maps_to_identity_refresh_error(self):
|
||||
from services.enterprise.base import MCPIdentityRefreshError
|
||||
from services.errors.enterprise import EnterpriseAPIUnauthorizedError
|
||||
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.side_effect = EnterpriseAPIUnauthorizedError("refresh rejected by IdP")
|
||||
with pytest.raises(MCPIdentityRefreshError, match="refresh rejected"):
|
||||
self._call()
|
||||
|
||||
def test_428_maps_to_no_refresh_token_error(self):
|
||||
from services.enterprise.base import MCPNoRefreshTokenError
|
||||
from services.errors.enterprise import EnterpriseAPIError
|
||||
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
# 428 PreconditionRequired is what EE returns when there's no
|
||||
# stored SSO refresh token for the user.
|
||||
req.send_request.side_effect = EnterpriseAPIError("user has not completed SSO", status_code=428)
|
||||
with pytest.raises(MCPNoRefreshTokenError, match="SSO"):
|
||||
self._call()
|
||||
|
||||
def test_403_maps_to_identity_refresh_error_for_license(self):
|
||||
from services.enterprise.base import MCPIdentityRefreshError
|
||||
from services.errors.enterprise import EnterpriseAPIForbiddenError
|
||||
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.side_effect = EnterpriseAPIForbiddenError("not licensed for MCP forwarding")
|
||||
with pytest.raises(MCPIdentityRefreshError, match="not licensed"):
|
||||
self._call()
|
||||
|
||||
def test_other_status_maps_to_generic_token_error(self):
|
||||
from services.enterprise.base import MCPTokenError
|
||||
from services.errors.enterprise import EnterpriseAPIError
|
||||
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.side_effect = EnterpriseAPIError("upstream 502", status_code=502)
|
||||
with pytest.raises(MCPTokenError, match="status=502"):
|
||||
self._call()
|
||||
|
||||
def test_malformed_response_shape_raises_token_error(self):
|
||||
from services.enterprise.base import MCPTokenError
|
||||
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "not-a-dict"
|
||||
with pytest.raises(MCPTokenError, match="invalid response shape"):
|
||||
self._call()
|
||||
|
||||
def test_missing_token_field_raises_token_error(self):
|
||||
from services.enterprise.base import MCPTokenError
|
||||
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"expires_at": 1700000000} # no token
|
||||
with pytest.raises(MCPTokenError, match="missing or non-string token"):
|
||||
self._call()
|
||||
|
||||
def test_float_expires_at_is_accepted(self):
|
||||
"""expires_at may arrive as float (time.time()) — must be coerced."""
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"token": "t", "expires_at": 1900000000.5}
|
||||
token, exp = self._call()
|
||||
assert token == "t"
|
||||
assert exp == 1900000000
|
||||
assert isinstance(exp, int)
|
||||
|
||||
def test_bool_expires_at_is_rejected(self):
|
||||
"""bool is a subclass of int — must NOT be accepted as expires_at."""
|
||||
from services.enterprise.base import MCPTokenError
|
||||
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"token": "t", "expires_at": True}
|
||||
with pytest.raises(MCPTokenError, match="non-numeric expires_at"):
|
||||
self._call()
|
||||
|
||||
@ -459,6 +459,7 @@ export type McpProviderCreatePayload = {
|
||||
icon: string
|
||||
icon_background?: string
|
||||
icon_type: string
|
||||
identity_mode?: IdentityMode
|
||||
name: string
|
||||
server_identifier: string
|
||||
server_url: string
|
||||
@ -477,6 +478,7 @@ export type McpProviderUpdatePayload = {
|
||||
icon: string
|
||||
icon_background?: string
|
||||
icon_type: string
|
||||
identity_mode?: IdentityMode
|
||||
name: string
|
||||
provider_id: string
|
||||
server_identifier: string
|
||||
@ -669,6 +671,8 @@ export type ApiProviderSchemaType = 'openai_actions' | 'openai_plugin' | 'openap
|
||||
|
||||
export type CredentialType = 'api-key' | 'oauth2' | 'unauthorized'
|
||||
|
||||
export type IdentityMode = 'idp_token' | 'off'
|
||||
|
||||
export type WorkflowToolParameterConfiguration = {
|
||||
description: string
|
||||
form: ToolParameterForm
|
||||
|
||||
@ -351,37 +351,6 @@ export const zMcpProviderDeletePayload = z.object({
|
||||
provider_id: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* MCPProviderCreatePayload
|
||||
*/
|
||||
export const zMcpProviderCreatePayload = z.object({
|
||||
authentication: z.record(z.string(), z.unknown()).nullish(),
|
||||
configuration: z.record(z.string(), z.unknown()).nullish(),
|
||||
headers: z.record(z.string(), z.unknown()).nullish(),
|
||||
icon: z.string(),
|
||||
icon_background: z.string().optional().default(''),
|
||||
icon_type: z.string(),
|
||||
name: z.string(),
|
||||
server_identifier: z.string(),
|
||||
server_url: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* MCPProviderUpdatePayload
|
||||
*/
|
||||
export const zMcpProviderUpdatePayload = z.object({
|
||||
authentication: z.record(z.string(), z.unknown()).nullish(),
|
||||
configuration: z.record(z.string(), z.unknown()).nullish(),
|
||||
headers: z.record(z.string(), z.unknown()).nullish(),
|
||||
icon: z.string(),
|
||||
icon_background: z.string().optional().default(''),
|
||||
icon_type: z.string(),
|
||||
name: z.string(),
|
||||
provider_id: z.string(),
|
||||
server_identifier: z.string(),
|
||||
server_url: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* MCPAuthPayload
|
||||
*/
|
||||
@ -829,6 +798,46 @@ export const zBuiltinToolAddPayload = z.object({
|
||||
visibility: z.string().nullish(),
|
||||
})
|
||||
|
||||
/**
|
||||
* IdentityMode
|
||||
*
|
||||
* How Dify forwards the end-user's identity to an MCP server.
|
||||
*/
|
||||
export const zIdentityMode = z.enum(['idp_token', 'off'])
|
||||
|
||||
/**
|
||||
* MCPProviderCreatePayload
|
||||
*/
|
||||
export const zMcpProviderCreatePayload = z.object({
|
||||
authentication: z.record(z.string(), z.unknown()).nullish(),
|
||||
configuration: z.record(z.string(), z.unknown()).nullish(),
|
||||
headers: z.record(z.string(), z.unknown()).nullish(),
|
||||
icon: z.string(),
|
||||
icon_background: z.string().optional().default(''),
|
||||
icon_type: z.string(),
|
||||
identity_mode: zIdentityMode.optional(),
|
||||
name: z.string(),
|
||||
server_identifier: z.string(),
|
||||
server_url: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* MCPProviderUpdatePayload
|
||||
*/
|
||||
export const zMcpProviderUpdatePayload = z.object({
|
||||
authentication: z.record(z.string(), z.unknown()).nullish(),
|
||||
configuration: z.record(z.string(), z.unknown()).nullish(),
|
||||
headers: z.record(z.string(), z.unknown()).nullish(),
|
||||
icon: z.string(),
|
||||
icon_background: z.string().optional().default(''),
|
||||
icon_type: z.string(),
|
||||
identity_mode: zIdentityMode.optional(),
|
||||
name: z.string(),
|
||||
provider_id: z.string(),
|
||||
server_identifier: z.string(),
|
||||
server_url: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* StrategySetting
|
||||
*/
|
||||
|
||||
Loading…
Reference in New Issue
Block a user