fix(api,web): post-review hardening for OAuth device flow

- api: account-flow stores subject_issuer="dify:account" sentinel
  instead of NULL so the rotate-in-place unique index collides as
  intended (Postgres treats NULLs as distinct in unique indices).
  mint_oauth_token validates prefix-specific issuer rules.
- api: enterprise_only inverts to an allowlist (ACTIVE / EXPIRING) so
  any future LicenseStatus value defaults to denial.
- api: consume_on_poll moved to a single Lua script (GET + status-check
  + DEL) so concurrent pollers can't both observe APPROVED.
- web: typed DeviceFlowError + central error-copy mapping; page
  surfaces rate_limited / lookup_failed view states; URL params
  scrubbed after consumption (RFC 8628 §5.4).
This commit is contained in:
GareArc 2026-04-26 23:05:07 -07:00
parent fe8510ad1a
commit 813da349ec
No known key found for this signature in database
10 changed files with 208 additions and 58 deletions

View File

@ -36,6 +36,7 @@ def bearer_feature_required(fn):
return inner
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
PREFIX_OAUTH_ACCOUNT,
DeviceFlowRedis,
DeviceFlowStatus,
@ -90,7 +91,7 @@ class DeviceApproveApi(Resource):
db.session,
redis_client,
subject_email=account.email,
subject_issuer=None,
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
account_id=str(account.id),
client_id=state.client_id,
device_label=state.device_label,
@ -104,7 +105,7 @@ class DeviceApproveApi(Resource):
device_code,
subject_email=account.email,
account_id=str(account.id),
subject_issuer=None,
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,

View File

@ -25,7 +25,9 @@ logger = logging.getLogger(__name__)
# ============================================================================
_CE_LIKE_STATUSES = {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@ -36,7 +38,7 @@ def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
if settings.license.status in _CE_LIKE_STATUSES:
if settings.license.status not in _EE_ENABLED_STATUSES:
raise NotFound()
return view(*args, **kwargs)

View File

@ -82,8 +82,10 @@ def upgrade():
postgresql_where=sa.text("revoked_at IS NULL"),
)
# Partial unique index — rotate-in-place keyed on (subject, client, device).
# subject_issuer NULL vs populated distinguishes account vs external-SSO rows
# for the same email, because Postgres treats NULL as distinct.
# The app always writes a non-NULL subject_issuer (account flow uses a
# sentinel, external-SSO uses the verified IdP issuer); without that the
# composite key would never collide because Postgres treats NULLs as
# distinct in unique indices.
op.create_index(
"uq_oauth_active_per_device",
"oauth_access_tokens",

View File

@ -87,10 +87,13 @@ class DatasourceOauthTenantParamConfig(TypeBase):
class OAuthAccessToken(TypeBase):
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account);
account_id NULL + subject_issuer dfoe_ (external SSO, EE-only).
Partial unique index on (subject_email, subject_issuer, client_id,
device_label) WHERE revoked_at IS NULL lets re-login rotate in place.
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account,
subject_issuer = "dify:account" sentinel); account_id NULL +
subject_issuer = verified IdP issuer dfoe_ (external SSO, EE-only).
subject_issuer is non-NULL for all rows the app writes Postgres
treats NULLs as distinct in unique indices, so the partial unique
index on (subject_email, subject_issuer, client_id, device_label)
WHERE revoked_at IS NULL would otherwise fail to rotate in place.
"""
__tablename__ = "oauth_access_tokens"

View File

@ -29,8 +29,26 @@ logger = logging.getLogger(__name__)
# ============================================================================
DEVICE_CODE_KEY_FMT = "device_code:{code}"
USER_CODE_KEY_FMT = "user_code:{code}"
_DEVICE_CODE_KEY_PREFIX = "device_code:"
_USER_CODE_KEY_PREFIX = "user_code:"
DEVICE_CODE_KEY_FMT = _DEVICE_CODE_KEY_PREFIX + "{code}"
USER_CODE_KEY_FMT = _USER_CODE_KEY_PREFIX + "{code}"
# Atomic GET → status-check → DEL(both keys). Two concurrent pollers must
# not both observe APPROVED — only the winner gets the plaintext token,
# the loser sees nil and the caller maps that to expired_token.
_CONSUME_ON_POLL_LUA = """
local raw = redis.call('GET', KEYS[1])
if not raw then return nil end
local ok, decoded = pcall(cjson.decode, raw)
if not ok then return nil end
if decoded.status == 'pending' then return nil end
if decoded.user_code then
redis.call('DEL', ARGV[1] .. decoded.user_code)
end
redis.call('DEL', KEYS[1])
return raw
"""
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
@ -112,6 +130,7 @@ class DeviceFlowRedis:
def __init__(self, redis_client) -> None:
self._redis = redis_client
self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA)
def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]:
device_code = _random_device_code()
@ -205,19 +224,23 @@ class DeviceFlowRedis:
)
def consume_on_poll(self, device_code: str) -> DeviceFlowState | None:
"""Race-safe via DEL: concurrent polls — one wins, the other gets
None and the caller maps that to expired_token.
"""Race-safe via Lua EVAL: GET + status-check + DEL execute in a
single Redis transaction so only one of N concurrent pollers
observes the APPROVED state. Losers get None, mapped to
expired_token by the caller.
"""
state = self._load_state(device_code)
if state is None:
return None
if state.status is DeviceFlowStatus.PENDING:
return None
self._redis.delete(
DEVICE_CODE_KEY_FMT.format(code=device_code),
USER_CODE_KEY_FMT.format(code=state.user_code),
raw = self._consume_on_poll_script(
keys=[DEVICE_CODE_KEY_FMT.format(code=device_code)],
args=[_USER_CODE_KEY_PREFIX],
)
return state
if raw is None:
return None
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
try:
return DeviceFlowState.from_json(text_)
except (ValueError, KeyError):
logger.error("device_flow: corrupt state on consume %s", device_code)
return None
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
now = time.time()
@ -254,6 +277,12 @@ OAUTH_BODY_BYTES = 32 # ~256 bits entropy
PREFIX_OAUTH_ACCOUNT = "dfoa_"
PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_"
# Sentinel issuer for account-flow rows. Postgres' default partial unique
# index treats NULLs as distinct, which would let two live `dfoa_` rows
# share (email, client, device) and break rotate-in-place. Storing a
# non-empty literal makes the composite key collide as intended.
ACCOUNT_ISSUER_SENTINEL = "dify:account"
@dataclass(frozen=True, slots=True)
class MintResult:
@ -295,7 +324,20 @@ def mint_oauth_token(
index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is
deleted so stale AuthContext drops immediately.
"""
if prefix not in (PREFIX_OAUTH_ACCOUNT, PREFIX_OAUTH_EXTERNAL_SSO):
if prefix == PREFIX_OAUTH_ACCOUNT:
# Account flow always writes the sentinel — caller may pass None
# (for clarity) or the sentinel itself; nothing else is valid.
if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL):
raise ValueError(
f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}"
)
subject_issuer = ACCOUNT_ISSUER_SENTINEL
elif prefix == PREFIX_OAUTH_EXTERNAL_SSO:
# Defense in depth: enterprise canonicalises + rejects empty,
# but a regression there must not yield a NULL composite key here.
if not subject_issuer or not subject_issuer.strip():
raise ValueError("external-SSO token requires non-empty subject_issuer")
else:
raise ValueError(f"unknown oauth prefix: {prefix!r}")
token = generate_token(prefix)
@ -333,11 +375,13 @@ def _upsert(
expires_at: datetime,
) -> UpsertOutcome:
# Snapshot prior live row's hash for Redis invalidation post-rotate.
# subject_issuer is always non-null here (account flow uses sentinel,
# external-SSO is validated upstream), so equality matches the index.
prior = session.execute(
select(OAuthAccessToken.id, OAuthAccessToken.token_hash)
.where(
OAuthAccessToken.subject_email == subject_email,
OAuthAccessToken.subject_issuer.is_not_distinct_from(subject_issuer),
OAuthAccessToken.subject_issuer == subject_issuer,
OAuthAccessToken.client_id == client_id,
OAuthAccessToken.device_label == device_label,
OAuthAccessToken.revoked_at.is_(None),

View File

@ -3,6 +3,7 @@
import type { FC } from 'react'
import { useState } from 'react'
import { deviceApproveAccount, deviceDenyAccount } from '@/service/device-flow'
import { approveErrorCopy } from '../utils/error-copy'
type Props = {
userCode: string
@ -30,8 +31,8 @@ const AuthorizeAccount: FC<Props> = ({
await deviceApproveAccount(userCode)
onApproved()
}
catch (e: any) {
onError(e?.message || 'Approve failed')
catch (e) {
onError(approveErrorCopy(e))
}
finally {
setBusy(false)
@ -44,8 +45,8 @@ const AuthorizeAccount: FC<Props> = ({
await deviceDenyAccount(userCode)
onDenied()
}
catch (e: any) {
onError(e?.message || 'Deny failed')
catch (e) {
onError(approveErrorCopy(e))
}
finally {
setBusy(false)

View File

@ -4,6 +4,7 @@ import type { FC } from 'react'
import { useEffect, useState } from 'react'
import type { ApprovalContext } from '@/service/device-flow'
import { approveExternal, fetchApprovalContext } from '@/service/device-flow'
import { approveErrorCopy } from '../utils/error-copy'
type Props = {
onApproved: () => void
@ -29,9 +30,9 @@ const AuthorizeSSO: FC<Props> = ({ onApproved, onError }) => {
let cancelled = false
fetchApprovalContext()
.then((c) => { if (!cancelled) setCtx(c) })
.catch((e: any) => {
.catch((e) => {
if (!cancelled)
setLoadErr(e?.message || 'Failed to load session')
setLoadErr(approveErrorCopy(e))
})
return () => { cancelled = true }
}, [])
@ -43,8 +44,8 @@ const AuthorizeSSO: FC<Props> = ({ onApproved, onError }) => {
await approveExternal(ctx, ctx.user_code)
onApproved()
}
catch (e: any) {
onError(e?.message || 'Approve failed')
catch (e) {
onError(approveErrorCopy(e))
}
finally {
setBusy(false)

View File

@ -1,7 +1,7 @@
'use client'
import { useEffect, useState } from 'react'
import { useSearchParams } from '@/next/navigation'
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
import { useQuery } from '@tanstack/react-query'
import { systemFeaturesQueryOptions } from '@/service/system-features'
import { commonQueryKeys, userProfileQueryOptions } from '@/service/use-common'
@ -13,6 +13,7 @@ import Chooser from './components/chooser'
import AuthorizeAccount from './components/authorize-account'
import AuthorizeSSO from './components/authorize-sso'
import { isValidUserCode } from './utils/user-code'
import { classifyLookupError } from './utils/error-copy'
type View =
| { kind: 'code_entry' }
@ -21,9 +22,13 @@ type View =
| { kind: 'authorize_sso' }
| { kind: 'success' }
| { kind: 'error_expired' }
| { kind: 'error_rate_limited' }
| { kind: 'error_lookup_failed' }
export default function DevicePage() {
const searchParams = useSearchParams()
const router = useRouter()
const pathname = usePathname()
const urlUserCode = (searchParams.get('user_code') || '').trim().toUpperCase()
const ssoVerified = searchParams.get('sso_verified') === '1'
@ -61,19 +66,25 @@ export default function DevicePage() {
// URL-driven view transitions. Only advances while the user is still on
// the entry/chooser screens — never clobbers terminal views (success /
// error_expired / authorize_*) when userProfile refetches.
// After consuming the params, scrub them from the URL so they don't
// leak via history / Referer / server logs (RFC 8628 §5.4).
useEffect(() => {
if (view.kind !== 'code_entry' && view.kind !== 'chooser') return
let consumed = false
if (ssoVerified) {
setView({ kind: 'authorize_sso' })
return
consumed = true
}
if (urlUserCode && isValidUserCode(urlUserCode)) {
else if (urlUserCode && isValidUserCode(urlUserCode)) {
if (account)
setView({ kind: 'authorize_account', userCode: urlUserCode })
else
setView({ kind: 'chooser', userCode: urlUserCode })
consumed = true
}
}, [urlUserCode, ssoVerified, account, view.kind])
if (consumed && (urlUserCode || ssoVerified))
router.replace(pathname)
}, [urlUserCode, ssoVerified, account, view.kind, router, pathname])
const onContinue = async () => {
if (!isValidUserCode(typed)) return
@ -84,8 +95,14 @@ export default function DevicePage() {
return
}
}
catch {
setView({ kind: 'error_expired' })
catch (e) {
const outcome = classifyLookupError(e)
if (outcome === 'rate_limited')
setView({ kind: 'error_rate_limited' })
else if (outcome === 'failed')
setView({ kind: 'error_lookup_failed' })
else
setView({ kind: 'error_expired' })
return
}
if (account) setView({ kind: 'authorize_account', userCode: typed })
@ -164,6 +181,24 @@ export default function DevicePage() {
</div>
)}
{view.kind === 'error_rate_limited' && (
<div>
<h1 className="text-2xl font-semibold text-text-primary">Too many attempts</h1>
<p className="mt-2 text-sm text-text-secondary">
We&apos;ve received too many requests for this code. Wait a moment and try again.
</p>
</div>
)}
{view.kind === 'error_lookup_failed' && (
<div>
<h1 className="text-2xl font-semibold text-text-primary">Could not verify the code</h1>
<p className="mt-2 text-sm text-text-secondary">
Something went wrong on our side. Try again in a moment.
</p>
</div>
)}
{errMsg && (
<p className="mt-4 text-sm text-text-destructive">{errMsg}</p>
)}

View File

@ -0,0 +1,41 @@
// Translate a DeviceFlowError (or any thrown value) into user-facing copy.
// Centralised so account/SSO branches surface the same words for the same
// failure mode and so a new server error code can be wired up here once.
import { DeviceFlowError } from '@/service/device-flow'
const APPROVE_COPY: Record<string, string> = {
rate_limited: 'Too many attempts. Wait a moment and try again.',
no_session: 'Your session has expired. Run difyctl auth login again to start over.',
invalid_session: 'Your session has expired. Run difyctl auth login again to start over.',
session_already_consumed: 'This session was already used. Run difyctl auth login again.',
csrf_mismatch: 'Could not verify the request. Refresh the page and try again.',
forbidden: 'Could not verify the request. Refresh the page and try again.',
expired_or_unknown: 'This code is no longer valid.',
not_found: 'This code is no longer valid.',
user_code_mismatch: 'This code does not match the active session. Run difyctl auth login again.',
user_code_not_pending: 'This code was already approved or denied.',
already_resolved: 'This code was already approved or denied.',
state_lost: 'The flow expired before approval completed. Run difyctl auth login again.',
approve_in_progress: 'An approval is already in progress for this code.',
conflict: 'This code is no longer in a state we can approve.',
server_error: 'Something went wrong on our side. Try again in a moment.',
}
const DEFAULT_MESSAGE = 'Could not complete the request. Please try again.'
export function approveErrorCopy(err: unknown): string {
if (err instanceof DeviceFlowError)
return APPROVE_COPY[err.code] ?? DEFAULT_MESSAGE
return DEFAULT_MESSAGE
}
export type LookupOutcome = 'expired' | 'rate_limited' | 'failed'
export function classifyLookupError(err: unknown): LookupOutcome {
if (err instanceof DeviceFlowError) {
if (err.code === 'rate_limited' || err.status === 429) return 'rate_limited'
if (err.code === 'server_error' || err.status >= 500) return 'failed'
}
return 'expired'
}

View File

@ -10,10 +10,45 @@
// session cookies automatically. Lookup + SSO-branch endpoints sit under
// /v1 so they ride the existing service-API gateway route.
import { del, post } from './base'
import { post } from './base'
const DEVICE_BASE = '/v1/oauth/device'
// Typed error thrown by every wrapper here. The page/component layer
// switches on `code` to choose user-facing copy / view; never render
// `status` or raw body to the user.
export class DeviceFlowError extends Error {
constructor(public code: string, public status: number) {
super(code)
this.name = 'DeviceFlowError'
}
}
// Translate a non-2xx fetch Response into a DeviceFlowError. Honours the
// server contract `{"error": "<code>"}` and falls back to a status-class
// code so callers can still dispatch (rate_limited / server_error / ...).
async function failFromResponse(res: Response): Promise<never> {
let serverCode = ''
try {
const body = await res.clone().json()
if (body && typeof body.error === 'string') serverCode = body.error
}
catch { /* non-JSON body — fall through to status mapping */ }
const code = serverCode || statusFallbackCode(res.status)
throw new DeviceFlowError(code, res.status)
}
function statusFallbackCode(status: number): string {
if (status === 429) return 'rate_limited'
if (status === 401) return 'no_session'
if (status === 403) return 'forbidden'
if (status === 404) return 'not_found'
if (status === 409) return 'conflict'
if (status >= 500) return 'server_error'
return 'unknown'
}
// ----- Account branch --------------------------------------------------------
export type DeviceLookupReply = {
@ -26,10 +61,7 @@ export async function deviceLookup(user_code: string): Promise<DeviceLookupReply
const res = await fetch(`${DEVICE_BASE}/lookup?user_code=${encodeURIComponent(user_code)}`, {
method: 'GET',
})
if (!res.ok) {
const body = await res.text().catch(() => '')
throw new Error(`lookup ${res.status}: ${body}`)
}
if (!res.ok) await failFromResponse(res)
return res.json()
}
@ -54,10 +86,7 @@ export async function fetchApprovalContext(): Promise<ApprovalContext> {
method: 'GET',
credentials: 'include',
})
if (!res.ok) {
const body = await res.text().catch(() => '')
throw new Error(`approval-context ${res.status}: ${body}`)
}
if (!res.ok) await failFromResponse(res)
return res.json()
}
@ -71,14 +100,5 @@ export async function approveExternal(ctx: ApprovalContext, user_code: string):
},
body: JSON.stringify({ user_code }),
})
if (!res.ok) {
const body = await res.text().catch(() => '')
throw new Error(`approve-external ${res.status}: ${body}`)
}
if (!res.ok) await failFromResponse(res)
}
// ----- Export for future PAT revoke; noop in v1.0 --------------------------
// Intentionally left out: personal_access_tokens endpoints are not in this
// milestone; see docs/specs/v1.0/README.md.
void del // keep import live for the TypeScript linter without surfacing usage