From 813da349ec00e5b1a1a28d971713ef11a6b95fcb Mon Sep 17 00:00:00 2001 From: GareArc Date: Sun, 26 Apr 2026 23:05:07 -0700 Subject: [PATCH] fix(api,web): post-review hardening for OAuth device flow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - api: account-flow stores subject_issuer="dify:account" sentinel instead of NULL so the rotate-in-place unique index collides as intended (Postgres treats NULLs as distinct in unique indices). mint_oauth_token validates prefix-specific issuer rules. - api: enterprise_only inverts to an allowlist (ACTIVE / EXPIRING) so any future LicenseStatus value defaults to denial. - api: consume_on_poll moved to a single Lua script (GET + status-check + DEL) so concurrent pollers can't both observe APPROVED. - web: typed DeviceFlowError + central error-copy mapping; page surfaces rate_limited / lookup_failed view states; URL params scrubbed after consumption (RFC 8628 §5.4). --- api/controllers/console/auth/oauth_device.py | 5 +- api/libs/device_flow_security.py | 6 +- ...00-d4a5e1f3c9b7_add_oauth_access_tokens.py | 6 +- api/models/oauth.py | 11 ++- api/services/oauth_device_flow.py | 74 +++++++++++++++---- .../device/components/authorize-account.tsx | 9 ++- web/app/device/components/authorize-sso.tsx | 9 ++- web/app/device/page.tsx | 47 ++++++++++-- web/app/device/utils/error-copy.ts | 41 ++++++++++ web/service/device-flow.ts | 58 ++++++++++----- 10 files changed, 208 insertions(+), 58 deletions(-) create mode 100644 web/app/device/utils/error-copy.ts diff --git a/api/controllers/console/auth/oauth_device.py b/api/controllers/console/auth/oauth_device.py index beef3698de..80ef8c8d5b 100644 --- a/api/controllers/console/auth/oauth_device.py +++ b/api/controllers/console/auth/oauth_device.py @@ -36,6 +36,7 @@ def bearer_feature_required(fn): return inner from services.oauth_device_flow import ( + ACCOUNT_ISSUER_SENTINEL, PREFIX_OAUTH_ACCOUNT, DeviceFlowRedis, DeviceFlowStatus, @@ -90,7 +91,7 @@ class DeviceApproveApi(Resource): db.session, redis_client, subject_email=account.email, - subject_issuer=None, + subject_issuer=ACCOUNT_ISSUER_SENTINEL, account_id=str(account.id), client_id=state.client_id, device_label=state.device_label, @@ -104,7 +105,7 @@ class DeviceApproveApi(Resource): device_code, subject_email=account.email, account_id=str(account.id), - subject_issuer=None, + subject_issuer=ACCOUNT_ISSUER_SENTINEL, minted_token=mint.token, token_id=str(mint.token_id), poll_payload=poll_payload, diff --git a/api/libs/device_flow_security.py b/api/libs/device_flow_security.py index e589a16522..3c081138da 100644 --- a/api/libs/device_flow_security.py +++ b/api/libs/device_flow_security.py @@ -25,7 +25,9 @@ logger = logging.getLogger(__name__) # ============================================================================ -_CE_LIKE_STATUSES = {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST} +# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST) +# is denied. Future LicenseStatus values default to denial unless explicitly admitted. +_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING} def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]: @@ -36,7 +38,7 @@ def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]: @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): settings = FeatureService.get_system_features() - if settings.license.status in _CE_LIKE_STATUSES: + if settings.license.status not in _EE_ENABLED_STATUSES: raise NotFound() return view(*args, **kwargs) diff --git a/api/migrations/versions/2026_04_23_2200-d4a5e1f3c9b7_add_oauth_access_tokens.py b/api/migrations/versions/2026_04_23_2200-d4a5e1f3c9b7_add_oauth_access_tokens.py index a0e34b9a17..fbb2ef801e 100644 --- a/api/migrations/versions/2026_04_23_2200-d4a5e1f3c9b7_add_oauth_access_tokens.py +++ b/api/migrations/versions/2026_04_23_2200-d4a5e1f3c9b7_add_oauth_access_tokens.py @@ -82,8 +82,10 @@ def upgrade(): postgresql_where=sa.text("revoked_at IS NULL"), ) # Partial unique index — rotate-in-place keyed on (subject, client, device). - # subject_issuer NULL vs populated distinguishes account vs external-SSO rows - # for the same email, because Postgres treats NULL as distinct. + # The app always writes a non-NULL subject_issuer (account flow uses a + # sentinel, external-SSO uses the verified IdP issuer); without that the + # composite key would never collide because Postgres treats NULLs as + # distinct in unique indices. op.create_index( "uq_oauth_active_per_device", "oauth_access_tokens", diff --git a/api/models/oauth.py b/api/models/oauth.py index a88dd9345d..5ab10fb7d0 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -87,10 +87,13 @@ class DatasourceOauthTenantParamConfig(TypeBase): class OAuthAccessToken(TypeBase): - """Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account); - account_id NULL + subject_issuer ⇒ dfoe_ (external SSO, EE-only). - Partial unique index on (subject_email, subject_issuer, client_id, - device_label) WHERE revoked_at IS NULL lets re-login rotate in place. + """Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account, + subject_issuer = "dify:account" sentinel); account_id NULL + + subject_issuer = verified IdP issuer ⇒ dfoe_ (external SSO, EE-only). + subject_issuer is non-NULL for all rows the app writes — Postgres + treats NULLs as distinct in unique indices, so the partial unique + index on (subject_email, subject_issuer, client_id, device_label) + WHERE revoked_at IS NULL would otherwise fail to rotate in place. """ __tablename__ = "oauth_access_tokens" diff --git a/api/services/oauth_device_flow.py b/api/services/oauth_device_flow.py index 381d6d6a85..6aa12cd536 100644 --- a/api/services/oauth_device_flow.py +++ b/api/services/oauth_device_flow.py @@ -29,8 +29,26 @@ logger = logging.getLogger(__name__) # ============================================================================ -DEVICE_CODE_KEY_FMT = "device_code:{code}" -USER_CODE_KEY_FMT = "user_code:{code}" +_DEVICE_CODE_KEY_PREFIX = "device_code:" +_USER_CODE_KEY_PREFIX = "user_code:" +DEVICE_CODE_KEY_FMT = _DEVICE_CODE_KEY_PREFIX + "{code}" +USER_CODE_KEY_FMT = _USER_CODE_KEY_PREFIX + "{code}" + +# Atomic GET → status-check → DEL(both keys). Two concurrent pollers must +# not both observe APPROVED — only the winner gets the plaintext token, +# the loser sees nil and the caller maps that to expired_token. +_CONSUME_ON_POLL_LUA = """ +local raw = redis.call('GET', KEYS[1]) +if not raw then return nil end +local ok, decoded = pcall(cjson.decode, raw) +if not ok then return nil end +if decoded.status == 'pending' then return nil end +if decoded.user_code then + redis.call('DEL', ARGV[1] .. decoded.user_code) +end +redis.call('DEL', KEYS[1]) +return raw +""" DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor @@ -112,6 +130,7 @@ class DeviceFlowRedis: def __init__(self, redis_client) -> None: self._redis = redis_client + self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA) def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]: device_code = _random_device_code() @@ -205,19 +224,23 @@ class DeviceFlowRedis: ) def consume_on_poll(self, device_code: str) -> DeviceFlowState | None: - """Race-safe via DEL: concurrent polls — one wins, the other gets - None and the caller maps that to expired_token. + """Race-safe via Lua EVAL: GET + status-check + DEL execute in a + single Redis transaction so only one of N concurrent pollers + observes the APPROVED state. Losers get None, mapped to + expired_token by the caller. """ - state = self._load_state(device_code) - if state is None: - return None - if state.status is DeviceFlowStatus.PENDING: - return None - self._redis.delete( - DEVICE_CODE_KEY_FMT.format(code=device_code), - USER_CODE_KEY_FMT.format(code=state.user_code), + raw = self._consume_on_poll_script( + keys=[DEVICE_CODE_KEY_FMT.format(code=device_code)], + args=[_USER_CODE_KEY_PREFIX], ) - return state + if raw is None: + return None + text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw + try: + return DeviceFlowState.from_json(text_) + except (ValueError, KeyError): + logger.error("device_flow: corrupt state on consume %s", device_code) + return None def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision: now = time.time() @@ -254,6 +277,12 @@ OAUTH_BODY_BYTES = 32 # ~256 bits entropy PREFIX_OAUTH_ACCOUNT = "dfoa_" PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_" +# Sentinel issuer for account-flow rows. Postgres' default partial unique +# index treats NULLs as distinct, which would let two live `dfoa_` rows +# share (email, client, device) and break rotate-in-place. Storing a +# non-empty literal makes the composite key collide as intended. +ACCOUNT_ISSUER_SENTINEL = "dify:account" + @dataclass(frozen=True, slots=True) class MintResult: @@ -295,7 +324,20 @@ def mint_oauth_token( index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is deleted so stale AuthContext drops immediately. """ - if prefix not in (PREFIX_OAUTH_ACCOUNT, PREFIX_OAUTH_EXTERNAL_SSO): + if prefix == PREFIX_OAUTH_ACCOUNT: + # Account flow always writes the sentinel — caller may pass None + # (for clarity) or the sentinel itself; nothing else is valid. + if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL): + raise ValueError( + f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}" + ) + subject_issuer = ACCOUNT_ISSUER_SENTINEL + elif prefix == PREFIX_OAUTH_EXTERNAL_SSO: + # Defense in depth: enterprise canonicalises + rejects empty, + # but a regression there must not yield a NULL composite key here. + if not subject_issuer or not subject_issuer.strip(): + raise ValueError("external-SSO token requires non-empty subject_issuer") + else: raise ValueError(f"unknown oauth prefix: {prefix!r}") token = generate_token(prefix) @@ -333,11 +375,13 @@ def _upsert( expires_at: datetime, ) -> UpsertOutcome: # Snapshot prior live row's hash for Redis invalidation post-rotate. + # subject_issuer is always non-null here (account flow uses sentinel, + # external-SSO is validated upstream), so equality matches the index. prior = session.execute( select(OAuthAccessToken.id, OAuthAccessToken.token_hash) .where( OAuthAccessToken.subject_email == subject_email, - OAuthAccessToken.subject_issuer.is_not_distinct_from(subject_issuer), + OAuthAccessToken.subject_issuer == subject_issuer, OAuthAccessToken.client_id == client_id, OAuthAccessToken.device_label == device_label, OAuthAccessToken.revoked_at.is_(None), diff --git a/web/app/device/components/authorize-account.tsx b/web/app/device/components/authorize-account.tsx index a02088c48f..0b3cf79866 100644 --- a/web/app/device/components/authorize-account.tsx +++ b/web/app/device/components/authorize-account.tsx @@ -3,6 +3,7 @@ import type { FC } from 'react' import { useState } from 'react' import { deviceApproveAccount, deviceDenyAccount } from '@/service/device-flow' +import { approveErrorCopy } from '../utils/error-copy' type Props = { userCode: string @@ -30,8 +31,8 @@ const AuthorizeAccount: FC = ({ await deviceApproveAccount(userCode) onApproved() } - catch (e: any) { - onError(e?.message || 'Approve failed') + catch (e) { + onError(approveErrorCopy(e)) } finally { setBusy(false) @@ -44,8 +45,8 @@ const AuthorizeAccount: FC = ({ await deviceDenyAccount(userCode) onDenied() } - catch (e: any) { - onError(e?.message || 'Deny failed') + catch (e) { + onError(approveErrorCopy(e)) } finally { setBusy(false) diff --git a/web/app/device/components/authorize-sso.tsx b/web/app/device/components/authorize-sso.tsx index a327c54858..60dc277641 100644 --- a/web/app/device/components/authorize-sso.tsx +++ b/web/app/device/components/authorize-sso.tsx @@ -4,6 +4,7 @@ import type { FC } from 'react' import { useEffect, useState } from 'react' import type { ApprovalContext } from '@/service/device-flow' import { approveExternal, fetchApprovalContext } from '@/service/device-flow' +import { approveErrorCopy } from '../utils/error-copy' type Props = { onApproved: () => void @@ -29,9 +30,9 @@ const AuthorizeSSO: FC = ({ onApproved, onError }) => { let cancelled = false fetchApprovalContext() .then((c) => { if (!cancelled) setCtx(c) }) - .catch((e: any) => { + .catch((e) => { if (!cancelled) - setLoadErr(e?.message || 'Failed to load session') + setLoadErr(approveErrorCopy(e)) }) return () => { cancelled = true } }, []) @@ -43,8 +44,8 @@ const AuthorizeSSO: FC = ({ onApproved, onError }) => { await approveExternal(ctx, ctx.user_code) onApproved() } - catch (e: any) { - onError(e?.message || 'Approve failed') + catch (e) { + onError(approveErrorCopy(e)) } finally { setBusy(false) diff --git a/web/app/device/page.tsx b/web/app/device/page.tsx index 0d19448fd7..0bfa1afbb2 100644 --- a/web/app/device/page.tsx +++ b/web/app/device/page.tsx @@ -1,7 +1,7 @@ 'use client' import { useEffect, useState } from 'react' -import { useSearchParams } from '@/next/navigation' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { useQuery } from '@tanstack/react-query' import { systemFeaturesQueryOptions } from '@/service/system-features' import { commonQueryKeys, userProfileQueryOptions } from '@/service/use-common' @@ -13,6 +13,7 @@ import Chooser from './components/chooser' import AuthorizeAccount from './components/authorize-account' import AuthorizeSSO from './components/authorize-sso' import { isValidUserCode } from './utils/user-code' +import { classifyLookupError } from './utils/error-copy' type View = | { kind: 'code_entry' } @@ -21,9 +22,13 @@ type View = | { kind: 'authorize_sso' } | { kind: 'success' } | { kind: 'error_expired' } + | { kind: 'error_rate_limited' } + | { kind: 'error_lookup_failed' } export default function DevicePage() { const searchParams = useSearchParams() + const router = useRouter() + const pathname = usePathname() const urlUserCode = (searchParams.get('user_code') || '').trim().toUpperCase() const ssoVerified = searchParams.get('sso_verified') === '1' @@ -61,19 +66,25 @@ export default function DevicePage() { // URL-driven view transitions. Only advances while the user is still on // the entry/chooser screens — never clobbers terminal views (success / // error_expired / authorize_*) when userProfile refetches. + // After consuming the params, scrub them from the URL so they don't + // leak via history / Referer / server logs (RFC 8628 §5.4). useEffect(() => { if (view.kind !== 'code_entry' && view.kind !== 'chooser') return + let consumed = false if (ssoVerified) { setView({ kind: 'authorize_sso' }) - return + consumed = true } - if (urlUserCode && isValidUserCode(urlUserCode)) { + else if (urlUserCode && isValidUserCode(urlUserCode)) { if (account) setView({ kind: 'authorize_account', userCode: urlUserCode }) else setView({ kind: 'chooser', userCode: urlUserCode }) + consumed = true } - }, [urlUserCode, ssoVerified, account, view.kind]) + if (consumed && (urlUserCode || ssoVerified)) + router.replace(pathname) + }, [urlUserCode, ssoVerified, account, view.kind, router, pathname]) const onContinue = async () => { if (!isValidUserCode(typed)) return @@ -84,8 +95,14 @@ export default function DevicePage() { return } } - catch { - setView({ kind: 'error_expired' }) + catch (e) { + const outcome = classifyLookupError(e) + if (outcome === 'rate_limited') + setView({ kind: 'error_rate_limited' }) + else if (outcome === 'failed') + setView({ kind: 'error_lookup_failed' }) + else + setView({ kind: 'error_expired' }) return } if (account) setView({ kind: 'authorize_account', userCode: typed }) @@ -164,6 +181,24 @@ export default function DevicePage() { )} + {view.kind === 'error_rate_limited' && ( +
+

Too many attempts

+

+ We've received too many requests for this code. Wait a moment and try again. +

+
+ )} + + {view.kind === 'error_lookup_failed' && ( +
+

Could not verify the code

+

+ Something went wrong on our side. Try again in a moment. +

+
+ )} + {errMsg && (

{errMsg}

)} diff --git a/web/app/device/utils/error-copy.ts b/web/app/device/utils/error-copy.ts new file mode 100644 index 0000000000..cfdf2af252 --- /dev/null +++ b/web/app/device/utils/error-copy.ts @@ -0,0 +1,41 @@ +// Translate a DeviceFlowError (or any thrown value) into user-facing copy. +// Centralised so account/SSO branches surface the same words for the same +// failure mode and so a new server error code can be wired up here once. + +import { DeviceFlowError } from '@/service/device-flow' + +const APPROVE_COPY: Record = { + rate_limited: 'Too many attempts. Wait a moment and try again.', + no_session: 'Your session has expired. Run difyctl auth login again to start over.', + invalid_session: 'Your session has expired. Run difyctl auth login again to start over.', + session_already_consumed: 'This session was already used. Run difyctl auth login again.', + csrf_mismatch: 'Could not verify the request. Refresh the page and try again.', + forbidden: 'Could not verify the request. Refresh the page and try again.', + expired_or_unknown: 'This code is no longer valid.', + not_found: 'This code is no longer valid.', + user_code_mismatch: 'This code does not match the active session. Run difyctl auth login again.', + user_code_not_pending: 'This code was already approved or denied.', + already_resolved: 'This code was already approved or denied.', + state_lost: 'The flow expired before approval completed. Run difyctl auth login again.', + approve_in_progress: 'An approval is already in progress for this code.', + conflict: 'This code is no longer in a state we can approve.', + server_error: 'Something went wrong on our side. Try again in a moment.', +} + +const DEFAULT_MESSAGE = 'Could not complete the request. Please try again.' + +export function approveErrorCopy(err: unknown): string { + if (err instanceof DeviceFlowError) + return APPROVE_COPY[err.code] ?? DEFAULT_MESSAGE + return DEFAULT_MESSAGE +} + +export type LookupOutcome = 'expired' | 'rate_limited' | 'failed' + +export function classifyLookupError(err: unknown): LookupOutcome { + if (err instanceof DeviceFlowError) { + if (err.code === 'rate_limited' || err.status === 429) return 'rate_limited' + if (err.code === 'server_error' || err.status >= 500) return 'failed' + } + return 'expired' +} diff --git a/web/service/device-flow.ts b/web/service/device-flow.ts index b64cea0331..51e0c3cecc 100644 --- a/web/service/device-flow.ts +++ b/web/service/device-flow.ts @@ -10,10 +10,45 @@ // session cookies automatically. Lookup + SSO-branch endpoints sit under // /v1 so they ride the existing service-API gateway route. -import { del, post } from './base' +import { post } from './base' const DEVICE_BASE = '/v1/oauth/device' +// Typed error thrown by every wrapper here. The page/component layer +// switches on `code` to choose user-facing copy / view; never render +// `status` or raw body to the user. +export class DeviceFlowError extends Error { + constructor(public code: string, public status: number) { + super(code) + this.name = 'DeviceFlowError' + } +} + +// Translate a non-2xx fetch Response into a DeviceFlowError. Honours the +// server contract `{"error": ""}` and falls back to a status-class +// code so callers can still dispatch (rate_limited / server_error / ...). +async function failFromResponse(res: Response): Promise { + let serverCode = '' + try { + const body = await res.clone().json() + if (body && typeof body.error === 'string') serverCode = body.error + } + catch { /* non-JSON body — fall through to status mapping */ } + + const code = serverCode || statusFallbackCode(res.status) + throw new DeviceFlowError(code, res.status) +} + +function statusFallbackCode(status: number): string { + if (status === 429) return 'rate_limited' + if (status === 401) return 'no_session' + if (status === 403) return 'forbidden' + if (status === 404) return 'not_found' + if (status === 409) return 'conflict' + if (status >= 500) return 'server_error' + return 'unknown' +} + // ----- Account branch -------------------------------------------------------- export type DeviceLookupReply = { @@ -26,10 +61,7 @@ export async function deviceLookup(user_code: string): Promise '') - throw new Error(`lookup ${res.status}: ${body}`) - } + if (!res.ok) await failFromResponse(res) return res.json() } @@ -54,10 +86,7 @@ export async function fetchApprovalContext(): Promise { method: 'GET', credentials: 'include', }) - if (!res.ok) { - const body = await res.text().catch(() => '') - throw new Error(`approval-context ${res.status}: ${body}`) - } + if (!res.ok) await failFromResponse(res) return res.json() } @@ -71,14 +100,5 @@ export async function approveExternal(ctx: ApprovalContext, user_code: string): }, body: JSON.stringify({ user_code }), }) - if (!res.ok) { - const body = await res.text().catch(() => '') - throw new Error(`approve-external ${res.status}: ${body}`) - } + if (!res.ok) await failFromResponse(res) } - -// ----- Export for future PAT revoke; noop in v1.0 -------------------------- - -// Intentionally left out: personal_access_tokens endpoints are not in this -// milestone; see docs/specs/v1.0/README.md. -void del // keep import live for the TypeScript linter without surfacing usage