feat(api,web): OAuth 2.0 device flow + bearer auth (RFC 8628)

Adds a CLI-friendly authorization flow so difyctl (and future
non-browser clients) can obtain user-scoped tokens without copy-
pasting cookies or raw API keys. Two grant paths share one device
flow surface:

  1. Account branch — user signs in via the existing /signin
     methods, /device page calls console-authed approve, mints a
     dfoa_ token tied to (account_id, tenant).
  2. External-SSO branch (EE) — /v1/oauth/device/sso-initiate signs
     an SSOState envelope, hands off to Enterprise's external ACS,
     receives a signed external-subject assertion, mints a dfoe_
     token tied to (subject_email, subject_issuer).

API surface (all under /v1, EE-only endpoints 404 on CE):

  POST   /v1/oauth/device/code              — RFC 8628 start
  POST   /v1/oauth/device/token             — RFC 8628 poll
  GET    /v1/oauth/device/lookup            — pre-validate user_code
  GET    /v1/oauth/device/sso-initiate      — SSO branch entry
  GET    /v1/device/sso-complete            — SSO callback sink
  GET    /v1/oauth/device/approval-context  — /device cookie probe
  POST   /v1/oauth/device/approve-external  — SSO approve
  GET    /v1/me                             — bearer subject lookup
  DELETE /v1/oauth/authorizations/self      — self-revoke
  POST   /console/api/oauth/device/approve  — account approve
  POST   /console/api/oauth/device/deny     — account deny

Core primitives:
- libs/oauth_bearer.py: prefix-keyed TokenKindRegistry +
  BearerAuthenticator + validate_bearer decorator. Two-tier scope
  (full vs apps:run) stamped from the registry, never from the DB.
- libs/jws.py: HS256 compact JWS keyed on the shared Dify
  SECRET_KEY — same key-set verifies the SSOState envelope, the
  external-subject assertion (minted by Enterprise), and the
  approval-grant cookie.
- libs/device_flow_security.py: enterprise_only gate, approval-
  grant cookie mint/verify/consume (Path=/v1/oauth/device,
  HttpOnly, SameSite=Lax, Secure follows is_secure()), anti-
  framing headers.
- libs/rate_limit.py: typed RateLimit / RateLimitScope dispatch
  with composite-key buckets; both decorator + imperative form.
- services/oauth_device_flow.py: Redis state machine (PENDING ->
  APPROVED|DENIED with atomic consume-on-poll), token mint via
  partial unique index uq_oauth_active_per_device (rotates in
  place), env-driven TTL policy.

Storage: oauth_access_tokens table with partial unique index on
(subject_email, subject_issuer, client_id, device_label) WHERE
revoked_at IS NULL. account_id NULL distinguishes external-SSO
rows. Hard-expire is CAS UPDATE (revoked_at + nullify token_hash)
so audit events keep their token_id. Retention pruner DELETEs
revoked + zombie-expired rows past OAUTH_ACCESS_TOKEN_RETENTION_DAYS.

Frontend: /device page with code-entry, chooser (account vs SSO),
authorize-account, authorize-sso views. SSO branch detaches from
the URL user_code and reads everything from the cookie via
/approval-context. Anti-framing headers on all responses.

Wiring: ENABLE_OAUTH_BEARER feature flag; ext_oauth_bearer binds
the authenticator at startup; clean_oauth_access_tokens_task
scheduled in ext_celery.

Spec: docs/specs/v1.0/server/{device-flow,tokens,middleware,security}.md
This commit is contained in:
GareArc 2026-04-26 20:06:43 -07:00
parent 8f070f2190
commit fe8510ad1a
No known key found for this signature in database
30 changed files with 2967 additions and 14 deletions

View File

@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
ext_logstore,
ext_mail,
ext_migrate,
ext_oauth_bearer,
ext_orjson,
ext_otel,
ext_proxy_fix,
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
ext_oauth_bearer,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

View File

@ -874,6 +874,11 @@ class AuthConfig(BaseSettings):
default=86400,
)
ENABLE_OAUTH_BEARER: bool = Field(
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
default=True,
)
class ModerationConfig(BaseSettings):
"""
@ -1148,6 +1153,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable scheduled workflow run cleanup task",
default=False,
)
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
default=True,
)
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
description="Days to retain revoked OAuth access-token rows before deletion.",
default=30,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,

View File

@ -81,6 +81,7 @@ from .auth import (
forgot_password,
login,
oauth,
oauth_device,
oauth_server,
)
@ -189,6 +190,7 @@ __all__ = [
"models",
"notification",
"oauth",
"oauth_device",
"oauth_server",
"ops_trace",
"parameter",

View File

@ -0,0 +1,221 @@
"""Console-session-authed device-flow approve/deny. Called by the
/device page after the user signs in. Public lookup is in service_api/oauth.py.
"""
from __future__ import annotations
import logging
from functools import wraps
from flask_login import login_required
from flask_restx import Resource, reqparse
from werkzeug.exceptions import ServiceUnavailable
from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
from libs.oauth_bearer import SubjectType
from libs.rate_limit import LIMIT_APPROVE_CONSOLE, rate_limit
def bearer_feature_required(fn):
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
without the authenticator, so fail fast instead of approving silently.
"""
@wraps(fn)
def inner(*args, **kwargs):
if not dify_config.ENABLE_OAUTH_BEARER:
raise ServiceUnavailable(
"bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable"
)
return fn(*args, **kwargs)
return inner
from services.oauth_device_flow import (
PREFIX_OAUTH_ACCOUNT,
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransition,
StateNotFound,
mint_oauth_token,
oauth_ttl_days,
)
logger = logging.getLogger(__name__)
_mutate_parser = reqparse.RequestParser()
_mutate_parser.add_argument("user_code", type=str, required=True, location="json")
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
_APPROVE_GUARD_TTL_SECONDS = 10
@console_ns.route("/oauth/device/approve")
class DeviceApproveApi(Resource):
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
args = _mutate_parser.parse_args()
user_code = args["user_code"].strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
# SET NX guard — without it, two in-flight approves both pass
# PENDING, both mint, and the second upsert silently rotates the
# first caller into an already-revoked token.
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
return {"error": "approve_in_progress"}, 409
try:
ttl_days = oauth_ttl_days(tenant_id=tenant)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=account.email,
subject_issuer=None,
account_id=str(account.id),
client_id=state.client_id,
device_label=state.device_label,
prefix=PREFIX_OAUTH_ACCOUNT,
ttl_days=ttl_days,
)
poll_payload = _build_account_poll_payload(account, tenant, mint)
try:
store.approve(
device_code,
subject_email=account.email,
account_id=str(account.id),
subject_issuer=None,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFound, InvalidTransition) as e:
# Row minted but state vanished — roll forward; the orphan
# token is revocable via auth devices list / Authorized Apps.
logger.error("device_flow: approve raced on %s: %s", device_code, e)
return {"error": "state_lost"}, 409
finally:
redis_client.delete(guard_key)
_emit_approve_audit(state, account, tenant, mint)
return {"status": "approved"}, 200
@console_ns.route("/oauth/device/deny")
class DeviceDenyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
args = _mutate_parser.parse_args()
user_code = args["user_code"].strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
try:
store.deny(device_code)
except (StateNotFound, InvalidTransition) as e:
logger.error("device_flow: deny raced on %s: %s", device_code, e)
return {"error": "state_lost"}, 409
_emit_deny_audit(state)
return {"status": "denied"}, 200
def _build_account_poll_payload(account, tenant, mint) -> dict:
"""Pre-render the poll-response body so the unauthenticated poll
handler doesn't re-query accounts/tenants for authz data.
"""
from models import Tenant, TenantAccountJoin
rows = (
db.session.query(Tenant, TenantAccountJoin)
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == account.id)
.all()
)
workspaces = [
{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")}
for t, m in rows
]
# Prefer active session tenant → DB-flagged current join → first membership.
default_ws_id = None
if tenant and any(w["id"] == str(tenant) for w in workspaces):
default_ws_id = str(tenant)
if default_ws_id is None:
for _t, m in rows:
if getattr(m, "current", False):
default_ws_id = str(m.tenant_id)
break
if default_ws_id is None and workspaces:
default_ws_id = workspaces[0]["id"]
return {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.ACCOUNT,
"account": {"id": str(account.id), "email": account.email, "name": account.name},
"workspaces": workspaces,
"default_workspace_id": default_ws_id,
"token_id": str(mint.token_id),
}
def _emit_approve_audit(state, account, tenant, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
mint.token_id, account.email, state.client_id, state.device_label, mint.expires_at,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"token_id": str(mint.token_id),
"subject_type": SubjectType.ACCOUNT,
"subject_email": account.email,
"account_id": str(account.id),
"tenant_id": tenant,
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["full"],
"expires_at": mint.expires_at.isoformat(),
},
)
def _emit_deny_audit(state) -> None:
logger.warning(
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
state.client_id, state.device_label,
extra={
"audit": True,
"event": "oauth.device_flow_denied",
"client_id": state.client_id,
"device_label": state.device_label,
},
)

View File

@ -0,0 +1,264 @@
"""SSO-branch device-flow endpoints. Browser hits sso-initiate → API
signs an SSOState envelope Enterprise inner-API returns IdP authorize
URL 302. IdP Enterprise ACS DeviceFlowDispatcher mints a signed
external-subject assertion 302 to /v1/device/sso-complete API mints
the approval-grant cookie /device user clicks Approve /approve-
external mints the OAuth token. All four endpoints are EE-only.
"""
from __future__ import annotations
import logging
import secrets
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from flask import Blueprint, jsonify, make_response, redirect, request
from libs import jws
from libs.oauth_bearer import SubjectType
from libs.rate_limit import (
LIMIT_APPROVE_EXT_PER_EMAIL,
LIMIT_SSO_INITIATE_PER_IP,
enforce,
rate_limit,
)
from libs.device_flow_security import (APPROVAL_GRANT_COOKIE_NAME, ApprovalGrantClaims,
approval_grant_cleared_cookie_kwargs,
approval_grant_cookie_kwargs,
attach_anti_framing,
consume_approval_grant_nonce,
consume_sso_assertion_nonce,
enterprise_only, mint_approval_grant,
verify_approval_grant)
from services.enterprise.enterprise_service import EnterpriseService
from services.oauth_device_flow import (PREFIX_OAUTH_EXTERNAL_SSO,
DeviceFlowRedis, DeviceFlowStatus,
InvalidTransition, StateNotFound,
mint_oauth_token, oauth_ttl_days)
from werkzeug.exceptions import (BadGateway, BadRequest, Conflict, Forbidden,
NotFound, Unauthorized)
logger = logging.getLogger(__name__)
bp = Blueprint("oauth_device_sso", __name__, url_prefix="/v1")
attach_anti_framing(bp)
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
# device_code it references.
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
@enterprise_only
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
def sso_initiate():
user_code = (request.args.get("user_code") or "").strip().upper()
if not user_code:
raise BadRequest("user_code required")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise BadRequest("invalid_user_code")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise BadRequest("invalid_user_code")
keyset = jws.KeySet.from_shared_secret()
signed_state = jws.sign(
keyset,
payload={
"redirect_url": "",
"app_code": "",
"intent": "device_flow",
"user_code": user_code,
"nonce": secrets.token_urlsafe(16),
"return_to": "",
"idp_callback_url": f"{request.host_url.rstrip('/')}/v1/device/sso-complete",
},
aud=jws.AUD_STATE_ENVELOPE,
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
)
try:
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
except Exception as e:
logger.warning("sso-initiate: enterprise call failed: %s", e)
raise BadGateway("sso_initiate_failed") from e
url = (reply or {}).get("url")
if not url:
raise BadGateway("sso_initiate_missing_url")
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
resp = redirect(url, code=302)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
@bp.route("/device/sso-complete", methods=["GET"])
@enterprise_only
def sso_complete():
blob = request.args.get("sso_assertion")
if not blob:
raise BadRequest("sso_assertion required")
keyset = jws.KeySet.from_shared_secret()
try:
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
except jws.VerifyError as e:
logger.warning("sso-complete: rejected assertion: %s", e)
raise BadRequest("invalid_sso_assertion") from e
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
raise BadRequest("invalid_sso_assertion")
user_code = (claims.get("user_code") or "").strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise Conflict("user_code_not_pending")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
iss = request.host_url.rstrip("/")
cookie_value, _ = mint_approval_grant(
keyset=keyset,
iss=iss,
subject_email=claims["email"],
subject_issuer=claims["issuer"],
user_code=user_code,
)
resp = redirect("/device?sso_verified=1", code=302)
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
return resp
@bp.route("/oauth/device/approval-context", methods=["GET"])
@enterprise_only
def approval_context():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("no_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approval-context: bad cookie: %s", e)
raise Unauthorized("no_session") from e
return jsonify({
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"user_code": claims.user_code,
"csrf_token": claims.csrf_token,
"expires_at": claims.expires_at.isoformat(),
}), 200
@bp.route("/oauth/device/approve-external", methods=["POST"])
@enterprise_only
def approve_external():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("invalid_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approve-external: bad cookie: %s", e)
raise Unauthorized("invalid_session") from e
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
csrf_header = request.headers.get("X-CSRF-Token", "")
if not csrf_header or csrf_header != claims.csrf_token:
raise Forbidden("csrf_mismatch")
data = request.get_json(silent=True) or {}
body_user_code = (data.get("user_code") or "").strip().upper()
if body_user_code != claims.user_code:
raise BadRequest("user_code_mismatch")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(claims.user_code)
if found is None:
raise NotFound("user_code_not_pending")
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if not consume_approval_grant_nonce(redis_client, claims.nonce):
raise Unauthorized("session_already_consumed")
ttl_days = oauth_ttl_days(tenant_id=None)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=claims.subject_email,
subject_issuer=claims.subject_issuer,
account_id=None,
client_id=state.client_id,
device_label=state.device_label,
prefix=PREFIX_OAUTH_EXTERNAL_SSO,
ttl_days=ttl_days,
)
poll_payload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"account": None,
"workspaces": [],
"default_workspace_id": None,
"token_id": str(mint.token_id),
}
try:
store.approve(
device_code,
subject_email=claims.subject_email,
account_id=None,
subject_issuer=claims.subject_issuer,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFound, InvalidTransition) as e:
logger.error("approve-external: state transition raced: %s", e)
raise Conflict("state_lost") from e
_emit_approve_external_audit(state, claims, mint)
resp = make_response(jsonify({"status": "approved"}), 200)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
def _emit_approve_external_audit(state, claims, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved subject_type=%s "
"subject_email=%s subject_issuer=%s token_id=%s",
SubjectType.EXTERNAL_SSO, claims.subject_email, claims.subject_issuer, mint.token_id,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"token_id": str(mint.token_id),
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["apps:run"],
"expires_at": mint.expires_at.isoformat(),
},
)

View File

@ -14,7 +14,7 @@ api = ExternalApi(
service_api_ns = Namespace("service_api", description="Service operations", path="/")
from . import index
from . import index, oauth
from .app import (
annotation,
app,
@ -54,6 +54,7 @@ __all__ = [
"message",
"metadata",
"models",
"oauth",
"rag_pipeline_workflow",
"segment",
"site",

View File

@ -0,0 +1,302 @@
"""``/v1`` OAuth bearer + device-flow endpoints. ``/me`` and self-revoke
are bearer-authed; the device-flow trio (code/token/lookup) is public
code/token per RFC 8628, lookup so the /device page can pre-validate
before the user has a console session.
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from flask import g, request
from flask_restx import Resource, reqparse
from sqlalchemy import update
from werkzeug.exceptions import BadRequest
from controllers.service_api import service_api_ns
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
SubjectType,
TOKEN_CACHE_KEY_FMT,
validate_bearer,
)
from libs.rate_limit import (
LIMIT_DEVICE_CODE_PER_IP,
LIMIT_LOOKUP_PUBLIC,
LIMIT_ME_PER_ACCOUNT,
LIMIT_ME_PER_EMAIL,
enforce,
rate_limit,
)
from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin
from services.oauth_device_flow import (
DEFAULT_POLL_INTERVAL_SECONDS,
DEVICE_FLOW_TTL_SECONDS,
DeviceFlowRedis,
DeviceFlowStatus,
SlowDownDecision,
)
logger = logging.getLogger(__name__)
KNOWN_CLIENT_IDS = frozenset({"difyctl"})
# ============================================================================
# GET /v1/me
# ============================================================================
@service_api_ns.route("/me")
class MeApi(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = g.auth_ctx
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
else:
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return {
"subject_type": ctx.subject_type,
"subject_email": ctx.subject_email,
"subject_issuer": ctx.subject_issuer,
"account": None,
"workspaces": [],
"default_workspace_id": None,
}
account = (
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none()
if ctx.account_id else None
)
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
default_ws_id = _pick_default_workspace(memberships)
return {
"subject_type": ctx.subject_type,
"subject_email": ctx.subject_email or (account.email if account else None),
"account": _account_payload(account) if account else None,
"workspaces": [_workspace_payload(m) for m in memberships],
"default_workspace_id": default_ws_id,
}
def _load_memberships(account_id):
return (
db.session.query(TenantAccountJoin, Tenant)
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
.filter(TenantAccountJoin.account_id == account_id)
.all()
)
def _pick_default_workspace(memberships) -> str | None:
if not memberships:
return None
for join, tenant in memberships:
if getattr(join, "current", False):
return str(tenant.id)
return str(memberships[0][1].id)
def _workspace_payload(row) -> dict:
join, tenant = row
return {"id": str(tenant.id), "name": tenant.name, "role": getattr(join, "role", "")}
def _account_payload(account) -> dict:
return {"id": str(account.id), "email": account.email, "name": account.name}
# ============================================================================
# DELETE /v1/oauth/authorizations/self
# ============================================================================
@service_api_ns.route("/oauth/authorizations/self")
class OAuthAuthorizationsSelfApi(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self):
ctx = g.auth_ctx
if not ctx.source.startswith("oauth"):
raise BadRequest(
"this endpoint revokes OAuth bearer tokens; "
"use /v1/personal-access-tokens/self for PATs"
)
# Snapshot pre-revoke hash for cache invalidation; UPDATE WHERE
# makes double-revoke idempotent.
row = (
db.session.query(OAuthAccessToken.token_hash)
.filter(
OAuthAccessToken.id == str(ctx.token_id),
OAuthAccessToken.revoked_at.is_(None),
)
.one_or_none()
)
pre_revoke_hash = row[0] if row else None
stmt = (
update(OAuthAccessToken)
.where(
OAuthAccessToken.id == str(ctx.token_id),
OAuthAccessToken.revoked_at.is_(None),
)
.values(revoked_at=datetime.now(UTC), token_hash=None)
)
db.session.execute(stmt)
db.session.commit()
if pre_revoke_hash:
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))
return {"status": "revoked"}, 200
# ============================================================================
# POST /v1/oauth/device/code (unauthenticated — CLI starts a flow)
# ============================================================================
_code_parser = reqparse.RequestParser()
_code_parser.add_argument("client_id", type=str, required=True, location="json")
_code_parser.add_argument("device_label", type=str, required=True, location="json")
@service_api_ns.route("/oauth/device/code")
class OAuthDeviceCodeApi(Resource):
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
def post(self):
args = _code_parser.parse_args()
client_id = args["client_id"]
device_label = args["device_label"]
if client_id not in KNOWN_CLIENT_IDS:
return {"error": "unsupported_client"}, 400
store = DeviceFlowRedis(redis_client)
ip = extract_remote_ip(request)
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
return {
"device_code": device_code,
"user_code": user_code,
"verification_uri": _verification_uri(),
"expires_in": expires_in,
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
}, 200
def _verification_uri() -> str:
from configs import dify_config
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
if base:
return f"{base.rstrip('/')}/device"
return f"{request.host_url.rstrip('/')}/device"
# ============================================================================
# POST /v1/oauth/device/token (unauthenticated — CLI polls)
# ============================================================================
_poll_parser = reqparse.RequestParser()
_poll_parser.add_argument("device_code", type=str, required=True, location="json")
_poll_parser.add_argument("client_id", type=str, required=True, location="json")
@service_api_ns.route("/oauth/device/token")
class OAuthDeviceTokenApi(Resource):
"""RFC 8628 poll."""
def post(self):
args = _poll_parser.parse_args()
device_code = args["device_code"]
store = DeviceFlowRedis(redis_client)
# slow_down beats every other branch — polling-too-fast clients
# see only that response regardless of underlying state.
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
return {"error": "slow_down"}, 400
state = store.load_by_device_code(device_code)
if state is None:
return {"error": "expired_token"}, 400
if state.status is DeviceFlowStatus.PENDING:
return {"error": "authorization_pending"}, 400
terminal = store.consume_on_poll(device_code)
if terminal is None:
return {"error": "expired_token"}, 400
if terminal.status is DeviceFlowStatus.DENIED:
return {"error": "access_denied"}, 400
poll_payload = terminal.poll_payload or {}
if "token" not in poll_payload:
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
return {"error": "expired_token"}, 400
_audit_cross_ip_if_needed(state)
return poll_payload, 200
# ============================================================================
# GET /v1/oauth/device/lookup (unauthenticated — /device page pre-validates)
# ============================================================================
_lookup_parser = reqparse.RequestParser()
_lookup_parser.add_argument("user_code", type=str, required=True, location="args")
@service_api_ns.route("/oauth/device/lookup")
class OAuthDeviceLookupApi(Resource):
"""Read-only — public for pre-validate before login. user_code is
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
"""
@rate_limit(LIMIT_LOOKUP_PUBLIC)
def get(self):
args = _lookup_parser.parse_args()
user_code = args["user_code"].strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
_device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
return {
"valid": True,
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
"client_id": state.client_id,
}, 200
def _audit_cross_ip_if_needed(state) -> None:
poll_ip = extract_remote_ip(request)
if state.created_ip and poll_ip and poll_ip != state.created_ip:
logger.warning(
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
state.token_id, state.created_ip, poll_ip,
extra={
"audit": True,
"token_id": state.token_id,
"creation_ip": state.created_ip,
"poll_ip": poll_ip,
},
)

View File

@ -90,6 +90,15 @@ def init_app(app: DifyApp):
app.register_blueprint(inner_api_bp)
app.register_blueprint(mcp_bp)
# SSO-branch device-flow routes. No CORS config — these endpoints are
# user-interactive (same-origin browser traffic) and cookie-authed;
# allowing cross-origin would defeat the SameSite=Lax cookie's purpose.
# Gated on ENABLE_OAUTH_BEARER: without the bearer authenticator, tokens
# minted here cannot be validated, so skip the blueprint entirely.
if dify_config.ENABLE_OAUTH_BEARER:
from controllers.oauth_device_sso import bp as oauth_device_sso_bp
app.register_blueprint(oauth_device_sso_bp)
# Register trigger blueprint with CORS for webhook calls
_apply_cors_once(
trigger_bp,

View File

@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
"schedule": crontab(minute="0", hour="0"),
}
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
imports.append("schedule.clean_oauth_access_tokens_task")
beat_schedule["clean_oauth_access_tokens_task"] = {
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
}
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {

View File

@ -0,0 +1,22 @@
"""Bind the bearer authenticator at startup. Must run after ext_database
and ext_redis (needs both factories).
"""
from __future__ import annotations
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import build_and_bind
def is_enabled() -> bool:
return dify_config.ENABLE_OAUTH_BEARER
def init_app(app: DifyApp) -> None:
# scoped_session isn't a context manager; request teardown closes it.
def session_factory():
return db.session
build_and_bind(session_factory=session_factory, redis_client=redis_client)

View File

@ -0,0 +1,187 @@
"""Device-flow security primitives: enterprise_only gate, approval-grant
cookie mint/verify/consume, and anti-framing headers.
"""
from __future__ import annotations
import logging
import secrets
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from functools import wraps
from typing import Callable
from flask import Blueprint
from werkzeug.exceptions import NotFound
from libs import jws
from libs.token import is_secure
from services.feature_service import FeatureService, LicenseStatus
logger = logging.getLogger(__name__)
# ============================================================================
# enterprise_only decorator
# ============================================================================
_CE_LIKE_STATUSES = {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
responses don't consume the bucket.
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
if settings.license.status in _CE_LIKE_STATUSES:
raise NotFound()
return view(*args, **kwargs)
return decorated
# ============================================================================
# approval_grant cookie
# ============================================================================
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
APPROVAL_GRANT_COOKIE_PATH = "/v1/oauth/device"
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
@dataclass(frozen=True, slots=True)
class ApprovalGrantClaims:
subject_email: str
subject_issuer: str
user_code: str
nonce: str
csrf_token: str
expires_at: datetime
def mint_approval_grant(
*,
keyset: jws.KeySet,
iss: str,
subject_email: str,
subject_issuer: str,
user_code: str,
) -> tuple[str, ApprovalGrantClaims]:
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
source of truth for Path/HttpOnly/Secure/SameSite.
"""
now = datetime.now(UTC)
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
nonce = _random_opaque()
csrf_token = _random_opaque()
payload = {
"iss": iss,
"subject_email": subject_email,
"subject_issuer": subject_issuer,
"user_code": user_code,
"nonce": nonce,
"csrf_token": csrf_token,
}
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
return token, ApprovalGrantClaims(
subject_email=subject_email,
subject_issuer=subject_issuer,
user_code=user_code,
nonce=nonce,
csrf_token=csrf_token,
expires_at=exp,
)
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
"""Sig + aud + exp only — nonce consumption is the caller's job."""
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
return ApprovalGrantClaims(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
user_code=data["user_code"],
nonce=data["nonce"],
csrf_token=data["csrf_token"],
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
)
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
if not nonce:
return False
return bool(
redis_client.set(
NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS,
)
)
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
if not nonce:
return False
return bool(
redis_client.set(
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS,
)
)
def approval_grant_cookie_kwargs(value: str) -> dict:
"""``secure`` follows is_secure() so HTTP-only deployments don't
silently drop the cookie.
"""
return {
"key": APPROVAL_GRANT_COOKIE_NAME,
"value": value,
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
"path": APPROVAL_GRANT_COOKIE_PATH,
"secure": is_secure(),
"httponly": True,
"samesite": "Lax",
}
def approval_grant_cleared_cookie_kwargs() -> dict:
return {
"key": APPROVAL_GRANT_COOKIE_NAME,
"value": "",
"max_age": 0,
"path": APPROVAL_GRANT_COOKIE_PATH,
"secure": is_secure(),
"httponly": True,
"samesite": "Lax",
}
def _random_opaque() -> str:
return secrets.token_urlsafe(16)
# ============================================================================
# Anti-framing headers
# ============================================================================
_ANTI_FRAMING_HEADERS = {
"X-Frame-Options": "DENY",
"Content-Security-Policy": "frame-ancestors 'none'",
}
def attach_anti_framing(bp: Blueprint) -> None:
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
@bp.after_request
def _apply_headers(response):
for name, value in _ANTI_FRAMING_HEADERS.items():
response.headers.setdefault(name, value)
return response

106
api/libs/jws.py Normal file
View File

@ -0,0 +1,106 @@
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
state envelope, external subject assertion, and approval-grant cookie
all three share one key-set so api enterprise can verify each other.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import jwt
from configs import dify_config
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
ACTIVE_KID_V1 = "dify-shared-v1"
class KeySetError(Exception):
pass
class KeySet:
"""``from_entries`` reserves multi-kid construction for rotation slots."""
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
if active_kid not in entries:
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
if not entries[active_kid]:
raise KeySetError(f"active kid {active_kid!r} has empty secret")
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
self._active_kid = active_kid
@classmethod
def from_shared_secret(cls) -> "KeySet":
secret = dify_config.SECRET_KEY
if not secret:
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
@classmethod
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> "KeySet":
return cls(entries, active_kid)
@property
def active_kid(self) -> str:
return self._active_kid
def lookup(self, kid: str) -> bytes | None:
return self._entries.get(kid)
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
"""``iat`` + ``exp`` are injected here; callers must not set them."""
if "aud" in payload or "iat" in payload or "exp" in payload:
raise ValueError("reserved claim present in payload (aud/iat/exp)")
if ttl_seconds <= 0:
raise ValueError("ttl_seconds must be positive")
kid = keyset.active_kid
secret = keyset.lookup(kid)
if secret is None:
raise KeySetError(f"active kid {kid!r} lookup miss")
iat = datetime.now(UTC)
exp = iat + timedelta(seconds=ttl_seconds)
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
return jwt.encode(
claims,
secret,
algorithm="HS256",
headers={"kid": kid, "typ": "JWT"},
)
class VerifyError(Exception):
pass
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
"""Unknown kid is rejected — never fall back to the active kid, since
a past kid value would otherwise be forgeable by anyone who saw it.
"""
try:
header = jwt.get_unverified_header(token)
except jwt.PyJWTError as e:
raise VerifyError(f"decode header: {e}") from e
kid = header.get("kid")
if not kid:
raise VerifyError("no kid in header")
secret = keyset.lookup(kid)
if secret is None:
raise VerifyError(f"unknown kid {kid!r}")
try:
return jwt.decode(
token,
secret,
algorithms=["HS256"],
audience=expected_aud,
)
except jwt.ExpiredSignatureError as e:
raise VerifyError("token expired") from e
except jwt.InvalidAudienceError as e:
raise VerifyError("aud mismatch") from e
except jwt.PyJWTError as e:
raise VerifyError(f"decode: {e}") from e

425
api/libs/oauth_bearer.py Normal file
View File

@ -0,0 +1,425 @@
"""OAuth bearer primitives.
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
Authenticator + validate_bearer stay untouched.
"""
from __future__ import annotations
import hashlib
import json
import logging
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime
from enum import StrEnum
from functools import wraps
from typing import Callable, Iterable, Literal, Protocol
from flask import g, request
from sqlalchemy import update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
from models import OAuthAccessToken
logger = logging.getLogger(__name__)
# ============================================================================
# Contract — types, enums, protocols
# ============================================================================
class SubjectType(StrEnum):
ACCOUNT = "account"
EXTERNAL_SSO = "external_sso"
@dataclass(frozen=True, slots=True)
class AuthContext:
"""Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source``
come from the TokenKind, not the DB corrupt rows can't elevate scope.
"""
subject_type: SubjectType
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
scopes: frozenset[str]
token_id: uuid.UUID
source: str
expires_at: datetime | None
@dataclass(frozen=True, slots=True)
class ResolvedRow:
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
token_id: uuid.UUID
expires_at: datetime | None
class Resolver(Protocol):
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
...
@dataclass(frozen=True, slots=True)
class TokenKind:
prefix: str
subject_type: SubjectType
scopes: frozenset[str]
source: str
resolver: Resolver
def matches(self, token: str) -> bool:
return token.startswith(self.prefix)
class InvalidBearer(Exception):
"""Token missing, unknown prefix, or no live row."""
class TokenExpired(Exception):
"""Hard-expire bookkeeping is the resolver's job before raising."""
# ============================================================================
# Registry
# ============================================================================
class TokenKindRegistry:
def __init__(self, kinds: Iterable[TokenKind]) -> None:
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
prefixes = [k.prefix for k in self._kinds]
if len(set(prefixes)) != len(prefixes):
raise ValueError(f"duplicate prefix in registry: {prefixes}")
def find(self, token: str) -> TokenKind | None:
for k in self._kinds:
if k.matches(token):
return k
return None
def kinds(self) -> tuple[TokenKind, ...]:
return self._kinds
# ============================================================================
# Authenticator
# ============================================================================
def sha256_hex(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
class BearerAuthenticator:
def __init__(self, registry: TokenKindRegistry) -> None:
self._registry = registry
def authenticate(self, token: str) -> AuthContext:
kind = self._registry.find(token)
if kind is None:
raise InvalidBearer("unknown token prefix")
row = kind.resolver.resolve(sha256_hex(token))
if row is None:
raise InvalidBearer("token unknown or revoked")
return AuthContext(
subject_type=kind.subject_type,
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=row.account_id,
scopes=kind.scopes,
token_id=row.token_id,
source=kind.source,
expires_at=row.expires_at,
)
# ============================================================================
# OAuth access token resolver (PAT resolver would be a sibling class)
# ============================================================================
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
POSITIVE_TTL_SECONDS = 60
NEGATIVE_TTL_SECONDS = 10
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
ScopeVariant = Literal["account", "external_sso"]
class OAuthAccessTokenResolver:
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
sharing DB + cache plumbing.
"""
def __init__(
self,
session_factory,
redis_client,
positive_ttl: int = POSITIVE_TTL_SECONDS,
negative_ttl: int = NEGATIVE_TTL_SECONDS,
) -> None:
self._session_factory = session_factory
self._redis = redis_client
self._positive_ttl = positive_ttl
self._negative_ttl = negative_ttl
def for_account(self) -> Resolver:
return _VariantResolver(self, variant="account")
def for_external_sso(self) -> Resolver:
return _VariantResolver(self, variant="external_sso")
def _cache_key(self, token_hash: str) -> str:
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
def _cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
raw = self._redis.get(self._cache_key(token_hash))
if raw is None:
return None
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return "invalid"
try:
data = json.loads(text)
return _row_from_cache(data)
except (ValueError, KeyError):
logger.warning("auth:token cache entry malformed; treating as miss")
return None
def _cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
self._redis.setex(
self._cache_key(token_hash),
self._positive_ttl,
json.dumps(_row_to_cache(row)),
)
def _cache_set_negative(self, token_hash: str) -> None:
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
def _hard_expire(self, session: Session, row_id: uuid.UUID, token_hash: str) -> None:
"""Atomic CAS — only the worker that flips revoked_at emits audit;
replays are idempotent. Spec: tokens.md §Detection + hard-expire.
"""
stmt = (
update(OAuthAccessToken)
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
.values(revoked_at=datetime.now(UTC), token_hash=None)
)
result = session.execute(stmt)
session.commit()
if result.rowcount == 1:
logger.warning(
"audit: %s token_id=%s", AUDIT_OAUTH_EXPIRED, row_id,
extra={"audit": True, "token_id": str(row_id)},
)
self._redis.delete(self._cache_key(token_hash))
self._cache_set_negative(token_hash)
class _VariantResolver:
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
self._parent = parent
self._variant = variant
def resolve(self, token_hash: str) -> ResolvedRow | None:
cached = self._parent._cache_get(token_hash)
if cached == "invalid":
return None
if cached is not None and not isinstance(cached, str):
if not self._matches_variant(cached):
return None
return cached
# _session_factory returns Flask-SQLAlchemy's scoped_session, which is
# request-bound and not a context manager; use it directly.
session = self._parent._session_factory()
row = self._load_from_db(session, token_hash)
if row is None:
self._parent._cache_set_negative(token_hash)
return None
now = datetime.now(UTC)
if row.expires_at is not None and row.expires_at <= now:
self._parent._hard_expire(session, row.id, token_hash)
return None
if not self._matches_variant_model(row):
logger.error(
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
row.id, row.prefix,
)
return None
resolved = ResolvedRow(
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
token_id=uuid.UUID(str(row.id)),
expires_at=row.expires_at,
)
self._parent._cache_set_positive(token_hash, resolved)
return resolved
def _matches_variant(self, row: ResolvedRow) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account
return not has_account
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account and row.prefix == "dfoa_"
return (not has_account) and row.prefix == "dfoe_"
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
return (
session.query(OAuthAccessToken)
.filter(
OAuthAccessToken.token_hash == token_hash,
OAuthAccessToken.revoked_at.is_(None),
)
.one_or_none()
)
def _row_to_cache(row: ResolvedRow) -> dict:
return {
"subject_email": row.subject_email,
"subject_issuer": row.subject_issuer,
"account_id": str(row.account_id) if row.account_id else None,
"token_id": str(row.token_id),
"expires_at": row.expires_at.isoformat() if row.expires_at else None,
}
def _row_from_cache(data: dict) -> ResolvedRow:
return ResolvedRow(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
token_id=uuid.UUID(data["token_id"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
)
# ============================================================================
# Decorator — route-level bearer gate
# ============================================================================
class Accepts(StrEnum):
USER_ACCOUNT = "user_account"
USER_EXT_SSO = "user_ext_sso"
APP = "app"
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
}
_authenticator: BearerAuthenticator | None = None
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
global _authenticator
_authenticator = authenticator
def get_authenticator() -> BearerAuthenticator:
if _authenticator is None:
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
return _authenticator
def _extract_bearer(req) -> str | None:
header = req.headers.get("Authorization", "")
scheme, _, value = header.partition(" ")
if scheme.lower() != "bearer" or not value:
return None
return value.strip()
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable:
"""Opt-in: omitting it leaves the route unauthenticated.
Coexists with legacy ``app-`` keys (tenant+app scoped, resolved in
``service_api/wraps.py``) and user-level OAuth bearers (resolved here).
"""
def wrap(fn: Callable) -> Callable:
@wraps(fn)
def inner(*args, **kwargs):
token = _extract_bearer(request)
if token is None:
raise Unauthorized("missing bearer token")
# app- keys bypass the OAuth authenticator (work even when disabled).
if token.startswith("app-"):
if Accepts.APP not in accept:
raise Unauthorized("app-scoped keys not accepted here")
return fn(*args, **kwargs)
if _authenticator is None:
raise ServiceUnavailable(
"bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable"
)
try:
ctx = get_authenticator().authenticate(token)
except InvalidBearer as e:
raise Unauthorized(str(e))
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
raise Forbidden("token subject type not accepted here")
g.auth_ctx = ctx
return fn(*args, **kwargs)
return inner
return wrap
# ============================================================================
# Wiring — called once from the app factory
# ============================================================================
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
return TokenKindRegistry([
TokenKind(
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({"full"}),
source="oauth_account",
resolver=oauth.for_account(),
),
TokenKind(
prefix="dfoe_",
subject_type=SubjectType.EXTERNAL_SSO,
scopes=frozenset({"apps:run"}),
source="oauth_external_sso",
resolver=oauth.for_external_sso(),
),
])
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
registry = build_registry(session_factory, redis_client)
auth = BearerAuthenticator(registry)
bind_authenticator(auth)
return auth

109
api/libs/rate_limit.py Normal file
View File

@ -0,0 +1,109 @@
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
window Redis ZSET). Apply after auth decorators so scopes can read
``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed
in-handler. RFC-8628 ``slow_down`` is inline its response shape isn't
generic 429. Spec: docs/specs/v1.0/server/security.md.
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import timedelta
from enum import StrEnum
from functools import wraps
from typing import Callable
from flask import g, request, session
from werkzeug.exceptions import TooManyRequests
from libs.helper import RateLimiter, extract_remote_ip
class RateLimitScope(StrEnum):
IP = "ip"
SESSION = "session"
ACCOUNT = "account"
SUBJECT_EMAIL = "subject_email"
TOKEN_ID = "token_id"
@dataclass(frozen=True, slots=True)
class RateLimit:
limit: int
window: timedelta
scopes: tuple[RateLimitScope, ...]
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
def _one_key(scope: RateLimitScope) -> str:
match scope:
case RateLimitScope.IP:
return f"ip:{extract_remote_ip(request) or 'unknown'}"
case RateLimitScope.SESSION:
return f"session:{session.get('_id', 'anon')}"
case RateLimitScope.ACCOUNT:
ctx = getattr(g, "auth_ctx", None)
if ctx and ctx.account_id:
return f"account:{ctx.account_id}"
return "account:anon"
case RateLimitScope.SUBJECT_EMAIL:
ctx = getattr(g, "auth_ctx", None)
if ctx and ctx.subject_email:
return f"subject:{ctx.subject_email}"
return "subject:anon"
case RateLimitScope.TOKEN_ID:
ctx = getattr(g, "auth_ctx", None)
if ctx and ctx.token_id:
return f"token:{ctx.token_id}"
return "token:anon"
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
return "|".join(_one_key(s) for s in scopes)
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
return "rl:" + "+".join(s.value for s in scopes)
def _build_limiter(spec: RateLimit) -> RateLimiter:
return RateLimiter(
prefix=_limiter_prefix(spec.scopes),
max_attempts=spec.limit,
time_window=int(spec.window.total_seconds()),
)
def rate_limit(spec: RateLimit) -> Callable:
"""Apply after auth decorators that the scopes read from."""
limiter = _build_limiter(spec)
def wrap(fn: Callable) -> Callable:
@wraps(fn)
def inner(*args, **kwargs):
key = _composite_key(spec.scopes)
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")
limiter.increment_rate_limit(key)
return fn(*args, **kwargs)
return inner
return wrap
def enforce(spec: RateLimit, *, key: str) -> None:
"""Imperative form — caller composes the bucket key to match scope
semantics (the key is opaque here).
"""
limiter = _build_limiter(spec)
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")
limiter.increment_rate_limit(key)

View File

@ -0,0 +1,102 @@
"""add oauth_access_tokens table
Revision ID: d4a5e1f3c9b7
Revises: 227822d22895, b69ca54b9208, 2a3aebbbf4bb
Create Date: 2026-04-23 22:00:00.000000
Merges the three open heads at time of authoring (add_workflow_comments_table,
add_chatbot_color_theme, add_app_tracing) into a single parent so the new
oauth_access_tokens table sits on a definite linear chain thereafter.
Table stores user-level OAuth bearer tokens minted via the device-flow grant
(difyctl auth login). PAT storage (personal_access_tokens) is a separate
table not added in this migration.
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "d4a5e1f3c9b7"
down_revision = ("227822d22895", "b69ca54b9208", "2a3aebbbf4bb")
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"oauth_access_tokens",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
primary_key=True,
),
sa.Column("subject_email", sa.Text(), nullable=False),
sa.Column("subject_issuer", sa.Text(), nullable=True),
sa.Column("account_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("device_label", sa.Text(), nullable=False),
sa.Column("prefix", sa.String(length=8), nullable=False),
sa.Column("token_hash", sa.String(length=64), nullable=True, unique=True),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("NOW()"),
nullable=False,
),
sa.Column("last_used_at", sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["account_id"],
["accounts.id"],
name="fk_oauth_access_tokens_account_id",
ondelete="SET NULL",
),
)
op.create_index(
"idx_oauth_subject_email",
"oauth_access_tokens",
["subject_email"],
postgresql_where=sa.text("revoked_at IS NULL"),
)
op.create_index(
"idx_oauth_account",
"oauth_access_tokens",
["account_id"],
postgresql_where=sa.text("revoked_at IS NULL AND account_id IS NOT NULL"),
)
op.create_index(
"idx_oauth_client",
"oauth_access_tokens",
["subject_email", "client_id"],
postgresql_where=sa.text("revoked_at IS NULL"),
)
op.create_index(
"idx_oauth_token_hash",
"oauth_access_tokens",
["token_hash"],
postgresql_where=sa.text("revoked_at IS NULL"),
)
# Partial unique index — rotate-in-place keyed on (subject, client, device).
# subject_issuer NULL vs populated distinguishes account vs external-SSO rows
# for the same email, because Postgres treats NULL as distinct.
op.create_index(
"uq_oauth_active_per_device",
"oauth_access_tokens",
["subject_email", "subject_issuer", "client_id", "device_label"],
unique=True,
postgresql_where=sa.text("revoked_at IS NULL"),
)
def downgrade():
op.drop_index("uq_oauth_active_per_device", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_token_hash", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_client", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_account", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_subject_email", table_name="oauth_access_tokens")
op.drop_table("oauth_access_tokens")

View File

@ -73,7 +73,7 @@ from .model import (
TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
from .oauth import DatasourceOauthParamConfig, DatasourceProvider, OAuthAccessToken
from .provider import (
LoadBalancingModelConfig,
Provider,
@ -177,6 +177,7 @@ __all__ = [
"MessageChain",
"MessageFeedback",
"MessageFile",
"OAuthAccessToken",
"OperationLog",
"PinnedConversation",
"Provider",

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any
from typing import Any, Optional
import sqlalchemy as sa
from sqlalchemy import func
@ -84,3 +84,38 @@ class DatasourceOauthTenantParamConfig(TypeBase):
onupdate=func.current_timestamp(),
init=False,
)
class OAuthAccessToken(TypeBase):
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account);
account_id NULL + subject_issuer dfoe_ (external SSO, EE-only).
Partial unique index on (subject_email, subject_issuer, client_id,
device_label) WHERE revoked_at IS NULL lets re-login rotate in place.
"""
__tablename__ = "oauth_access_tokens"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
subject_email: Mapped[str] = mapped_column(sa.Text, nullable=False)
client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False)
device_label: Mapped[str] = mapped_column(sa.Text, nullable=False)
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
subject_issuer: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True, default=None)
account_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True, default=None)
token_hash: Mapped[Optional[str]] = mapped_column(sa.String(64), nullable=True, default=None)
last_used_at: Mapped[Optional[datetime]] = mapped_column(
sa.DateTime(timezone=True), nullable=True, default=None
)
revoked_at: Mapped[Optional[datetime]] = mapped_column(
sa.DateTime(timezone=True), nullable=True, default=None
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
)

View File

@ -0,0 +1,57 @@
"""DELETE oauth_access_tokens past retention. Revocation is UPDATE
(token_id stays for audits) so rows accumulate across re-logins, and
expired-but-never-presented rows have no hard-expire trigger both get
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
"""
from __future__ import annotations
import logging
import time
from datetime import UTC, datetime, timedelta
import click
from sqlalchemy import delete, or_, select
import app
from configs import dify_config
from extensions.ext_database import db
from models.oauth import OAuthAccessToken
logger = logging.getLogger(__name__)
DELETE_BATCH_SIZE = 500
@app.celery.task(queue="retention")
def clean_oauth_access_tokens_task():
click.echo(click.style("Start clean oauth_access_tokens.", fg="green"))
retention_days = int(dify_config.OAUTH_ACCESS_TOKEN_RETENTION_DAYS)
cutoff = datetime.now(UTC) - timedelta(days=retention_days)
start_at = time.perf_counter()
candidates = or_(
OAuthAccessToken.revoked_at < cutoff,
# Zombies: expired but never re-presented, so middleware never flipped them.
(OAuthAccessToken.revoked_at.is_(None))
& (OAuthAccessToken.expires_at < cutoff),
)
total = 0
while True:
ids = db.session.scalars(
select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)
).all()
if not ids:
break
db.session.execute(
delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids))
)
db.session.commit()
total += len(ids)
end_at = time.perf_counter()
click.echo(click.style(
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d "
f"in {end_at - start_at:.2f}s",
fg="green",
))

View File

@ -106,6 +106,15 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
return EnterpriseRequest.send_request(
"POST",
"/device-flow/sso-initiate",
json={"signed_state": signed_state},
raise_for_status=True,
)
@classmethod
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
"""

View File

@ -0,0 +1,417 @@
"""Device-flow service layer: Redis state machine, OAuth token mint
(DB upsert + plaintext generation), and TTL policy. Specs:
docs/specs/v1.0/server/{device-flow.md, tokens.md}.
"""
from __future__ import annotations
import hashlib
import json
import logging
import os
import secrets
import time
import uuid
from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
from models.oauth import OAuthAccessToken
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# ============================================================================
# Redis state machine — device_code + user_code ephemeral state
# ============================================================================
DEVICE_CODE_KEY_FMT = "device_code:{code}"
USER_CODE_KEY_FMT = "user_code:{code}"
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
USER_CODE_SEGMENT_LEN = 4
USER_CODE_MAX_CLAIM_ATTEMPTS = 5
DEFAULT_POLL_INTERVAL_SECONDS = 5 # RFC 8628 minimum
class DeviceFlowStatus(StrEnum):
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
class SlowDownDecision(StrEnum):
OK = "ok"
SLOW_DOWN = "slow_down"
@dataclass
class DeviceFlowState:
"""``minted_token`` is plaintext between approve and the next poll;
DEL'd after the poll reads it.
"""
user_code: str
client_id: str
device_label: str
status: DeviceFlowStatus
subject_email: str | None = None
account_id: str | None = None
subject_issuer: str | None = None
minted_token: str | None = None
token_id: str | None = None
created_at: str = ""
created_ip: str = ""
last_poll_at: str = ""
poll_payload: dict | None = field(default=None)
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, raw: str) -> "DeviceFlowState":
data = json.loads(raw)
if "status" in data:
data["status"] = DeviceFlowStatus(data["status"])
return cls(**data)
def _random_device_code() -> str:
return "dc_" + secrets.token_urlsafe(24)
def _random_user_code_segment() -> str:
return "".join(secrets.choice(USER_CODE_ALPHABET) for _ in range(USER_CODE_SEGMENT_LEN))
def _random_user_code() -> str:
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
class StateNotFound(Exception):
pass
class InvalidTransition(Exception):
pass
class UserCodeExhausted(Exception):
pass
class DeviceFlowRedis:
def __init__(self, redis_client) -> None:
self._redis = redis_client
def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]:
device_code = _random_device_code()
user_code = self._claim_user_code(device_code)
state = DeviceFlowState(
user_code=user_code,
client_id=client_id,
device_label=device_label,
status=DeviceFlowStatus.PENDING,
created_at=datetime.now(UTC).isoformat(),
created_ip=created_ip,
)
self._redis.setex(
DEVICE_CODE_KEY_FMT.format(code=device_code),
DEVICE_FLOW_TTL_SECONDS,
state.to_json(),
)
return device_code, user_code, DEVICE_FLOW_TTL_SECONDS
def _claim_user_code(self, device_code: str) -> str:
for _ in range(USER_CODE_MAX_CLAIM_ATTEMPTS):
user_code = _random_user_code()
key = USER_CODE_KEY_FMT.format(code=user_code)
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
if ok:
return user_code
raise UserCodeExhausted("could not allocate a unique user_code in 5 attempts")
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
if not raw_dc:
return None
device_code = raw_dc.decode() if isinstance(raw_dc, (bytes, bytearray)) else raw_dc
state = self._load_state(device_code)
if state is None:
return None
return device_code, state
def load_by_device_code(self, device_code: str) -> DeviceFlowState | None:
return self._load_state(device_code)
def _load_state(self, device_code: str) -> DeviceFlowState | None:
raw = self._redis.get(DEVICE_CODE_KEY_FMT.format(code=device_code))
if not raw:
return None
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
try:
return DeviceFlowState.from_json(text_)
except (ValueError, KeyError):
logger.error("device_flow: corrupt state for %s", device_code)
return None
def approve(
self,
device_code: str,
subject_email: str,
account_id: str | None,
minted_token: str,
token_id: str,
subject_issuer: str | None = None,
poll_payload: dict | None = None,
) -> None:
state = self._load_state(device_code)
if state is None:
raise StateNotFound(device_code)
if state.status is not DeviceFlowStatus.PENDING:
raise InvalidTransition(f"cannot approve {state.status}")
state.status = DeviceFlowStatus.APPROVED
state.subject_email = subject_email
state.account_id = account_id
state.subject_issuer = subject_issuer
state.minted_token = minted_token
state.token_id = token_id
state.poll_payload = poll_payload
new_ttl = self._remaining_ttl(device_code, floor=APPROVED_TTL_SECONDS_MIN)
self._redis.setex(DEVICE_CODE_KEY_FMT.format(code=device_code), new_ttl, state.to_json())
def deny(self, device_code: str) -> None:
state = self._load_state(device_code)
if state is None:
raise StateNotFound(device_code)
if state.status is not DeviceFlowStatus.PENDING:
raise InvalidTransition(f"cannot deny {state.status}")
state.status = DeviceFlowStatus.DENIED
self._redis.setex(
DEVICE_CODE_KEY_FMT.format(code=device_code),
self._remaining_ttl(device_code, floor=1),
state.to_json(),
)
def consume_on_poll(self, device_code: str) -> DeviceFlowState | None:
"""Race-safe via DEL: concurrent polls — one wins, the other gets
None and the caller maps that to expired_token.
"""
state = self._load_state(device_code)
if state is None:
return None
if state.status is DeviceFlowStatus.PENDING:
return None
self._redis.delete(
DEVICE_CODE_KEY_FMT.format(code=device_code),
USER_CODE_KEY_FMT.format(code=state.user_code),
)
return state
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
now = time.time()
key = f"device_code:{device_code}:last_poll"
prev_raw = self._redis.get(key)
self._redis.setex(key, DEVICE_FLOW_TTL_SECONDS, str(now))
if prev_raw is None:
return SlowDownDecision.OK
prev_s = prev_raw.decode() if isinstance(prev_raw, (bytes, bytearray)) else prev_raw
try:
prev = float(prev_s)
except ValueError:
return SlowDownDecision.OK
if now - prev < interval_seconds:
return SlowDownDecision.SLOW_DOWN
return SlowDownDecision.OK
def _remaining_ttl(self, device_code: str, floor: int) -> int:
"""``max(remaining, floor)`` — guarantees the CLI has at least
``floor`` seconds to poll after a near-expiry approve.
"""
ttl = self._redis.ttl(DEVICE_CODE_KEY_FMT.format(code=device_code))
if ttl is None or ttl < 0:
return floor
return max(int(ttl), floor)
# ============================================================================
# Token mint — generate + upsert
# ============================================================================
OAUTH_BODY_BYTES = 32 # ~256 bits entropy
PREFIX_OAUTH_ACCOUNT = "dfoa_"
PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_"
@dataclass(frozen=True, slots=True)
class MintResult:
"""Plaintext token surfaces to the caller once."""
token: str
token_id: uuid.UUID
expires_at: datetime
@dataclass(frozen=True, slots=True)
class UpsertOutcome:
token_id: uuid.UUID
rotated: bool
old_hash: str | None
def generate_token(prefix: str) -> str:
return prefix + secrets.token_urlsafe(OAUTH_BODY_BYTES)
def sha256_hex(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def mint_oauth_token(
session: Session,
redis_client,
*,
subject_email: str,
subject_issuer: str | None,
account_id: str | None,
client_id: str,
device_label: str,
prefix: str,
ttl_days: int,
) -> MintResult:
"""Live row rotates in place via partial unique index
``uq_oauth_active_per_device``; hard-expired rows are excluded by the
index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is
deleted so stale AuthContext drops immediately.
"""
if prefix not in (PREFIX_OAUTH_ACCOUNT, PREFIX_OAUTH_EXTERNAL_SSO):
raise ValueError(f"unknown oauth prefix: {prefix!r}")
token = generate_token(prefix)
new_hash = sha256_hex(token)
expires_at = datetime.now(UTC) + timedelta(days=ttl_days)
outcome = _upsert(
session,
subject_email=subject_email,
subject_issuer=subject_issuer,
account_id=account_id,
client_id=client_id,
device_label=device_label,
prefix=prefix,
new_hash=new_hash,
expires_at=expires_at,
)
if outcome.rotated and outcome.old_hash:
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=outcome.old_hash))
return MintResult(token=token, token_id=outcome.token_id, expires_at=expires_at)
def _upsert(
session: Session,
*,
subject_email: str,
subject_issuer: str | None,
account_id: str | None,
client_id: str,
device_label: str,
prefix: str,
new_hash: str,
expires_at: datetime,
) -> UpsertOutcome:
# Snapshot prior live row's hash for Redis invalidation post-rotate.
prior = session.execute(
select(OAuthAccessToken.id, OAuthAccessToken.token_hash)
.where(
OAuthAccessToken.subject_email == subject_email,
OAuthAccessToken.subject_issuer.is_not_distinct_from(subject_issuer),
OAuthAccessToken.client_id == client_id,
OAuthAccessToken.device_label == device_label,
OAuthAccessToken.revoked_at.is_(None),
)
.limit(1)
).first()
old_hash = prior.token_hash if prior else None
insert_stmt = pg_insert(OAuthAccessToken).values(
subject_email=subject_email,
subject_issuer=subject_issuer,
account_id=account_id,
client_id=client_id,
device_label=device_label,
prefix=prefix,
token_hash=new_hash,
expires_at=expires_at,
)
upsert_stmt = insert_stmt.on_conflict_do_update(
index_elements=["subject_email", "subject_issuer", "client_id", "device_label"],
index_where=OAuthAccessToken.revoked_at.is_(None),
set_={
"token_hash": insert_stmt.excluded.token_hash,
"prefix": insert_stmt.excluded.prefix,
"account_id": insert_stmt.excluded.account_id,
"expires_at": insert_stmt.excluded.expires_at,
"created_at": func.now(),
"last_used_at": None,
},
).returning(OAuthAccessToken.id)
row = session.execute(upsert_stmt).first()
session.commit()
token_id = uuid.UUID(str(row.id))
return UpsertOutcome(
token_id=token_id,
rotated=prior is not None,
old_hash=old_hash,
)
# ============================================================================
# TTL policy — days new OAuth tokens live
# ============================================================================
DEFAULT_OAUTH_TTL_DAYS = 14
MIN_TTL_DAYS = 1
MAX_TTL_DAYS = 365
_TTL_ENV_VAR = "OAUTH_TTL_DAYS"
def oauth_ttl_days(tenant_id: str | None = None) -> int:
"""``OAUTH_TTL_DAYS`` env, else default. EE tenant-level lookup
is deferred; when it lands it wins over the env (Redis-cached 60s).
"""
_ = tenant_id
raw = os.environ.get(_TTL_ENV_VAR)
if raw is None:
return DEFAULT_OAUTH_TTL_DAYS
try:
value = int(raw)
except ValueError:
logger.warning(
"%s=%r is not an int; falling back to %d",
_TTL_ENV_VAR, raw, DEFAULT_OAUTH_TTL_DAYS,
)
return DEFAULT_OAUTH_TTL_DAYS
if value < MIN_TTL_DAYS:
logger.warning("%s=%d below min %d; clamping", _TTL_ENV_VAR, value, MIN_TTL_DAYS)
return MIN_TTL_DAYS
if value > MAX_TTL_DAYS:
logger.warning("%s=%d above max %d; clamping", _TTL_ENV_VAR, value, MAX_TTL_DAYS)
return MAX_TTL_DAYS
return value

View File

@ -0,0 +1,96 @@
'use client'
import type { FC } from 'react'
import { useState } from 'react'
import { deviceApproveAccount, deviceDenyAccount } from '@/service/device-flow'
type Props = {
userCode: string
accountEmail?: string
defaultWorkspace?: string
onApproved: () => void
onDenied: () => void
onError: (message: string) => void
}
/**
* AuthorizeAccount is the account-branch authorize screen. Called with a
* live console session already established (user bounced through /signin).
* Posts to /console/api/oauth/device/{approve,deny}; these endpoints mint
* the dfoa_ token server-side.
*/
const AuthorizeAccount: FC<Props> = ({
userCode, accountEmail, defaultWorkspace, onApproved, onDenied, onError,
}) => {
const [busy, setBusy] = useState(false)
const approve = async () => {
setBusy(true)
try {
await deviceApproveAccount(userCode)
onApproved()
}
catch (e: any) {
onError(e?.message || 'Approve failed')
}
finally {
setBusy(false)
}
}
const deny = async () => {
setBusy(true)
try {
await deviceDenyAccount(userCode)
onDenied()
}
catch (e: any) {
onError(e?.message || 'Deny failed')
}
finally {
setBusy(false)
}
}
return (
<div className="flex flex-col gap-6">
<div>
<h2 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h2>
<p className="mt-2 text-sm text-text-secondary">
Dify CLI (difyctl) is requesting access to your account.
{' '}If you did not start this from your terminal, click Cancel.
</p>
</div>
<div className="rounded-lg border border-components-panel-border bg-components-panel-bg px-4 py-3">
{accountEmail && (
<p className="text-sm text-text-secondary">
Signed in as <span className="font-medium text-text-primary">{accountEmail}</span>
</p>
)}
{defaultWorkspace && (
<p className="mt-1 text-sm text-text-secondary">
Default workspace: <span className="font-medium text-text-primary">{defaultWorkspace}</span>
</p>
)}
</div>
<div className="flex gap-3">
<button
onClick={approve}
disabled={busy}
className="flex-1 rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover disabled:opacity-50"
>
Authorize
</button>
<button
onClick={deny}
disabled={busy}
className="flex-1 rounded-lg border border-components-button-secondary-border bg-components-button-secondary-bg px-4 py-3 text-components-button-secondary-text font-medium hover:bg-components-button-secondary-bg-hover disabled:opacity-50"
>
Cancel
</button>
</div>
</div>
)
}
export default AuthorizeAccount

View File

@ -0,0 +1,96 @@
'use client'
import type { FC } from 'react'
import { useEffect, useState } from 'react'
import type { ApprovalContext } from '@/service/device-flow'
import { approveExternal, fetchApprovalContext } from '@/service/device-flow'
type Props = {
onApproved: () => void
onError: (message: string) => void
}
/**
* AuthorizeSSO is the external-SSO branch authorize screen. On mount it
* fetches /v1/oauth/device/approval-context to learn subject_email, issuer,
* user_code, and csrf_token from the device_approval_grant cookie. On
* Approve click, posts /v1/oauth/device/approve-external with the CSRF header.
*
* The user_code in state is bound to the cookie by server; we do not accept
* one from the URL because the SSO branch deliberately detaches from the
* pre-SSO ?user_code=... query param.
*/
const AuthorizeSSO: FC<Props> = ({ onApproved, onError }) => {
const [ctx, setCtx] = useState<ApprovalContext | null>(null)
const [busy, setBusy] = useState(false)
const [loadErr, setLoadErr] = useState<string | null>(null)
useEffect(() => {
let cancelled = false
fetchApprovalContext()
.then((c) => { if (!cancelled) setCtx(c) })
.catch((e: any) => {
if (!cancelled)
setLoadErr(e?.message || 'Failed to load session')
})
return () => { cancelled = true }
}, [])
const approve = async () => {
if (!ctx) return
setBusy(true)
try {
await approveExternal(ctx, ctx.user_code)
onApproved()
}
catch (e: any) {
onError(e?.message || 'Approve failed')
}
finally {
setBusy(false)
}
}
if (loadErr) {
return (
<div>
<h2 className="text-2xl font-semibold text-text-primary">This session is no longer valid</h2>
<p className="mt-2 text-sm text-text-secondary">
Run <code className="rounded bg-components-panel-bg px-1">difyctl auth login</code> again to start a new sign-in.
</p>
</div>
)
}
if (!ctx) {
return <div className="text-sm text-text-secondary">Loading session</div>
}
return (
<div className="flex flex-col gap-6">
<div>
<h2 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h2>
<p className="mt-2 text-sm text-text-secondary">
Dify CLI (difyctl) is requesting access via SSO. If you did not start
this from your terminal, close this tab.
</p>
</div>
<div className="rounded-lg border border-components-panel-border bg-components-panel-bg px-4 py-3">
<p className="text-sm text-text-secondary">
Signed in as <span className="font-medium text-text-primary">{ctx.subject_email}</span>
</p>
<p className="mt-1 text-sm text-text-secondary">
Issuer: <span className="font-medium text-text-primary">{ctx.subject_issuer}</span>
</p>
</div>
<button
onClick={approve}
disabled={busy}
className="rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover disabled:opacity-50"
>
Authorize
</button>
</div>
)
}
export default AuthorizeSSO

View File

@ -0,0 +1,60 @@
'use client'
import type { FC } from 'react'
import { useRouter } from '@/next/navigation'
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
type Props = {
userCode: string
ssoAvailable: boolean
}
/**
* Chooser renders the two-button device-auth login selector. Account button
* seeds postLoginRedirect + navigates to /signin so every existing account
* login method (password / email-code / social OAuth / account-SSO) flows
* through its usual plumbing. SSO button hits /v1/oauth/device/sso-initiate
* directly the SSO branch skips /signin entirely.
*
* v1.0 scope: only account-SSO honours postLoginRedirect (via sso-auth's
* return_to plumbing). Password / email-code / social-OAuth users land on
* /signin's default post-login target and manually return to the /device
* URL printed by the CLI. That's not great UX; a follow-up milestone
* generalises post-signin redirect to all methods.
*/
const Chooser: FC<Props> = ({ userCode, ssoAvailable }) => {
const router = useRouter()
const onAccount = () => {
setPostLoginRedirect(`/device?user_code=${encodeURIComponent(userCode)}`)
router.push('/signin')
}
const onSSO = () => {
// Full-page navigation, not router.push — /v1/oauth/device/sso-initiate
// issues a 302 to the IdP. Next's client router can't follow cross-
// origin redirects; a plain window.location assignment handles it.
window.location.href = `/v1/oauth/device/sso-initiate?user_code=${encodeURIComponent(userCode)}`
}
return (
<div className="flex flex-col gap-3">
<button
onClick={onAccount}
className="rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover"
>
Sign in with Dify account
</button>
{ssoAvailable && (
<button
onClick={onSSO}
className="rounded-lg border border-components-button-secondary-border bg-components-button-secondary-bg px-4 py-3 text-components-button-secondary-text font-medium hover:bg-components-button-secondary-bg-hover"
>
Sign in with SSO
</button>
)}
</div>
)
}
export default Chooser

View File

@ -0,0 +1,45 @@
'use client'
import type { FC } from 'react'
import { useCallback } from 'react'
import { normaliseUserCodeInput } from '../utils/user-code'
type Props = {
value: string
onChange: (normalised: string) => void
disabled?: boolean
autoFocus?: boolean
}
/**
* CodeInput renders the user_code text field with live normalisation
* (uppercase, reduced alphabet, XXXX-XXXX hyphenation).
*
* The onChange callback receives the normalised value only the parent does
* not need to run validation itself.
*/
const CodeInput: FC<Props> = ({ value, onChange, disabled, autoFocus }) => {
const handle = useCallback((raw: string) => {
onChange(normaliseUserCodeInput(raw))
}, [onChange])
return (
<input
type="text"
inputMode="text"
autoCapitalize="characters"
autoComplete="off"
spellCheck={false}
placeholder="ABCD-1234"
maxLength={9}
aria-label="one-time code"
className="w-full rounded-lg border border-components-input-border-normal bg-components-input-bg-normal px-4 py-3 text-center text-2xl font-mono tracking-wider text-text-primary focus:border-components-input-border-active focus:outline-none"
value={value}
disabled={disabled}
autoFocus={autoFocus}
onChange={e => handle(e.target.value)}
/>
)
}
export default CodeInput

173
web/app/device/page.tsx Normal file
View File

@ -0,0 +1,173 @@
'use client'
import { useEffect, useState } from 'react'
import { useSearchParams } from '@/next/navigation'
import { useQuery } from '@tanstack/react-query'
import { systemFeaturesQueryOptions } from '@/service/system-features'
import { commonQueryKeys, userProfileQueryOptions } from '@/service/use-common'
import { post } from '@/service/base'
import type { ICurrentWorkspace } from '@/models/common'
import { deviceLookup } from '@/service/device-flow'
import CodeInput from './components/code-input'
import Chooser from './components/chooser'
import AuthorizeAccount from './components/authorize-account'
import AuthorizeSSO from './components/authorize-sso'
import { isValidUserCode } from './utils/user-code'
type View =
| { kind: 'code_entry' }
| { kind: 'chooser'; userCode: string }
| { kind: 'authorize_account'; userCode: string }
| { kind: 'authorize_sso' }
| { kind: 'success' }
| { kind: 'error_expired' }
export default function DevicePage() {
const searchParams = useSearchParams()
const urlUserCode = (searchParams.get('user_code') || '').trim().toUpperCase()
const ssoVerified = searchParams.get('sso_verified') === '1'
const [typed, setTyped] = useState('')
const [view, setView] = useState<View>({ kind: 'code_entry' })
const [errMsg, setErrMsg] = useState<string | null>(null)
// Account subject + workspace identity (for the authorize-account screen).
// Logged-out is a valid landing state on /device — disable refetch storms
// and skip workspace probe until profile resolves (avoids /current + chained
// /refresh-token 401 loops while the user is still entering the code).
const { data: userResp, isError: profileErr } = useQuery({
...userProfileQueryOptions(),
throwOnError: false,
retry: false,
refetchOnWindowFocus: false,
refetchOnMount: false,
})
const account = userResp?.profile
const { data: currentWorkspace } = useQuery<ICurrentWorkspace>({
queryKey: commonQueryKeys.currentWorkspace,
queryFn: () => post<ICurrentWorkspace>('/workspaces/current'),
enabled: !!account && !profileErr,
retry: false,
refetchOnWindowFocus: false,
})
const { data: sys } = useQuery(systemFeaturesQueryOptions())
// Device-flow SSO branch uses external-user (webapp) SSO, not console SSO —
// backend mints EXTERNAL_SSO tokens via Enterprise's external ACS. Gate on
// webapp_auth.{enabled, allow_sso} + a configured webapp SSO protocol.
const ssoAvailable = !!sys?.webapp_auth?.enabled
&& !!sys?.webapp_auth?.allow_sso
&& (sys?.webapp_auth?.sso_config?.protocol || '') !== ''
// URL-driven view transitions. Only advances while the user is still on
// the entry/chooser screens — never clobbers terminal views (success /
// error_expired / authorize_*) when userProfile refetches.
useEffect(() => {
if (view.kind !== 'code_entry' && view.kind !== 'chooser') return
if (ssoVerified) {
setView({ kind: 'authorize_sso' })
return
}
if (urlUserCode && isValidUserCode(urlUserCode)) {
if (account)
setView({ kind: 'authorize_account', userCode: urlUserCode })
else
setView({ kind: 'chooser', userCode: urlUserCode })
}
}, [urlUserCode, ssoVerified, account, view.kind])
const onContinue = async () => {
if (!isValidUserCode(typed)) return
try {
const reply = await deviceLookup(typed)
if (!reply.valid) {
setView({ kind: 'error_expired' })
return
}
}
catch {
setView({ kind: 'error_expired' })
return
}
if (account) setView({ kind: 'authorize_account', userCode: typed })
else setView({ kind: 'chooser', userCode: typed })
}
return (
<main className="mx-auto flex min-h-screen max-w-lg flex-col items-center justify-center px-6 py-10">
<div className="w-full rounded-xl border border-components-panel-border bg-components-panel-bg p-8 shadow-sm">
{view.kind === 'code_entry' && (
<div className="flex flex-col gap-5">
<div>
<h1 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h1>
<p className="mt-2 text-sm text-text-secondary">
Enter the code shown in your terminal.
</p>
</div>
<CodeInput value={typed} onChange={setTyped} autoFocus />
<button
onClick={onContinue}
disabled={!isValidUserCode(typed)}
className="rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover disabled:opacity-50"
>
Continue
</button>
</div>
)}
{view.kind === 'chooser' && (
<div className="flex flex-col gap-5">
<div>
<h1 className="text-2xl font-semibold text-text-primary">Sign in to authorize</h1>
<p className="mt-2 text-sm text-text-secondary">
Code <span className="font-mono">{view.userCode}</span> is valid. Choose how to sign in.
</p>
</div>
<Chooser userCode={view.userCode} ssoAvailable={ssoAvailable} />
</div>
)}
{view.kind === 'authorize_account' && (
<AuthorizeAccount
userCode={view.userCode}
accountEmail={account?.email}
defaultWorkspace={currentWorkspace?.name}
onApproved={() => setView({ kind: 'success' })}
onDenied={() => setView({ kind: 'error_expired' })}
onError={e => setErrMsg(e)}
/>
)}
{view.kind === 'authorize_sso' && (
<AuthorizeSSO
onApproved={() => setView({ kind: 'success' })}
onError={e => setErrMsg(e)}
/>
)}
{view.kind === 'success' && (
<div>
<h1 className="text-2xl font-semibold text-text-primary">You&apos;re signed in</h1>
<p className="mt-2 text-sm text-text-secondary">Return to your terminal to continue.</p>
</div>
)}
{view.kind === 'error_expired' && (
<div>
<h1 className="text-2xl font-semibold text-text-primary">This code is no longer valid</h1>
<p className="mt-2 text-sm text-text-secondary">
The code may have expired or already been used. Run
{' '}
<code className="rounded bg-components-panel-bg px-1">difyctl auth login</code>
{' '}
again to get a new one.
</p>
</div>
)}
{errMsg && (
<p className="mt-4 text-sm text-text-destructive">{errMsg}</p>
)}
</div>
</main>
)
}

View File

@ -0,0 +1,37 @@
// user-code.ts — input normalisation + validation for the RFC 8628
// 8-character user_code format the CLI prints to stderr.
//
// Format: XXXX-XXXX, uppercase, reduced alphabet (no 0/O, 1/I/l, 2/Z). Low
// entropy by design — humans type it — so the server-side rate-limit + TTL +
// single-use properties are what defend it, not the alphabet.
export const USER_CODE_ALPHABET = 'ABCDEFGHJKLMNPQRSTUVWXY3456789' // excludes 0 O 1 I L 2 Z
/**
* normaliseUserCodeInput prepares raw input for display in the code field:
* strips non-alphanumerics, uppercases, drops disallowed characters, and
* inserts the hyphen after the fourth accepted char.
*
* Returns at most 9 chars ("XXXX-XXXX"); longer input is truncated.
*/
export function normaliseUserCodeInput(raw: string): string {
const cleaned: string[] = []
for (const ch of raw.toUpperCase()) {
if (USER_CODE_ALPHABET.includes(ch))
cleaned.push(ch)
if (cleaned.length === 8)
break
}
if (cleaned.length <= 4)
return cleaned.join('')
return `${cleaned.slice(0, 4).join('')}-${cleaned.slice(4).join('')}`
}
/**
* isValidUserCode tests whether the normalised form is a complete XXXX-XXXX
* token suitable for submission to /console/api/oauth/device/lookup.
*/
export function isValidUserCode(normalised: string): boolean {
return /^[A-Z0-9]{4}-[A-Z0-9]{4}$/.test(normalised)
&& [...normalised.replace('-', '')].every(c => USER_CODE_ALPHABET.includes(c))
}

View File

@ -1,15 +1,68 @@
let postLoginRedirect: string | null = null
// Persists target across full-page redirects within the same tab (social
// OAuth, SSO IdP bounce). sessionStorage is tab-scoped so concurrent
// /device tabs don't clobber each other. 15-min TTL drops stale values.
// Same-origin + exact-path whitelist prevents open-redirect.
//
// Signup-via-email-link opening in a new tab is out of scope — that tab
// starts with an empty sessionStorage and falls to /apps default.
const KEY = 'dify_post_login_redirect'
const TTL_MS = 15 * 60 * 1000
const ALLOWED: Record<string, ReadonlySet<string>> = {
'/device': new Set(['user_code', 'sso_verified']),
'/account/oauth/authorize': new Set(['client_id', 'scope', 'state', 'redirect_uri']),
}
function validate(target: string): string | null {
if (typeof window === 'undefined') return null
try {
const url = new URL(target, window.location.origin)
if (url.origin !== window.location.origin) return null
const allowedKeys = ALLOWED[url.pathname]
if (!allowedKeys) return null
for (const key of url.searchParams.keys()) {
if (!allowedKeys.has(key)) return null
}
return url.pathname + (url.search || '')
}
catch {
return null
}
}
export const setPostLoginRedirect = (value: string | null) => {
postLoginRedirect = value
}
export const resolvePostLoginRedirect = () => {
if (postLoginRedirect) {
const redirectUrl = postLoginRedirect
postLoginRedirect = null
return redirectUrl
if (typeof window === 'undefined') return
if (value === null) {
try { sessionStorage.removeItem(KEY) } catch {}
return
}
const safe = validate(value)
if (!safe) return
try {
sessionStorage.setItem(KEY, JSON.stringify({ target: safe, ts: Date.now() }))
}
catch {}
}
export const resolvePostLoginRedirect = (): string | null => {
if (typeof window === 'undefined') return null
let raw: string | null = null
try {
raw = sessionStorage.getItem(KEY)
sessionStorage.removeItem(KEY)
}
catch {
return null
}
if (!raw) return null
try {
const parsed = JSON.parse(raw)
if (typeof parsed?.target !== 'string' || typeof parsed?.ts !== 'number') return null
if (Date.now() - parsed.ts > TTL_MS) return null
return validate(parsed.target)
}
catch {
return null
}
return null
}

View File

@ -30,6 +30,20 @@ const nextConfig: NextConfig = {
},
]
},
// Anti-framing for device-flow surfaces. A framed /device page could UI-trick
// a victim with a valid device_approval_grant cookie into approving a
// device_code — functionally CSRF, bypasses the double-submit token. Deny
// framing outright on every device-flow route; no trusted embedder exists.
async headers() {
const antiFrame = [
{ key: 'X-Frame-Options', value: 'DENY' },
{ key: 'Content-Security-Policy', value: "frame-ancestors 'none'" },
]
return [
{ source: '/device', headers: antiFrame },
{ source: '/device/:path*', headers: antiFrame },
]
},
output: 'standalone',
compiler: {
removeConsole: isDev ? false : { exclude: ['warn', 'error'] },

View File

@ -794,6 +794,11 @@ export const request = async<T>(url: string, options = {}, otherOptions?: IOther
const [refreshErr] = await asyncRunSafe(refreshAccessTokenOrReLogin(TIME_OUT))
if (refreshErr === null)
return baseFetch<T>(url, options, otherOptionsForBaseFetch)
// /device is the device-flow chooser; logged-out is a valid state
// there. Redirecting to /signin loses the user_code context and
// the post-login flow lands on /apps instead of returning here.
if (location.pathname === `${basePath}/device`)
return Promise.reject(err)
if (location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) {
jumpTo(loginUrl)
return Promise.reject(err)

View File

@ -0,0 +1,84 @@
// Web-side calls into the Dify device-flow endpoints:
//
// /v1/oauth/device/lookup (public — GET, no auth, IP-rate-limited)
// /v1/oauth/device/approval-context (cookie-authed — GET)
// /v1/oauth/device/approve-external (cookie-authed + CSRF — POST)
// /console/api/oauth/device/approve (session-authed — POST)
// /console/api/oauth/device/deny (session-authed — POST)
//
// Approve/deny use the standard service/base helpers so they get console-
// session cookies automatically. Lookup + SSO-branch endpoints sit under
// /v1 so they ride the existing service-API gateway route.
import { del, post } from './base'
const DEVICE_BASE = '/v1/oauth/device'
// ----- Account branch --------------------------------------------------------
export type DeviceLookupReply = {
valid: boolean
expires_in_remaining: number
client_id: string
}
export async function deviceLookup(user_code: string): Promise<DeviceLookupReply> {
const res = await fetch(`${DEVICE_BASE}/lookup?user_code=${encodeURIComponent(user_code)}`, {
method: 'GET',
})
if (!res.ok) {
const body = await res.text().catch(() => '')
throw new Error(`lookup ${res.status}: ${body}`)
}
return res.json()
}
export const deviceApproveAccount = (user_code: string) =>
post<{ status: 'approved' }>('/oauth/device/approve', { body: { user_code } })
export const deviceDenyAccount = (user_code: string) =>
post<{ status: 'denied' }>('/oauth/device/deny', { body: { user_code } })
// ----- SSO branch (cookie-authed via /v1/oauth/device/*) --------------------
export type ApprovalContext = {
subject_email: string
subject_issuer: string
user_code: string
csrf_token: string
expires_at: string
}
export async function fetchApprovalContext(): Promise<ApprovalContext> {
const res = await fetch(`${DEVICE_BASE}/approval-context`, {
method: 'GET',
credentials: 'include',
})
if (!res.ok) {
const body = await res.text().catch(() => '')
throw new Error(`approval-context ${res.status}: ${body}`)
}
return res.json()
}
export async function approveExternal(ctx: ApprovalContext, user_code: string): Promise<void> {
const res = await fetch(`${DEVICE_BASE}/approve-external`, {
method: 'POST',
credentials: 'include',
headers: {
'Content-Type': 'application/json',
'X-CSRF-Token': ctx.csrf_token,
},
body: JSON.stringify({ user_code }),
})
if (!res.ok) {
const body = await res.text().catch(() => '')
throw new Error(`approve-external ${res.status}: ${body}`)
}
}
// ----- Export for future PAT revoke; noop in v1.0 --------------------------
// Intentionally left out: personal_access_tokens endpoints are not in this
// milestone; see docs/specs/v1.0/README.md.
void del // keep import live for the TypeScript linter without surfacing usage