feat(api): add MCP user-identity forwarding (#36839)

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Charles Yao 2026-06-08 06:32:11 +02:00 committed by GitHub
parent db1aa683bc
commit 37e1d452b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 673 additions and 41 deletions

View File

@ -20,7 +20,7 @@ from controllers.console.wraps import (
setup_required,
)
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
@ -210,6 +210,30 @@ class MCPProviderBasePayload(BaseModel):
configuration: dict[str, Any] | None = Field(default_factory=dict)
headers: dict[str, Any] | None = Field(default_factory=dict)
authentication: dict[str, Any] | None = Field(default_factory=dict)
# None means "leave unchanged" on update; the controller resolves it to a
# concrete IdentityMode before calling the service (see _resolve_identity_mode).
identity_mode: IdentityMode | None = None
def _resolve_identity_mode(requested: IdentityMode | None, *, current: IdentityMode) -> IdentityMode:
"""Resolve the effective MCP identity_mode for a create/update request.
Keeps two API-layer concerns out of the service so the service always
receives a concrete value:
* ``None`` means "leave unchanged" (update semantics) fall back to
``current`` (``IdentityMode.OFF`` for a brand-new provider).
* Identity forwarding is an enterprise-only capability. On non-enterprise
deployments any non-OFF value is coerced back to OFF so a persisted row
can never imply forwarding that the runtime won't perform. This gates the
API surface to match the backend gate in
``MCPTool._forwarding_requested`` both the API and the backend
invocation must be gated on ``dify_config.ENTERPRISE_ENABLED``.
"""
mode = current if requested is None else requested
if mode != IdentityMode.OFF and not dify_config.ENTERPRISE_ENABLED:
return IdentityMode.OFF
return mode
class MCPProviderCreatePayload(MCPProviderBasePayload):
@ -1000,6 +1024,7 @@ class ToolProviderMCPApi(Resource):
headers=payload.headers or {},
configuration=configuration,
authentication=authentication,
identity_mode=_resolve_identity_mode(payload.identity_mode, current=IdentityMode.OFF),
)
# 2) Try to fetch tools immediately after creation so they appear without a second save.
@ -1054,6 +1079,11 @@ class ToolProviderMCPApi(Resource):
# Step 3: Perform database update in a transaction
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
# Resolve "leave unchanged" (None) against the stored value, and gate
# the result on ENTERPRISE_ENABLED — both are API-layer concerns, so
# the service receives a concrete IdentityMode.
existing = service.get_provider(provider_id=payload.provider_id, tenant_id=current_tenant_id)
identity_mode = _resolve_identity_mode(payload.identity_mode, current=IdentityMode(existing.identity_mode))
service.update_provider(
tenant_id=current_tenant_id,
provider_id=payload.provider_id,
@ -1067,6 +1097,7 @@ class ToolProviderMCPApi(Resource):
configuration=configuration,
authentication=authentication,
validation_result=validation_result,
identity_mode=identity_mode,
)
return {"result": "success"}

View File

@ -37,6 +37,13 @@ class MCPSupportGrantType(StrEnum):
REFRESH_TOKEN = "refresh_token"
class IdentityMode(StrEnum):
"""How Dify forwards the end-user's identity to an MCP server."""
OFF = "off"
IDP_TOKEN = "idp_token"
class MCPAuthentication(BaseModel):
client_id: str
client_secret: str | None = None
@ -76,6 +83,8 @@ class MCPProviderEntity(BaseModel):
created_at: datetime
updated_at: datetime
identity_mode: IdentityMode = IdentityMode.OFF
@classmethod
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
"""Create entity from database model with decryption"""
@ -96,6 +105,7 @@ class MCPProviderEntity(BaseModel):
icon=db_provider.icon or "",
created_at=db_provider.created_at,
updated_at=db_provider.updated_at,
identity_mode=IdentityMode(db_provider.identity_mode),
)
@property
@ -170,6 +180,7 @@ class MCPProviderEntity(BaseModel):
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
"identity_mode": self.identity_mode,
}
# Add configuration

View File

@ -40,6 +40,7 @@ class MCPClientWithAuthRetry(MCPClient):
provider_entity: MCPProviderEntity | None = None,
authorization_code: str | None = None,
by_server_id: bool = False,
forward_identity_active: bool = False,
):
"""
Initialize the MCP client with auth retry capability.
@ -52,12 +53,15 @@ class MCPClientWithAuthRetry(MCPClient):
provider_entity: Provider entity for authentication
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
forward_identity_active: If True, suppress the static-OAuth retry
on 401 the forwarded identity must propagate as-is.
"""
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.authorization_code = authorization_code
self.by_server_id = by_server_id
self.forward_identity_active = forward_identity_active
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
@ -73,6 +77,8 @@ class MCPClientWithAuthRetry(MCPClient):
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
if self.forward_identity_active:
raise error
if not self.provider_entity:
raise error
if self._has_retried:

View File

@ -54,6 +54,9 @@ class ToolProviderApiEntity(BaseModel):
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
# M3 — user-identity forwarding selector. Round-tripped through the
# console API so the create/edit modal can hydrate the toggle state.
identity_mode: str = Field(default="off", description="Identity-forwarding mechanism: 'off' or 'idp_token'")
# Workflow
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
@ -92,6 +95,9 @@ class ToolProviderApiEntity(BaseModel):
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
# M3 — forwarding selector. Always emit ("off" is a valid
# value that the UI must hydrate, not skip).
optional_fields["identity_mode"] = self.identity_mode
case ToolProviderType.WORKFLOW:
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
case _:

View File

@ -1,6 +1,6 @@
from typing import Any, Self
from core.entities.mcp_provider import MCPProviderEntity
from core.entities.mcp_provider import IdentityMode, MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -28,6 +28,7 @@ class MCPToolProviderController(ToolProviderController):
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
identity_mode: IdentityMode = IdentityMode.OFF,
):
super().__init__(entity)
self.entity: ToolProviderEntityWithPlugin = entity
@ -37,6 +38,7 @@ class MCPToolProviderController(ToolProviderController):
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.identity_mode: IdentityMode = identity_mode
@property
def provider_type(self) -> ToolProviderType:
@ -105,6 +107,7 @@ class MCPToolProviderController(ToolProviderController):
headers=entity.headers,
timeout=entity.timeout,
sse_read_timeout=entity.sse_read_timeout,
identity_mode=entity.identity_mode,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
@ -134,6 +137,7 @@ class MCPToolProviderController(ToolProviderController):
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
identity_mode=self.identity_mode,
)
def get_tools(self) -> list[MCPTool]:
@ -151,6 +155,7 @@ class MCPToolProviderController(ToolProviderController):
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
identity_mode=self.identity_mode,
)
for tool_entity in self.entity.tools
]

View File

@ -6,6 +6,8 @@ import logging
from collections.abc import Generator, Mapping
from typing import Any, cast
from configs import dify_config
from core.entities.mcp_provider import IdentityMode
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import (
@ -25,6 +27,11 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetada
logger = logging.getLogger(__name__)
# Custom header used to carry the forwarded SSO access token. Picked to avoid
# stomping on the workspace-scoped Authorization header (provider OAuth /
# user-supplied custom credentials), which would silently break those flows.
FORWARDED_IDENTITY_HEADER = "X-Dify-SSO-Access-Token"
class MCPTool(Tool):
def __init__(
@ -38,6 +45,7 @@ class MCPTool(Tool):
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
identity_mode: IdentityMode = IdentityMode.OFF,
):
super().__init__(entity, runtime)
self.tenant_id = tenant_id
@ -47,6 +55,7 @@ class MCPTool(Tool):
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.identity_mode: IdentityMode = identity_mode
self._latest_usage = LLMUsage.empty_usage()
def tool_provider_type(self) -> ToolProviderType:
@ -60,7 +69,7 @@ class MCPTool(Tool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
result = self.invoke_remote_mcp_tool(tool_parameters)
result = self.invoke_remote_mcp_tool(tool_parameters, user_id=user_id, app_id=app_id)
# Extract usage metadata from MCP protocol's _meta field
self._latest_usage = self._derive_usage_from_result(result)
@ -234,6 +243,7 @@ class MCPTool(Tool):
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
identity_mode=self.identity_mode,
)
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
@ -246,7 +256,26 @@ class MCPTool(Tool):
if value is not None and not (isinstance(value, str) and value.strip() == "")
}
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
@property
def _forwarding_requested(self) -> bool:
"""True only when the configured identity_mode wants forwarding AND
the deployment actually has the enterprise side that can mint tokens.
Non-enterprise installs treat the DB value as a no-op a stale row
won't trigger a 5xx against a missing inner-API endpoint."""
return self.identity_mode != IdentityMode.OFF and dify_config.ENTERPRISE_ENABLED
def invoke_remote_mcp_tool(
self,
tool_parameters: dict[str, Any],
user_id: str | None = None,
app_id: str | None = None,
) -> CallToolResult:
# Fail closed: forwarding requires user_id (refuse before any DB I/O).
if self._forwarding_requested and not user_id:
raise ToolInvokeError(
"Forward-user-identity is enabled for this MCP provider but no end-user context was supplied."
)
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
@ -271,6 +300,15 @@ class MCPTool(Tool):
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Forwarded identity rides in a custom header so workspace-scoped
# provider credentials (Authorization / custom Headers) keep working
# untouched. The MCP server is expected to read X-Dify-SSO-Access-Token
# when identity forwarding is configured.
forward_identity_active = False
if self._forwarding_requested and user_id:
self._inject_forwarded_identity(headers, user_id=user_id, app_id=app_id, audience=server_url)
forward_identity_active = True
# Step 2: Session is now closed, perform network operations without holding database connection
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
try:
@ -280,9 +318,44 @@ class MCPTool(Tool):
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
forward_identity_active=forward_identity_active,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
def _inject_forwarded_identity(
self,
headers: dict[str, str],
*,
user_id: str,
app_id: str | None,
audience: str,
) -> None:
"""Call the enterprise IssueMCPToken endpoint and stamp the issued
token into X-Dify-SSO-Access-Token.
A custom header is used (rather than Authorization) so it composes
with workspace-scoped provider credentials the user may have OAuth
tokens or a custom Authorization header configured on the MCP
provider, and forwarding must not silently overwrite them.
Errors are surfaced as ToolInvokeError so the workflow halts with a
clear message instead of silently dropping identity and hitting the
MCP server unauthenticated.
"""
from services.enterprise.base import MCPTokenError
from services.enterprise.enterprise_service import EnterpriseService
try:
token, _expires_at = EnterpriseService.issue_mcp_token(
user_id=user_id,
tenant_id=self.tenant_id,
app_id=app_id,
audience=audience,
)
except MCPTokenError as e:
raise ToolInvokeError(f"Failed to obtain forwarded identity token: {e}") from e
headers[FORWARDED_IDENTITY_HEADER] = token

View File

@ -0,0 +1,44 @@
"""add identity mode to mcp tool provider
Revision ID: 3df4dbcc1e21
Revises: 2b3c4d5e6f70
Create Date: 2026-05-29 15:00:00.000000
Adds the `identity_mode` column to `tool_mcp_providers` to drive the M2 MCP
user-identity forwarding feature. Reserved values:
"off" no header forwarded (default; pre-M2 behaviour).
"idp_token" call dify-enterprise /inner/api/mcp/issue-token, stamp the
returned SSO access token on the outbound MCP request as
`X-Dify-SSO-Access-Token: <token>`.
The column is filled with the safe default "off" for existing rows so older
providers keep their current behaviour until an admin opts in.
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "3df4dbcc1e21"
down_revision = "2b3c4d5e6f70"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
"tool_mcp_providers",
sa.Column(
"identity_mode",
sa.String(length=32),
nullable=False,
server_default=sa.text("'off'"),
),
)
def downgrade():
op.drop_column("tool_mcp_providers", "identity_mode")

View File

@ -350,6 +350,14 @@ class MCPToolProvider(TypeBase):
# encrypted headers for MCP server requests
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# M2 (MCP user-identity forwarding) — which identity-forwarding mechanism
# this provider uses. Reserved values:
# "off" — no forwarding (default; preserves pre-M2 behaviour).
# "idp_token" — forward an SSO access token minted by dify-enterprise.
identity_mode: Mapped[str] = mapped_column(
sa.String(32), nullable=False, server_default=sa.text("'off'"), default="off"
)
def load_user(self) -> Account | None:
return db.session.scalar(select(Account).where(Account.id == self.user_id))

View File

@ -14907,6 +14907,14 @@ Icon information model.
| ---- | ---- | ----------- | -------- |
| IconType | string | | |
#### IdentityMode
How Dify forwards the end-user's identity to an MCP server.
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| IdentityMode | string | How Dify forwards the end-user's identity to an MCP server. | |
#### Import
| Name | Type | Description | Required |
@ -15217,6 +15225,7 @@ Enum class for large language model mode.
| icon | string | | Yes |
| icon_background | string | | No |
| icon_type | string | | Yes |
| identity_mode | [IdentityMode](#identitymode) | | No |
| name | string | | Yes |
| server_identifier | string | | Yes |
| server_url | string | | Yes |
@ -15237,6 +15246,7 @@ Enum class for large language model mode.
| icon | string | | Yes |
| icon_background | string | | No |
| icon_type | string | | Yes |
| identity_mode | [IdentityMode](#identitymode) | | No |
| name | string | | Yes |
| provider_id | string | | Yes |
| server_identifier | string | | Yes |

View File

@ -18,7 +18,7 @@ import yaml
from sqlalchemy import or_
from sqlalchemy.orm import Session, sessionmaker
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@ -748,6 +748,9 @@ class MigrationImportService:
headers=mcp_data.get("headers") if isinstance(mcp_data.get("headers"), dict) else {},
configuration=configuration,
authentication=authentication,
# Re-import must not silently reset forwarding: preserve the
# stored mode (update_provider now defaults to OFF when omitted).
identity_mode=IdentityMode(existing.identity_mode),
)
db.session.commit()
status = "updated"

View File

@ -12,8 +12,28 @@ from services.errors.enterprise import (
EnterpriseAPIForbiddenError,
EnterpriseAPINotFoundError,
EnterpriseAPIUnauthorizedError,
EnterpriseServiceError,
)
class MCPTokenError(EnterpriseServiceError):
"""Generic failure of the IssueMCPToken RPC."""
class MCPNoRefreshTokenError(MCPTokenError):
"""User has no stored SSO refresh_token; ask them to re-authenticate."""
def __init__(self, description: str = ""):
super().__init__(description, status_code=428)
class MCPIdentityRefreshError(MCPTokenError):
"""IdP rejected the refresh attempt (revoked/expired session)."""
def __init__(self, description: str = ""):
super().__init__(description, status_code=401)
logger = logging.getLogger(__name__)

View File

@ -11,7 +11,15 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
from extensions.ext_redis import redis_client
from services.enterprise.base import EnterpriseRequest
from services.enterprise.base import (
EnterpriseRequest,
MCPIdentityRefreshError,
MCPNoRefreshTokenError,
MCPTokenError,
)
from services.errors.enterprise import (
EnterpriseServiceError,
)
if TYPE_CHECKING:
from services.feature_service import LicenseStatus
@ -121,6 +129,77 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def issue_mcp_token(
cls,
user_id: str,
tenant_id: str,
app_id: str | None,
audience: str,
) -> tuple[str, int]:
"""Mint a short-lived SSO id_token (or OAuth2 access_token) representing
the calling Dify user, audience-scoped to the given MCP server identifier.
Used by MCPTool.invoke_remote_mcp_tool to stamp the
X-Dify-SSO-Access-Token header on outbound MCP requests when the
provider's identity_mode is set to "idp_token".
Returns:
(token, expires_at_unix_seconds)
Raises:
MCPNoRefreshTokenError: user has no stored SSO refresh_token on the
enterprise side; surface to the workflow as "please log in via SSO".
MCPIdentityRefreshError: enterprise tried to refresh against the IdP
and the IdP rejected (revoked/expired session).
MCPTokenError: any other failure of the enterprise endpoint.
"""
try:
response = EnterpriseRequest.send_request(
"POST",
"/mcp/issue-token",
json={
"user_id": user_id,
"tenant_id": tenant_id,
"app_id": app_id or "",
"audience": audience,
},
)
except EnterpriseServiceError as e:
# The HTTP-status subclasses (400/401/403/404) inherit directly
# from EnterpriseServiceError, not EnterpriseAPIError, so we
# must catch the base class to route them all.
status = getattr(e, "status_code", None)
if status == 401:
# Enterprise side returns 401 when the IdP rejected the refresh.
raise MCPIdentityRefreshError(str(e) or "identity refresh failed; please re-authenticate") from e
if status == 428:
raise MCPNoRefreshTokenError(
str(e) or "user has no stored SSO refresh token; please re-authenticate"
) from e
if status == 403:
# 403 most often means the tenant isn't licensed for MCP
# identity-forwarding. Surface as identity-refresh-failure so
# the workflow halts loudly rather than retrying.
raise MCPIdentityRefreshError(
str(e) or "enterprise refused to issue an MCP identity token (license or policy)"
) from e
raise MCPTokenError(f"issue_mcp_token failed (status={status}): {e}") from e
if not isinstance(response, dict):
raise MCPTokenError("invalid response shape from enterprise /mcp/issue-token")
token = response.get("token")
expires_at = response.get("expires_at")
# Accept int or float for expires_at (some clocks emit float
# seconds-since-epoch). Reject bools explicitly because `bool` is
# an `int` subclass in Python and would pass isinstance(_, int).
if not isinstance(token, str) or not token:
raise MCPTokenError(f"missing or non-string token in enterprise response: {response!r}")
if isinstance(expires_at, bool) or not isinstance(expires_at, (int, float)):
raise MCPTokenError(f"missing or non-numeric expires_at in enterprise response: {response!r}")
return token, int(expires_at)
@classmethod
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
return EnterpriseRequest.send_request(

View File

@ -12,7 +12,7 @@ from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth.auth_flow import auth
@ -136,6 +136,7 @@ class MCPToolManageService:
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
headers: dict[str, str] | None = None,
identity_mode: IdentityMode = IdentityMode.OFF,
) -> ToolProviderApiEntity:
"""Create a new MCP provider."""
# Validate URL format
@ -171,6 +172,7 @@ class MCPToolManageService:
sse_read_timeout=configuration.sse_read_timeout,
encrypted_headers=encrypted_headers,
encrypted_credentials=encrypted_credentials,
identity_mode=identity_mode,
)
self._session.add(mcp_tool)
@ -194,6 +196,7 @@ class MCPToolManageService:
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
validation_result: ServerUrlValidationResult | None = None,
identity_mode: IdentityMode = IdentityMode.OFF,
) -> None:
"""
Update an MCP provider.
@ -255,6 +258,11 @@ class MCPToolManageService:
if authentication and authentication.client_id:
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
# Update user-identity forwarding mode. The controller has already
# resolved "leave unchanged" and applied the ENTERPRISE_ENABLED gate,
# so this is always a concrete, vetted value.
mcp_provider.identity_mode = identity_mode
# Flush changes to database
self._session.flush()

View File

@ -380,3 +380,53 @@ def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.Mon
resp = controller_module.ToolLabelsApi().get()
assert resp == ["a", "b"]
# --- _resolve_identity_mode: gating + None-resolution (PR #36839 review) ---
def test_resolve_identity_mode_none_keeps_current_when_enterprise(controller_module, monkeypatch: pytest.MonkeyPatch):
"""None means 'leave unchanged' — fall back to the stored mode (update path)."""
identity_mode = importlib.import_module("core.entities.mcp_provider").IdentityMode
monkeypatch.setattr(controller_module.dify_config, "ENTERPRISE_ENABLED", True)
resolved = controller_module._resolve_identity_mode(None, current=identity_mode.IDP_TOKEN)
assert resolved == identity_mode.IDP_TOKEN
def test_resolve_identity_mode_explicit_value_overrides_current(controller_module, monkeypatch: pytest.MonkeyPatch):
"""An explicit value wins over the stored mode."""
identity_mode = importlib.import_module("core.entities.mcp_provider").IdentityMode
monkeypatch.setattr(controller_module.dify_config, "ENTERPRISE_ENABLED", True)
resolved = controller_module._resolve_identity_mode(identity_mode.OFF, current=identity_mode.IDP_TOKEN)
assert resolved == identity_mode.OFF
def test_resolve_identity_mode_coerces_non_off_to_off_when_not_enterprise(
controller_module, monkeypatch: pytest.MonkeyPatch
):
"""Gate: a non-EE deployment must never persist a non-OFF mode — the
runtime won't forward, so the stored row must not imply it does."""
identity_mode = importlib.import_module("core.entities.mcp_provider").IdentityMode
monkeypatch.setattr(controller_module.dify_config, "ENTERPRISE_ENABLED", False)
# Both an explicit idp_token request AND an inherited non-OFF current
# must collapse to OFF.
assert (
controller_module._resolve_identity_mode(identity_mode.IDP_TOKEN, current=identity_mode.OFF)
== identity_mode.OFF
)
assert controller_module._resolve_identity_mode(None, current=identity_mode.IDP_TOKEN) == identity_mode.OFF
def test_resolve_identity_mode_off_is_passthrough_when_not_enterprise(
controller_module, monkeypatch: pytest.MonkeyPatch
):
"""OFF is always fine — the gate only neutralizes non-OFF values."""
identity_mode = importlib.import_module("core.entities.mcp_provider").IdentityMode
monkeypatch.setattr(controller_module.dify_config, "ENTERPRISE_ENABLED", False)
assert controller_module._resolve_identity_mode(None, current=identity_mode.OFF) == identity_mode.OFF

View File

@ -53,6 +53,7 @@ def test_from_db_model_maps_fields() -> None:
icon=None,
created_at=now,
updated_at=now,
identity_mode="off",
)
# Act

View File

@ -0,0 +1,41 @@
from __future__ import annotations
import pytest
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError
class TestForwardIdentityShortCircuit:
def test_forward_identity_active_reraises_without_retry(self):
client = MCPClientWithAuthRetry(
server_url="https://mcp.example.com",
headers={"Authorization": "Bearer user-jwt"},
forward_identity_active=True,
)
with pytest.raises(MCPAuthError):
client._handle_auth_error(MCPAuthError("unauthorized"))
assert client.headers["Authorization"] == "Bearer user-jwt"
assert client._has_retried is False
def test_forward_identity_active_takes_precedence_over_provider_entity(self):
sentinel_entity = object()
client = MCPClientWithAuthRetry(
server_url="https://mcp.example.com",
provider_entity=sentinel_entity, # type: ignore[arg-type]
forward_identity_active=True,
)
with pytest.raises(MCPAuthError, match="forwarded-id-401"):
client._handle_auth_error(MCPAuthError("forwarded-id-401"))
def test_default_path_unchanged_without_provider_entity(self):
client = MCPClientWithAuthRetry(server_url="https://mcp.example.com")
with pytest.raises(MCPAuthError, match="no-provider"):
client._handle_auth_error(MCPAuthError("no-provider"))
def test_default_constructor_defaults_forward_identity_to_false(self):
client = MCPClientWithAuthRetry(server_url="https://mcp.example.com")
assert client.forward_identity_active is False

View File

@ -148,3 +148,112 @@ def test_mcp_tool_handle_none_parameter_filters_empty_values():
tool = _build_mcp_tool()
cleaned = tool._handle_none_parameter({"a": 1, "b": None, "c": "", "d": " ", "e": "ok"})
assert cleaned == {"a": 1, "e": "ok"}
# ----- M2/M3 user-identity forwarding ---------------------------------------
def _build_forwarding_tool(*, mode: str = "idp_token") -> MCPTool:
"""Helper that builds an MCPTool with the identity_mode set."""
entity = ToolEntity(
identity=ToolIdentity(
author="author",
name="remote-tool",
label=I18nObject(en_US="remote-tool"),
provider="provider-id",
),
parameters=[],
output_schema={},
)
return MCPTool(
entity=entity,
runtime=ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER),
tenant_id="tenant-1",
icon="icon.svg",
server_url="https://mcp.example.com/mcp/",
provider_id="provider-id",
identity_mode=mode,
)
def test_inject_forwarded_identity_stamps_custom_header():
"""The minted SSO token must be placed in X-Dify-SSO-Access-Token; the
workspace-scoped Authorization header and any other custom headers must
pass through untouched so provider credentials keep working."""
from core.tools.mcp_tool.tool import FORWARDED_IDENTITY_HEADER
tool = _build_forwarding_tool()
headers: dict[str, str] = {"Authorization": "Bearer static-client-token", "X-Other": "keep"}
with patch(
"services.enterprise.enterprise_service.EnterpriseService.issue_mcp_token",
return_value=("forwarded.jwt.payload", 1900000000),
):
tool._inject_forwarded_identity(headers, user_id="alice", app_id=None, audience="https://mcp.example.com/mcp/")
assert headers[FORWARDED_IDENTITY_HEADER] == "forwarded.jwt.payload"
assert headers["Authorization"] == "Bearer static-client-token"
assert headers["X-Other"] == "keep"
def test_inject_forwarded_identity_translates_token_error_to_invoke_error():
"""EnterpriseService failures must surface as ToolInvokeError so the
workflow halts loudly instead of proceeding without identity."""
from core.tools.mcp_tool.tool import FORWARDED_IDENTITY_HEADER
from services.enterprise.base import MCPNoRefreshTokenError
tool = _build_forwarding_tool()
headers: dict[str, str] = {}
with patch(
"services.enterprise.enterprise_service.EnterpriseService.issue_mcp_token",
side_effect=MCPNoRefreshTokenError("please re-sso"),
):
with pytest.raises(ToolInvokeError, match="forwarded identity token"):
tool._inject_forwarded_identity(
headers, user_id="alice", app_id=None, audience="https://mcp.example.com/mcp/"
)
# Headers must NOT have been mutated when token-issuance failed.
assert FORWARDED_IDENTITY_HEADER not in headers
assert "Authorization" not in headers
def test_invoke_remote_mcp_tool_fails_closed_when_user_id_missing():
"""When forwarding is enabled AND the deployment is enterprise, missing
user_id must raise never silently invoke as the static identity."""
tool = _build_forwarding_tool()
with patch("core.tools.mcp_tool.tool.dify_config") as cfg:
cfg.ENTERPRISE_ENABLED = True
with pytest.raises(ToolInvokeError, match="no end-user context"):
tool.invoke_remote_mcp_tool({}, user_id=None, app_id=None)
def test_invoke_skips_forwarding_when_enterprise_disabled():
"""Non-enterprise deployments treat the DB selector as a no-op: a stale
`identity_mode="idp_token"` row must NOT raise (fail-closed) AND must
NOT call the enterprise inner API. The runtime falls through to the
legacy provider-identity path."""
tool = _build_forwarding_tool()
with patch("core.tools.mcp_tool.tool.dify_config") as cfg:
cfg.ENTERPRISE_ENABLED = False
# The fail-closed branch must NOT fire (no enterprise → no forwarding).
# The function will still try the legacy DB-load path; we patch that
# to keep the test unit-scoped.
with patch("core.tools.mcp_tool.tool.MCPClientWithAuthRetry") as client_cls:
client_cls.return_value.__enter__.return_value.invoke_tool.return_value = CallToolResult(
content=[],
_meta=None,
)
with patch.object(tool, "_inject_forwarded_identity") as inject:
with patch("services.tools.mcp_tools_manage_service.MCPToolManageService"):
with patch("core.entities.mcp_provider.MCPProviderEntity.decrypt_server_url", return_value="u"):
with patch("core.entities.mcp_provider.MCPProviderEntity.decrypt_headers", return_value={}):
# Should not raise; should not call enterprise.
try:
tool.invoke_remote_mcp_tool({}, user_id=None, app_id=None)
except Exception:
pass
inject.assert_not_called()

View File

@ -596,7 +596,9 @@ def test_api_tool_create_records_id_mapping(monkeypatch):
def test_mcp_tool_import_restores_exported_tool_list(monkeypatch):
provider = type("Provider", (), {"id": "target-provider-id", "tools": "[]", "authed": False})()
provider = type(
"Provider", (), {"id": "target-provider-id", "tools": "[]", "authed": False, "identity_mode": "off"}
)()
report_items = []
class StubSession:
@ -652,7 +654,9 @@ def test_mcp_tool_import_restores_exported_tool_list(monkeypatch):
@pytest.mark.parametrize("conflict_strategy", [ConflictStrategy.SKIP, ConflictStrategy.UPDATE])
def test_mcp_tool_existing_provider_records_id_mapping(monkeypatch, conflict_strategy):
provider = type("Provider", (), {"id": "target-mcp-provider-id", "tools": "[]", "authed": False})()
provider = type(
"Provider", (), {"id": "target-mcp-provider-id", "tools": "[]", "authed": False, "identity_mode": "off"}
)()
id_mapping = {}
id_mapping_details = []
@ -712,7 +716,9 @@ def test_mcp_tool_existing_provider_records_id_mapping(monkeypatch, conflict_str
def test_mcp_tool_create_records_id_mapping(monkeypatch):
provider = type("Provider", (), {"id": "target-mcp-provider-id", "tools": "[]", "authed": False})()
provider = type(
"Provider", (), {"id": "target-mcp-provider-id", "tools": "[]", "authed": False, "identity_mode": "off"}
)()
id_mapping = {}
provider_created = False

View File

@ -466,3 +466,111 @@ class TestGetCachedLicenseStatus:
assert EnterpriseService.get_cached_license_status() is None
mock_redis.setex.assert_not_called()
class TestIssueMCPToken:
"""Coverage for EnterpriseService.issue_mcp_token (M2).
The function wraps `POST /inner/api/mcp/issue-token` and must map
EnterpriseServiceError subclasses to MCP-typed errors so the workflow
layer can halt with a precise message instead of leaking transport text.
"""
@staticmethod
def _call():
return EnterpriseService.issue_mcp_token(
user_id="user-uuid",
tenant_id="tenant-uuid",
app_id="app-uuid",
audience="https://mcp.example.com/mcp/",
)
def test_happy_path_returns_token_and_expiry(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"token": "abc.def.ghi", "expires_at": 1900000000}
token, exp = self._call()
assert token == "abc.def.ghi"
assert exp == 1900000000
req.send_request.assert_called_once_with(
"POST",
"/mcp/issue-token",
json={
"user_id": "user-uuid",
"tenant_id": "tenant-uuid",
"app_id": "app-uuid",
"audience": "https://mcp.example.com/mcp/",
},
)
def test_401_maps_to_identity_refresh_error(self):
from services.enterprise.base import MCPIdentityRefreshError
from services.errors.enterprise import EnterpriseAPIUnauthorizedError
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.side_effect = EnterpriseAPIUnauthorizedError("refresh rejected by IdP")
with pytest.raises(MCPIdentityRefreshError, match="refresh rejected"):
self._call()
def test_428_maps_to_no_refresh_token_error(self):
from services.enterprise.base import MCPNoRefreshTokenError
from services.errors.enterprise import EnterpriseAPIError
with patch(f"{MODULE}.EnterpriseRequest") as req:
# 428 PreconditionRequired is what EE returns when there's no
# stored SSO refresh token for the user.
req.send_request.side_effect = EnterpriseAPIError("user has not completed SSO", status_code=428)
with pytest.raises(MCPNoRefreshTokenError, match="SSO"):
self._call()
def test_403_maps_to_identity_refresh_error_for_license(self):
from services.enterprise.base import MCPIdentityRefreshError
from services.errors.enterprise import EnterpriseAPIForbiddenError
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.side_effect = EnterpriseAPIForbiddenError("not licensed for MCP forwarding")
with pytest.raises(MCPIdentityRefreshError, match="not licensed"):
self._call()
def test_other_status_maps_to_generic_token_error(self):
from services.enterprise.base import MCPTokenError
from services.errors.enterprise import EnterpriseAPIError
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.side_effect = EnterpriseAPIError("upstream 502", status_code=502)
with pytest.raises(MCPTokenError, match="status=502"):
self._call()
def test_malformed_response_shape_raises_token_error(self):
from services.enterprise.base import MCPTokenError
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = "not-a-dict"
with pytest.raises(MCPTokenError, match="invalid response shape"):
self._call()
def test_missing_token_field_raises_token_error(self):
from services.enterprise.base import MCPTokenError
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"expires_at": 1700000000} # no token
with pytest.raises(MCPTokenError, match="missing or non-string token"):
self._call()
def test_float_expires_at_is_accepted(self):
"""expires_at may arrive as float (time.time()) — must be coerced."""
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"token": "t", "expires_at": 1900000000.5}
token, exp = self._call()
assert token == "t"
assert exp == 1900000000
assert isinstance(exp, int)
def test_bool_expires_at_is_rejected(self):
"""bool is a subclass of int — must NOT be accepted as expires_at."""
from services.enterprise.base import MCPTokenError
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"token": "t", "expires_at": True}
with pytest.raises(MCPTokenError, match="non-numeric expires_at"):
self._call()

View File

@ -459,6 +459,7 @@ export type McpProviderCreatePayload = {
icon: string
icon_background?: string
icon_type: string
identity_mode?: IdentityMode
name: string
server_identifier: string
server_url: string
@ -477,6 +478,7 @@ export type McpProviderUpdatePayload = {
icon: string
icon_background?: string
icon_type: string
identity_mode?: IdentityMode
name: string
provider_id: string
server_identifier: string
@ -669,6 +671,8 @@ export type ApiProviderSchemaType = 'openai_actions' | 'openai_plugin' | 'openap
export type CredentialType = 'api-key' | 'oauth2' | 'unauthorized'
export type IdentityMode = 'idp_token' | 'off'
export type WorkflowToolParameterConfiguration = {
description: string
form: ToolParameterForm

View File

@ -351,37 +351,6 @@ export const zMcpProviderDeletePayload = z.object({
provider_id: z.string(),
})
/**
* MCPProviderCreatePayload
*/
export const zMcpProviderCreatePayload = z.object({
authentication: z.record(z.string(), z.unknown()).nullish(),
configuration: z.record(z.string(), z.unknown()).nullish(),
headers: z.record(z.string(), z.unknown()).nullish(),
icon: z.string(),
icon_background: z.string().optional().default(''),
icon_type: z.string(),
name: z.string(),
server_identifier: z.string(),
server_url: z.string(),
})
/**
* MCPProviderUpdatePayload
*/
export const zMcpProviderUpdatePayload = z.object({
authentication: z.record(z.string(), z.unknown()).nullish(),
configuration: z.record(z.string(), z.unknown()).nullish(),
headers: z.record(z.string(), z.unknown()).nullish(),
icon: z.string(),
icon_background: z.string().optional().default(''),
icon_type: z.string(),
name: z.string(),
provider_id: z.string(),
server_identifier: z.string(),
server_url: z.string(),
})
/**
* MCPAuthPayload
*/
@ -829,6 +798,46 @@ export const zBuiltinToolAddPayload = z.object({
visibility: z.string().nullish(),
})
/**
* IdentityMode
*
* How Dify forwards the end-user's identity to an MCP server.
*/
export const zIdentityMode = z.enum(['idp_token', 'off'])
/**
* MCPProviderCreatePayload
*/
export const zMcpProviderCreatePayload = z.object({
authentication: z.record(z.string(), z.unknown()).nullish(),
configuration: z.record(z.string(), z.unknown()).nullish(),
headers: z.record(z.string(), z.unknown()).nullish(),
icon: z.string(),
icon_background: z.string().optional().default(''),
icon_type: z.string(),
identity_mode: zIdentityMode.optional(),
name: z.string(),
server_identifier: z.string(),
server_url: z.string(),
})
/**
* MCPProviderUpdatePayload
*/
export const zMcpProviderUpdatePayload = z.object({
authentication: z.record(z.string(), z.unknown()).nullish(),
configuration: z.record(z.string(), z.unknown()).nullish(),
headers: z.record(z.string(), z.unknown()).nullish(),
icon: z.string(),
icon_background: z.string().optional().default(''),
icon_type: z.string(),
identity_mode: zIdentityMode.optional(),
name: z.string(),
provider_id: z.string(),
server_identifier: z.string(),
server_url: z.string(),
})
/**
* StrategySetting
*/