mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
feat(api,web): OAuth 2.0 device flow + bearer auth (RFC 8628)
Adds a CLI-friendly authorization flow so difyctl (and future
non-browser clients) can obtain user-scoped tokens without copy-
pasting cookies or raw API keys. Two grant paths share one device
flow surface:
1. Account branch — user signs in via the existing /signin
methods, /device page calls console-authed approve, mints a
dfoa_ token tied to (account_id, tenant).
2. External-SSO branch (EE) — /v1/oauth/device/sso-initiate signs
an SSOState envelope, hands off to Enterprise's external ACS,
receives a signed external-subject assertion, mints a dfoe_
token tied to (subject_email, subject_issuer).
API surface (all under /v1, EE-only endpoints 404 on CE):
POST /v1/oauth/device/code — RFC 8628 start
POST /v1/oauth/device/token — RFC 8628 poll
GET /v1/oauth/device/lookup — pre-validate user_code
GET /v1/oauth/device/sso-initiate — SSO branch entry
GET /v1/device/sso-complete — SSO callback sink
GET /v1/oauth/device/approval-context — /device cookie probe
POST /v1/oauth/device/approve-external — SSO approve
GET /v1/me — bearer subject lookup
DELETE /v1/oauth/authorizations/self — self-revoke
POST /console/api/oauth/device/approve — account approve
POST /console/api/oauth/device/deny — account deny
Core primitives:
- libs/oauth_bearer.py: prefix-keyed TokenKindRegistry +
BearerAuthenticator + validate_bearer decorator. Two-tier scope
(full vs apps:run) stamped from the registry, never from the DB.
- libs/jws.py: HS256 compact JWS keyed on the shared Dify
SECRET_KEY — same key-set verifies the SSOState envelope, the
external-subject assertion (minted by Enterprise), and the
approval-grant cookie.
- libs/device_flow_security.py: enterprise_only gate, approval-
grant cookie mint/verify/consume (Path=/v1/oauth/device,
HttpOnly, SameSite=Lax, Secure follows is_secure()), anti-
framing headers.
- libs/rate_limit.py: typed RateLimit / RateLimitScope dispatch
with composite-key buckets; both decorator + imperative form.
- services/oauth_device_flow.py: Redis state machine (PENDING ->
APPROVED|DENIED with atomic consume-on-poll), token mint via
partial unique index uq_oauth_active_per_device (rotates in
place), env-driven TTL policy.
Storage: oauth_access_tokens table with partial unique index on
(subject_email, subject_issuer, client_id, device_label) WHERE
revoked_at IS NULL. account_id NULL distinguishes external-SSO
rows. Hard-expire is CAS UPDATE (revoked_at + nullify token_hash)
so audit events keep their token_id. Retention pruner DELETEs
revoked + zombie-expired rows past OAUTH_ACCESS_TOKEN_RETENTION_DAYS.
Frontend: /device page with code-entry, chooser (account vs SSO),
authorize-account, authorize-sso views. SSO branch detaches from
the URL user_code and reads everything from the cookie via
/approval-context. Anti-framing headers on all responses.
Wiring: ENABLE_OAUTH_BEARER feature flag; ext_oauth_bearer binds
the authenticator at startup; clean_oauth_access_tokens_task
scheduled in ext_celery.
Spec: docs/specs/v1.0/server/{device-flow,tokens,middleware,security}.md
This commit is contained in:
parent
8f070f2190
commit
fe8510ad1a
@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_oauth_bearer,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
ext_oauth_bearer,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
||||
@ -874,6 +874,11 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_BEARER: bool = Field(
|
||||
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -1148,6 +1153,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||
default=True,
|
||||
)
|
||||
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||
default=30,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
|
||||
@ -81,6 +81,7 @@ from .auth import (
|
||||
forgot_password,
|
||||
login,
|
||||
oauth,
|
||||
oauth_device,
|
||||
oauth_server,
|
||||
)
|
||||
|
||||
@ -189,6 +190,7 @@ __all__ = [
|
||||
"models",
|
||||
"notification",
|
||||
"oauth",
|
||||
"oauth_device",
|
||||
"oauth_server",
|
||||
"ops_trace",
|
||||
"parameter",
|
||||
|
||||
221
api/controllers/console/auth/oauth_device.py
Normal file
221
api/controllers/console/auth/oauth_device.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""Console-session-authed device-flow approve/deny. Called by the
|
||||
/device page after the user signs in. Public lookup is in service_api/oauth.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import ServiceUnavailable
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from libs.rate_limit import LIMIT_APPROVE_CONSOLE, rate_limit
|
||||
|
||||
|
||||
def bearer_feature_required(fn):
|
||||
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
|
||||
without the authenticator, so fail fast instead of approving silently.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if not dify_config.ENABLE_OAUTH_BEARER:
|
||||
raise ServiceUnavailable(
|
||||
"bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable"
|
||||
)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
from services.oauth_device_flow import (
|
||||
PREFIX_OAUTH_ACCOUNT,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransition,
|
||||
StateNotFound,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_mutate_parser = reqparse.RequestParser()
|
||||
_mutate_parser.add_argument("user_code", type=str, required=True, location="json")
|
||||
|
||||
|
||||
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@console_ns.route("/oauth/device/approve")
|
||||
class DeviceApproveApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
args = _mutate_parser.parse_args()
|
||||
user_code = args["user_code"].strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
# SET NX guard — without it, two in-flight approves both pass
|
||||
# PENDING, both mint, and the second upsert silently rotates the
|
||||
# first caller into an already-revoked token.
|
||||
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||
return {"error": "approve_in_progress"}, 409
|
||||
|
||||
try:
|
||||
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=account.email,
|
||||
subject_issuer=None,
|
||||
account_id=str(account.id),
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=PREFIX_OAUTH_ACCOUNT,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=account.email,
|
||||
account_id=str(account.id),
|
||||
subject_issuer=None,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFound, InvalidTransition) as e:
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.error("device_flow: approve raced on %s: %s", device_code, e)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
_emit_approve_audit(state, account, tenant, mint)
|
||||
return {"status": "approved"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/oauth/device/deny")
|
||||
class DeviceDenyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
args = _mutate_parser.parse_args()
|
||||
user_code = args["user_code"].strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFound, InvalidTransition) as e:
|
||||
logger.error("device_flow: deny raced on %s: %s", device_code, e)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
return {"status": "denied"}, 200
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
"""Pre-render the poll-response body so the unauthenticated poll
|
||||
handler doesn't re-query accounts/tenants for authz data.
|
||||
"""
|
||||
from models import Tenant, TenantAccountJoin
|
||||
rows = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.account_id == account.id)
|
||||
.all()
|
||||
)
|
||||
workspaces = [
|
||||
{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")}
|
||||
for t, m in rows
|
||||
]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w["id"] == str(tenant) for w in workspaces):
|
||||
default_ws_id = str(tenant)
|
||||
if default_ws_id is None:
|
||||
for _t, m in rows:
|
||||
if getattr(m, "current", False):
|
||||
default_ws_id = str(m.tenant_id)
|
||||
break
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0]["id"]
|
||||
|
||||
return {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"account": {"id": str(account.id), "email": account.email, "name": account.name},
|
||||
"workspaces": workspaces,
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id, account.email, state.client_id, state.device_label, mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"token_id": str(mint.token_id),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"subject_email": account.email,
|
||||
"account_id": str(account.id),
|
||||
"tenant_id": tenant,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["full"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id, state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
264
api/controllers/oauth_device_sso.py
Normal file
264
api/controllers/oauth_device_sso.py
Normal file
@ -0,0 +1,264 @@
|
||||
"""SSO-branch device-flow endpoints. Browser hits sso-initiate → API
|
||||
signs an SSOState envelope → Enterprise inner-API returns IdP authorize
|
||||
URL → 302. IdP → Enterprise ACS → DeviceFlowDispatcher mints a signed
|
||||
external-subject assertion → 302 to /v1/device/sso-complete → API mints
|
||||
the approval-grant cookie → /device → user clicks Approve → /approve-
|
||||
external mints the OAuth token. All four endpoints are EE-only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask import Blueprint, jsonify, make_response, redirect, request
|
||||
from libs import jws
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||
LIMIT_SSO_INITIATE_PER_IP,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from libs.device_flow_security import (APPROVAL_GRANT_COOKIE_NAME, ApprovalGrantClaims,
|
||||
approval_grant_cleared_cookie_kwargs,
|
||||
approval_grant_cookie_kwargs,
|
||||
attach_anti_framing,
|
||||
consume_approval_grant_nonce,
|
||||
consume_sso_assertion_nonce,
|
||||
enterprise_only, mint_approval_grant,
|
||||
verify_approval_grant)
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (PREFIX_OAUTH_EXTERNAL_SSO,
|
||||
DeviceFlowRedis, DeviceFlowStatus,
|
||||
InvalidTransition, StateNotFound,
|
||||
mint_oauth_token, oauth_ttl_days)
|
||||
from werkzeug.exceptions import (BadGateway, BadRequest, Conflict, Forbidden,
|
||||
NotFound, Unauthorized)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
bp = Blueprint("oauth_device_sso", __name__, url_prefix="/v1")
|
||||
attach_anti_framing(bp)
|
||||
|
||||
|
||||
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||
# device_code it references.
|
||||
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
def sso_initiate():
|
||||
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||
if not user_code:
|
||||
raise BadRequest("user_code required")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise BadRequest("invalid_user_code")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
payload={
|
||||
"redirect_url": "",
|
||||
"app_code": "",
|
||||
"intent": "device_flow",
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{request.host_url.rstrip('/')}/v1/device/sso-complete",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||
except Exception as e:
|
||||
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||
raise BadGateway("sso_initiate_failed") from e
|
||||
|
||||
url = (reply or {}).get("url")
|
||||
if not url:
|
||||
raise BadGateway("sso_initiate_missing_url")
|
||||
|
||||
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||
resp = redirect(url, code=302)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = (claims.get("user_code") or "").strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
iss = request.host_url.rstrip("/")
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims["email"],
|
||||
subject_issuer=claims["issuer"],
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
resp = redirect("/device?sso_verified=1", code=302)
|
||||
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||
@enterprise_only
|
||||
def approval_context():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("no_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify({
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}), 200
|
||||
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@enterprise_only
|
||||
def approve_external():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("invalid_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approve-external: bad cookie: %s", e)
|
||||
raise Unauthorized("invalid_session") from e
|
||||
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or csrf_header != claims.csrf_token:
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||
if body_user_code != claims.user_code:
|
||||
raise BadRequest("user_code_mismatch")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(claims.user_code)
|
||||
if found is None:
|
||||
raise NotFound("user_code_not_pending")
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||
raise Unauthorized("session_already_consumed")
|
||||
|
||||
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=claims.subject_email,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
account_id=None,
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=PREFIX_OAUTH_EXTERNAL_SSO,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=claims.subject_email,
|
||||
account_id=None,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFound, InvalidTransition) as e:
|
||||
logger.error("approve-external: state transition raced: %s", e)
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
|
||||
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s "
|
||||
"subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO, claims.subject_email, claims.subject_issuer, mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"token_id": str(mint.token_id),
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["apps:run"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
@ -14,7 +14,7 @@ api = ExternalApi(
|
||||
|
||||
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
||||
|
||||
from . import index
|
||||
from . import index, oauth
|
||||
from .app import (
|
||||
annotation,
|
||||
app,
|
||||
@ -54,6 +54,7 @@ __all__ = [
|
||||
"message",
|
||||
"metadata",
|
||||
"models",
|
||||
"oauth",
|
||||
"rag_pipeline_workflow",
|
||||
"segment",
|
||||
"site",
|
||||
|
||||
302
api/controllers/service_api/oauth.py
Normal file
302
api/controllers/service_api/oauth.py
Normal file
@ -0,0 +1,302 @@
|
||||
"""``/v1`` OAuth bearer + device-flow endpoints. ``/me`` and self-revoke
|
||||
are bearer-authed; the device-flow trio (code/token/lookup) is public —
|
||||
code/token per RFC 8628, lookup so the /device page can pre-validate
|
||||
before the user has a console session.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import g, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy import update
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
TOKEN_CACHE_KEY_FMT,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_DEVICE_CODE_PER_IP,
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||
from services.oauth_device_flow import (
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
SlowDownDecision,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KNOWN_CLIENT_IDS = frozenset({"difyctl"})
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GET /v1/me
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@service_api_ns.route("/me")
|
||||
class MeApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return {
|
||||
"subject_type": ctx.subject_type,
|
||||
"subject_email": ctx.subject_email,
|
||||
"subject_issuer": ctx.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
}
|
||||
|
||||
account = (
|
||||
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none()
|
||||
if ctx.account_id else None
|
||||
)
|
||||
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return {
|
||||
"subject_type": ctx.subject_type,
|
||||
"subject_email": ctx.subject_email or (account.email if account else None),
|
||||
"account": _account_payload(account) if account else None,
|
||||
"workspaces": [_workspace_payload(m) for m in memberships],
|
||||
"default_workspace_id": default_ws_id,
|
||||
}
|
||||
|
||||
|
||||
def _load_memberships(account_id):
|
||||
return (
|
||||
db.session.query(TenantAccountJoin, Tenant)
|
||||
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.filter(TenantAccountJoin.account_id == account_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _pick_default_workspace(memberships) -> str | None:
|
||||
if not memberships:
|
||||
return None
|
||||
for join, tenant in memberships:
|
||||
if getattr(join, "current", False):
|
||||
return str(tenant.id)
|
||||
return str(memberships[0][1].id)
|
||||
|
||||
|
||||
def _workspace_payload(row) -> dict:
|
||||
join, tenant = row
|
||||
return {"id": str(tenant.id), "name": tenant.name, "role": getattr(join, "role", "")}
|
||||
|
||||
|
||||
def _account_payload(account) -> dict:
|
||||
return {"id": str(account.id), "email": account.email, "name": account.name}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DELETE /v1/oauth/authorizations/self
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@service_api_ns.route("/oauth/authorizations/self")
|
||||
class OAuthAuthorizationsSelfApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; "
|
||||
"use /v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
# Snapshot pre-revoke hash for cache invalidation; UPDATE WHERE
|
||||
# makes double-revoke idempotent.
|
||||
row = (
|
||||
db.session.query(OAuthAccessToken.token_hash)
|
||||
.filter(
|
||||
OAuthAccessToken.id == str(ctx.token_id),
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
pre_revoke_hash = row[0] if row else None
|
||||
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(
|
||||
OAuthAccessToken.id == str(ctx.token_id),
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
db.session.execute(stmt)
|
||||
db.session.commit()
|
||||
|
||||
if pre_revoke_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))
|
||||
|
||||
return {"status": "revoked"}, 200
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# POST /v1/oauth/device/code (unauthenticated — CLI starts a flow)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_code_parser = reqparse.RequestParser()
|
||||
_code_parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
_code_parser.add_argument("device_label", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/oauth/device/code")
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
args = _code_parser.parse_args()
|
||||
client_id = args["client_id"]
|
||||
device_label = args["device_label"]
|
||||
|
||||
if client_id not in KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
ip = extract_remote_ip(request)
|
||||
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": _verification_uri(),
|
||||
"expires_in": expires_in,
|
||||
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
}, 200
|
||||
|
||||
|
||||
def _verification_uri() -> str:
|
||||
from configs import dify_config
|
||||
|
||||
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||
if base:
|
||||
return f"{base.rstrip('/')}/device"
|
||||
return f"{request.host_url.rstrip('/')}/device"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# POST /v1/oauth/device/token (unauthenticated — CLI polls)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_poll_parser = reqparse.RequestParser()
|
||||
_poll_parser.add_argument("device_code", type=str, required=True, location="json")
|
||||
_poll_parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/oauth/device/token")
|
||||
class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
def post(self):
|
||||
args = _poll_parser.parse_args()
|
||||
device_code = args["device_code"]
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
# slow_down beats every other branch — polling-too-fast clients
|
||||
# see only that response regardless of underlying state.
|
||||
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||
return {"error": "slow_down"}, 400
|
||||
|
||||
state = store.load_by_device_code(device_code)
|
||||
if state is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return {"error": "authorization_pending"}, 400
|
||||
|
||||
terminal = store.consume_on_poll(device_code)
|
||||
if terminal is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
_audit_cross_ip_if_needed(state)
|
||||
return poll_payload, 200
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GET /v1/oauth/device/lookup (unauthenticated — /device page pre-validates)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_lookup_parser = reqparse.RequestParser()
|
||||
_lookup_parser.add_argument("user_code", type=str, required=True, location="args")
|
||||
|
||||
|
||||
@service_api_ns.route("/oauth/device/lookup")
|
||||
class OAuthDeviceLookupApi(Resource):
|
||||
"""Read-only — public for pre-validate before login. user_code is
|
||||
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||
"""
|
||||
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
args = _lookup_parser.parse_args()
|
||||
user_code = args["user_code"].strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||
|
||||
_device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||
"client_id": state.client_id,
|
||||
}, 200
|
||||
|
||||
|
||||
def _audit_cross_ip_if_needed(state) -> None:
|
||||
poll_ip = extract_remote_ip(request)
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id, state.created_ip, poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
"creation_ip": state.created_ip,
|
||||
"poll_ip": poll_ip,
|
||||
},
|
||||
)
|
||||
@ -90,6 +90,15 @@ def init_app(app: DifyApp):
|
||||
app.register_blueprint(inner_api_bp)
|
||||
app.register_blueprint(mcp_bp)
|
||||
|
||||
# SSO-branch device-flow routes. No CORS config — these endpoints are
|
||||
# user-interactive (same-origin browser traffic) and cookie-authed;
|
||||
# allowing cross-origin would defeat the SameSite=Lax cookie's purpose.
|
||||
# Gated on ENABLE_OAUTH_BEARER: without the bearer authenticator, tokens
|
||||
# minted here cannot be validated, so skip the blueprint entirely.
|
||||
if dify_config.ENABLE_OAUTH_BEARER:
|
||||
from controllers.oauth_device_sso import bp as oauth_device_sso_bp
|
||||
app.register_blueprint(oauth_device_sso_bp)
|
||||
|
||||
# Register trigger blueprint with CORS for webhook calls
|
||||
_apply_cors_once(
|
||||
trigger_bp,
|
||||
|
||||
@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
||||
"schedule": crontab(minute="0", hour="0"),
|
||||
}
|
||||
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
|
||||
imports.append("schedule.clean_oauth_access_tokens_task")
|
||||
beat_schedule["clean_oauth_access_tokens_task"] = {
|
||||
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
|
||||
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||
imports.append("schedule.workflow_schedule_task")
|
||||
beat_schedule["workflow_schedule_task"] = {
|
||||
|
||||
22
api/extensions/ext_oauth_bearer.py
Normal file
22
api/extensions/ext_oauth_bearer.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""Bind the bearer authenticator at startup. Must run after ext_database
|
||||
and ext_redis (needs both factories).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import build_and_bind
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return dify_config.ENABLE_OAUTH_BEARER
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
# scoped_session isn't a context manager; request teardown closes it.
|
||||
def session_factory():
|
||||
return db.session
|
||||
|
||||
build_and_bind(session_factory=session_factory, redis_client=redis_client)
|
||||
187
api/libs/device_flow_security.py
Normal file
187
api/libs/device_flow_security.py
Normal file
@ -0,0 +1,187 @@
|
||||
"""Device-flow security primitives: enterprise_only gate, approval-grant
|
||||
cookie mint/verify/consume, and anti-framing headers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from flask import Blueprint
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from libs import jws
|
||||
from libs.token import is_secure
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# enterprise_only decorator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_CE_LIKE_STATUSES = {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}
|
||||
|
||||
|
||||
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
|
||||
responses don't consume the bucket.
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in _CE_LIKE_STATUSES:
|
||||
raise NotFound()
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# approval_grant cookie
|
||||
# ============================================================================
|
||||
|
||||
|
||||
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
|
||||
APPROVAL_GRANT_COOKIE_PATH = "/v1/oauth/device"
|
||||
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
|
||||
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
|
||||
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
|
||||
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ApprovalGrantClaims:
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
user_code: str
|
||||
nonce: str
|
||||
csrf_token: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
def mint_approval_grant(
|
||||
*,
|
||||
keyset: jws.KeySet,
|
||||
iss: str,
|
||||
subject_email: str,
|
||||
subject_issuer: str,
|
||||
user_code: str,
|
||||
) -> tuple[str, ApprovalGrantClaims]:
|
||||
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
|
||||
source of truth for Path/HttpOnly/Secure/SameSite.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
nonce = _random_opaque()
|
||||
csrf_token = _random_opaque()
|
||||
|
||||
payload = {
|
||||
"iss": iss,
|
||||
"subject_email": subject_email,
|
||||
"subject_issuer": subject_issuer,
|
||||
"user_code": user_code,
|
||||
"nonce": nonce,
|
||||
"csrf_token": csrf_token,
|
||||
}
|
||||
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
|
||||
return token, ApprovalGrantClaims(
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
user_code=user_code,
|
||||
nonce=nonce,
|
||||
csrf_token=csrf_token,
|
||||
expires_at=exp,
|
||||
)
|
||||
|
||||
|
||||
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
|
||||
"""Sig + aud + exp only — nonce consumption is the caller's job."""
|
||||
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
|
||||
return ApprovalGrantClaims(
|
||||
subject_email=data["subject_email"],
|
||||
subject_issuer=data["subject_issuer"],
|
||||
user_code=data["user_code"],
|
||||
nonce=data["nonce"],
|
||||
csrf_token=data["csrf_token"],
|
||||
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
|
||||
)
|
||||
|
||||
|
||||
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def approval_grant_cookie_kwargs(value: str) -> dict:
|
||||
"""``secure`` follows is_secure() so HTTP-only deployments don't
|
||||
silently drop the cookie.
|
||||
"""
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": value,
|
||||
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def approval_grant_cleared_cookie_kwargs() -> dict:
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": "",
|
||||
"max_age": 0,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def _random_opaque() -> str:
|
||||
return secrets.token_urlsafe(16)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Anti-framing headers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_ANTI_FRAMING_HEADERS = {
|
||||
"X-Frame-Options": "DENY",
|
||||
"Content-Security-Policy": "frame-ancestors 'none'",
|
||||
}
|
||||
|
||||
|
||||
def attach_anti_framing(bp: Blueprint) -> None:
|
||||
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
|
||||
|
||||
@bp.after_request
|
||||
def _apply_headers(response):
|
||||
for name, value in _ANTI_FRAMING_HEADERS.items():
|
||||
response.headers.setdefault(name, value)
|
||||
return response
|
||||
106
api/libs/jws.py
Normal file
106
api/libs/jws.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
|
||||
state envelope, external subject assertion, and approval-grant cookie —
|
||||
all three share one key-set so api ↔ enterprise can verify each other.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from configs import dify_config
|
||||
|
||||
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
|
||||
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
|
||||
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
|
||||
|
||||
ACTIVE_KID_V1 = "dify-shared-v1"
|
||||
|
||||
|
||||
class KeySetError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class KeySet:
|
||||
"""``from_entries`` reserves multi-kid construction for rotation slots."""
|
||||
|
||||
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
|
||||
if active_kid not in entries:
|
||||
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
|
||||
if not entries[active_kid]:
|
||||
raise KeySetError(f"active kid {active_kid!r} has empty secret")
|
||||
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
|
||||
self._active_kid = active_kid
|
||||
|
||||
@classmethod
|
||||
def from_shared_secret(cls) -> "KeySet":
|
||||
secret = dify_config.SECRET_KEY
|
||||
if not secret:
|
||||
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
|
||||
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
|
||||
|
||||
@classmethod
|
||||
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> "KeySet":
|
||||
return cls(entries, active_kid)
|
||||
|
||||
@property
|
||||
def active_kid(self) -> str:
|
||||
return self._active_kid
|
||||
|
||||
def lookup(self, kid: str) -> bytes | None:
|
||||
return self._entries.get(kid)
|
||||
|
||||
|
||||
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
|
||||
"""``iat`` + ``exp`` are injected here; callers must not set them."""
|
||||
if "aud" in payload or "iat" in payload or "exp" in payload:
|
||||
raise ValueError("reserved claim present in payload (aud/iat/exp)")
|
||||
if ttl_seconds <= 0:
|
||||
raise ValueError("ttl_seconds must be positive")
|
||||
|
||||
kid = keyset.active_kid
|
||||
secret = keyset.lookup(kid)
|
||||
if secret is None:
|
||||
raise KeySetError(f"active kid {kid!r} lookup miss")
|
||||
|
||||
iat = datetime.now(UTC)
|
||||
exp = iat + timedelta(seconds=ttl_seconds)
|
||||
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
|
||||
return jwt.encode(
|
||||
claims,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"kid": kid, "typ": "JWT"},
|
||||
)
|
||||
|
||||
|
||||
class VerifyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
|
||||
"""Unknown kid is rejected — never fall back to the active kid, since
|
||||
a past kid value would otherwise be forgeable by anyone who saw it.
|
||||
"""
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except jwt.PyJWTError as e:
|
||||
raise VerifyError(f"decode header: {e}") from e
|
||||
kid = header.get("kid")
|
||||
if not kid:
|
||||
raise VerifyError("no kid in header")
|
||||
secret = keyset.lookup(kid)
|
||||
if secret is None:
|
||||
raise VerifyError(f"unknown kid {kid!r}")
|
||||
try:
|
||||
return jwt.decode(
|
||||
token,
|
||||
secret,
|
||||
algorithms=["HS256"],
|
||||
audience=expected_aud,
|
||||
)
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
raise VerifyError("token expired") from e
|
||||
except jwt.InvalidAudienceError as e:
|
||||
raise VerifyError("aud mismatch") from e
|
||||
except jwt.PyJWTError as e:
|
||||
raise VerifyError(f"decode: {e}") from e
|
||||
425
api/libs/oauth_bearer.py
Normal file
425
api/libs/oauth_bearer.py
Normal file
@ -0,0 +1,425 @@
|
||||
"""OAuth bearer primitives.
|
||||
|
||||
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
|
||||
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
|
||||
Authenticator + validate_bearer stay untouched.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import Callable, Iterable, Literal, Protocol
|
||||
|
||||
from flask import g, request
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
|
||||
|
||||
from models import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Contract — types, enums, protocols
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SubjectType(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
EXTERNAL_SSO = "external_sso"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AuthContext:
|
||||
"""Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source``
|
||||
come from the TokenKind, not the DB — corrupt rows can't elevate scope.
|
||||
"""
|
||||
|
||||
subject_type: SubjectType
|
||||
subject_email: str | None
|
||||
subject_issuer: str | None
|
||||
account_id: uuid.UUID | None
|
||||
scopes: frozenset[str]
|
||||
token_id: uuid.UUID
|
||||
source: str
|
||||
expires_at: datetime | None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ResolvedRow:
|
||||
subject_email: str | None
|
||||
subject_issuer: str | None
|
||||
account_id: uuid.UUID | None
|
||||
token_id: uuid.UUID
|
||||
expires_at: datetime | None
|
||||
|
||||
|
||||
class Resolver(Protocol):
|
||||
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TokenKind:
|
||||
prefix: str
|
||||
subject_type: SubjectType
|
||||
scopes: frozenset[str]
|
||||
source: str
|
||||
resolver: Resolver
|
||||
|
||||
def matches(self, token: str) -> bool:
|
||||
return token.startswith(self.prefix)
|
||||
|
||||
|
||||
class InvalidBearer(Exception):
|
||||
"""Token missing, unknown prefix, or no live row."""
|
||||
|
||||
|
||||
class TokenExpired(Exception):
|
||||
"""Hard-expire bookkeeping is the resolver's job before raising."""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Registry
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenKindRegistry:
|
||||
def __init__(self, kinds: Iterable[TokenKind]) -> None:
|
||||
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
|
||||
prefixes = [k.prefix for k in self._kinds]
|
||||
if len(set(prefixes)) != len(prefixes):
|
||||
raise ValueError(f"duplicate prefix in registry: {prefixes}")
|
||||
|
||||
def find(self, token: str) -> TokenKind | None:
|
||||
for k in self._kinds:
|
||||
if k.matches(token):
|
||||
return k
|
||||
return None
|
||||
|
||||
def kinds(self) -> tuple[TokenKind, ...]:
|
||||
return self._kinds
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authenticator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def sha256_hex(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class BearerAuthenticator:
|
||||
def __init__(self, registry: TokenKindRegistry) -> None:
|
||||
self._registry = registry
|
||||
|
||||
def authenticate(self, token: str) -> AuthContext:
|
||||
kind = self._registry.find(token)
|
||||
if kind is None:
|
||||
raise InvalidBearer("unknown token prefix")
|
||||
row = kind.resolver.resolve(sha256_hex(token))
|
||||
if row is None:
|
||||
raise InvalidBearer("token unknown or revoked")
|
||||
return AuthContext(
|
||||
subject_type=kind.subject_type,
|
||||
subject_email=row.subject_email,
|
||||
subject_issuer=row.subject_issuer,
|
||||
account_id=row.account_id,
|
||||
scopes=kind.scopes,
|
||||
token_id=row.token_id,
|
||||
source=kind.source,
|
||||
expires_at=row.expires_at,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth access token resolver (PAT resolver would be a sibling class)
|
||||
# ============================================================================
|
||||
|
||||
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
|
||||
POSITIVE_TTL_SECONDS = 60
|
||||
NEGATIVE_TTL_SECONDS = 10
|
||||
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
|
||||
|
||||
ScopeVariant = Literal["account", "external_sso"]
|
||||
|
||||
|
||||
class OAuthAccessTokenResolver:
|
||||
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
|
||||
sharing DB + cache plumbing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory,
|
||||
redis_client,
|
||||
positive_ttl: int = POSITIVE_TTL_SECONDS,
|
||||
negative_ttl: int = NEGATIVE_TTL_SECONDS,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._redis = redis_client
|
||||
self._positive_ttl = positive_ttl
|
||||
self._negative_ttl = negative_ttl
|
||||
|
||||
def for_account(self) -> Resolver:
|
||||
return _VariantResolver(self, variant="account")
|
||||
|
||||
def for_external_sso(self) -> Resolver:
|
||||
return _VariantResolver(self, variant="external_sso")
|
||||
|
||||
def _cache_key(self, token_hash: str) -> str:
|
||||
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||
|
||||
def _cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
|
||||
raw = self._redis.get(self._cache_key(token_hash))
|
||||
if raw is None:
|
||||
return None
|
||||
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
if text == "invalid":
|
||||
return "invalid"
|
||||
try:
|
||||
data = json.loads(text)
|
||||
return _row_from_cache(data)
|
||||
except (ValueError, KeyError):
|
||||
logger.warning("auth:token cache entry malformed; treating as miss")
|
||||
return None
|
||||
|
||||
def _cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
|
||||
self._redis.setex(
|
||||
self._cache_key(token_hash),
|
||||
self._positive_ttl,
|
||||
json.dumps(_row_to_cache(row)),
|
||||
)
|
||||
|
||||
def _cache_set_negative(self, token_hash: str) -> None:
|
||||
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
|
||||
|
||||
def _hard_expire(self, session: Session, row_id: uuid.UUID, token_hash: str) -> None:
|
||||
"""Atomic CAS — only the worker that flips revoked_at emits audit;
|
||||
replays are idempotent. Spec: tokens.md §Detection + hard-expire.
|
||||
"""
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
if result.rowcount == 1:
|
||||
logger.warning(
|
||||
"audit: %s token_id=%s", AUDIT_OAUTH_EXPIRED, row_id,
|
||||
extra={"audit": True, "token_id": str(row_id)},
|
||||
)
|
||||
self._redis.delete(self._cache_key(token_hash))
|
||||
self._cache_set_negative(token_hash)
|
||||
|
||||
|
||||
class _VariantResolver:
|
||||
|
||||
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
|
||||
self._parent = parent
|
||||
self._variant = variant
|
||||
|
||||
def resolve(self, token_hash: str) -> ResolvedRow | None:
|
||||
cached = self._parent._cache_get(token_hash)
|
||||
if cached == "invalid":
|
||||
return None
|
||||
if cached is not None and not isinstance(cached, str):
|
||||
if not self._matches_variant(cached):
|
||||
return None
|
||||
return cached
|
||||
|
||||
# _session_factory returns Flask-SQLAlchemy's scoped_session, which is
|
||||
# request-bound and not a context manager; use it directly.
|
||||
session = self._parent._session_factory()
|
||||
row = self._load_from_db(session, token_hash)
|
||||
if row is None:
|
||||
self._parent._cache_set_negative(token_hash)
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
if row.expires_at is not None and row.expires_at <= now:
|
||||
self._parent._hard_expire(session, row.id, token_hash)
|
||||
return None
|
||||
|
||||
if not self._matches_variant_model(row):
|
||||
logger.error(
|
||||
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
|
||||
row.id, row.prefix,
|
||||
)
|
||||
return None
|
||||
|
||||
resolved = ResolvedRow(
|
||||
subject_email=row.subject_email,
|
||||
subject_issuer=row.subject_issuer,
|
||||
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
|
||||
token_id=uuid.UUID(str(row.id)),
|
||||
expires_at=row.expires_at,
|
||||
)
|
||||
self._parent._cache_set_positive(token_hash, resolved)
|
||||
return resolved
|
||||
|
||||
def _matches_variant(self, row: ResolvedRow) -> bool:
|
||||
has_account = row.account_id is not None
|
||||
if self._variant == "account":
|
||||
return has_account
|
||||
return not has_account
|
||||
|
||||
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
|
||||
has_account = row.account_id is not None
|
||||
if self._variant == "account":
|
||||
return has_account and row.prefix == "dfoa_"
|
||||
return (not has_account) and row.prefix == "dfoe_"
|
||||
|
||||
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
|
||||
return (
|
||||
session.query(OAuthAccessToken)
|
||||
.filter(
|
||||
OAuthAccessToken.token_hash == token_hash,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def _row_to_cache(row: ResolvedRow) -> dict:
|
||||
return {
|
||||
"subject_email": row.subject_email,
|
||||
"subject_issuer": row.subject_issuer,
|
||||
"account_id": str(row.account_id) if row.account_id else None,
|
||||
"token_id": str(row.token_id),
|
||||
"expires_at": row.expires_at.isoformat() if row.expires_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _row_from_cache(data: dict) -> ResolvedRow:
|
||||
return ResolvedRow(
|
||||
subject_email=data["subject_email"],
|
||||
subject_issuer=data["subject_issuer"],
|
||||
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
|
||||
token_id=uuid.UUID(data["token_id"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decorator — route-level bearer gate
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Accepts(StrEnum):
|
||||
USER_ACCOUNT = "user_account"
|
||||
USER_EXT_SSO = "user_ext_sso"
|
||||
APP = "app"
|
||||
|
||||
|
||||
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
|
||||
|
||||
|
||||
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
|
||||
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
|
||||
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
|
||||
}
|
||||
|
||||
|
||||
_authenticator: BearerAuthenticator | None = None
|
||||
|
||||
|
||||
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
|
||||
global _authenticator
|
||||
_authenticator = authenticator
|
||||
|
||||
|
||||
def get_authenticator() -> BearerAuthenticator:
|
||||
if _authenticator is None:
|
||||
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
|
||||
return _authenticator
|
||||
|
||||
|
||||
def _extract_bearer(req) -> str | None:
|
||||
header = req.headers.get("Authorization", "")
|
||||
scheme, _, value = header.partition(" ")
|
||||
if scheme.lower() != "bearer" or not value:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable:
|
||||
"""Opt-in: omitting it leaves the route unauthenticated.
|
||||
|
||||
Coexists with legacy ``app-`` keys (tenant+app scoped, resolved in
|
||||
``service_api/wraps.py``) and user-level OAuth bearers (resolved here).
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable) -> Callable:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
token = _extract_bearer(request)
|
||||
if token is None:
|
||||
raise Unauthorized("missing bearer token")
|
||||
|
||||
# app- keys bypass the OAuth authenticator (work even when disabled).
|
||||
if token.startswith("app-"):
|
||||
if Accepts.APP not in accept:
|
||||
raise Unauthorized("app-scoped keys not accepted here")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
if _authenticator is None:
|
||||
raise ServiceUnavailable(
|
||||
"bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable"
|
||||
)
|
||||
|
||||
try:
|
||||
ctx = get_authenticator().authenticate(token)
|
||||
except InvalidBearer as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
|
||||
raise Forbidden("token subject type not accepted here")
|
||||
|
||||
g.auth_ctx = ctx
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Wiring — called once from the app factory
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
|
||||
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
|
||||
return TokenKindRegistry([
|
||||
TokenKind(
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({"full"}),
|
||||
source="oauth_account",
|
||||
resolver=oauth.for_account(),
|
||||
),
|
||||
TokenKind(
|
||||
prefix="dfoe_",
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
scopes=frozenset({"apps:run"}),
|
||||
source="oauth_external_sso",
|
||||
resolver=oauth.for_external_sso(),
|
||||
),
|
||||
])
|
||||
|
||||
|
||||
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
|
||||
registry = build_registry(session_factory, redis_client)
|
||||
auth = BearerAuthenticator(registry)
|
||||
bind_authenticator(auth)
|
||||
return auth
|
||||
109
api/libs/rate_limit.py
Normal file
109
api/libs/rate_limit.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
|
||||
window Redis ZSET). Apply after auth decorators so scopes can read
|
||||
``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed
|
||||
in-handler. RFC-8628 ``slow_down`` is inline — its response shape isn't
|
||||
generic 429. Spec: docs/specs/v1.0/server/security.md.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from flask import g, request, session
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
|
||||
|
||||
class RateLimitScope(StrEnum):
|
||||
IP = "ip"
|
||||
SESSION = "session"
|
||||
ACCOUNT = "account"
|
||||
SUBJECT_EMAIL = "subject_email"
|
||||
TOKEN_ID = "token_id"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimit:
|
||||
limit: int
|
||||
window: timedelta
|
||||
scopes: tuple[RateLimitScope, ...]
|
||||
|
||||
|
||||
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
|
||||
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
|
||||
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
|
||||
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||
|
||||
|
||||
def _one_key(scope: RateLimitScope) -> str:
|
||||
match scope:
|
||||
case RateLimitScope.IP:
|
||||
return f"ip:{extract_remote_ip(request) or 'unknown'}"
|
||||
case RateLimitScope.SESSION:
|
||||
return f"session:{session.get('_id', 'anon')}"
|
||||
case RateLimitScope.ACCOUNT:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.account_id:
|
||||
return f"account:{ctx.account_id}"
|
||||
return "account:anon"
|
||||
case RateLimitScope.SUBJECT_EMAIL:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.subject_email:
|
||||
return f"subject:{ctx.subject_email}"
|
||||
return "subject:anon"
|
||||
case RateLimitScope.TOKEN_ID:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.token_id:
|
||||
return f"token:{ctx.token_id}"
|
||||
return "token:anon"
|
||||
|
||||
|
||||
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||
return "|".join(_one_key(s) for s in scopes)
|
||||
|
||||
|
||||
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||
return "rl:" + "+".join(s.value for s in scopes)
|
||||
|
||||
|
||||
def _build_limiter(spec: RateLimit) -> RateLimiter:
|
||||
return RateLimiter(
|
||||
prefix=_limiter_prefix(spec.scopes),
|
||||
max_attempts=spec.limit,
|
||||
time_window=int(spec.window.total_seconds()),
|
||||
)
|
||||
|
||||
|
||||
def rate_limit(spec: RateLimit) -> Callable:
|
||||
"""Apply after auth decorators that the scopes read from."""
|
||||
limiter = _build_limiter(spec)
|
||||
|
||||
def wrap(fn: Callable) -> Callable:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
key = _composite_key(spec.scopes)
|
||||
if limiter.is_rate_limited(key):
|
||||
raise TooManyRequests("rate_limited")
|
||||
limiter.increment_rate_limit(key)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def enforce(spec: RateLimit, *, key: str) -> None:
|
||||
"""Imperative form — caller composes the bucket key to match scope
|
||||
semantics (the key is opaque here).
|
||||
"""
|
||||
limiter = _build_limiter(spec)
|
||||
if limiter.is_rate_limited(key):
|
||||
raise TooManyRequests("rate_limited")
|
||||
limiter.increment_rate_limit(key)
|
||||
@ -0,0 +1,102 @@
|
||||
"""add oauth_access_tokens table
|
||||
|
||||
Revision ID: d4a5e1f3c9b7
|
||||
Revises: 227822d22895, b69ca54b9208, 2a3aebbbf4bb
|
||||
Create Date: 2026-04-23 22:00:00.000000
|
||||
|
||||
Merges the three open heads at time of authoring (add_workflow_comments_table,
|
||||
add_chatbot_color_theme, add_app_tracing) into a single parent so the new
|
||||
oauth_access_tokens table sits on a definite linear chain thereafter.
|
||||
|
||||
Table stores user-level OAuth bearer tokens minted via the device-flow grant
|
||||
(difyctl auth login). PAT storage (personal_access_tokens) is a separate
|
||||
table not added in this migration.
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d4a5e1f3c9b7"
|
||||
down_revision = ("227822d22895", "b69ca54b9208", "2a3aebbbf4bb")
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"oauth_access_tokens",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column("subject_email", sa.Text(), nullable=False),
|
||||
sa.Column("subject_issuer", sa.Text(), nullable=True),
|
||||
sa.Column("account_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("device_label", sa.Text(), nullable=False),
|
||||
sa.Column("prefix", sa.String(length=8), nullable=False),
|
||||
sa.Column("token_hash", sa.String(length=64), nullable=True, unique=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_used_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
|
||||
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["account_id"],
|
||||
["accounts.id"],
|
||||
name="fk_oauth_access_tokens_account_id",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"idx_oauth_subject_email",
|
||||
"oauth_access_tokens",
|
||||
["subject_email"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_account",
|
||||
"oauth_access_tokens",
|
||||
["account_id"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL AND account_id IS NOT NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_client",
|
||||
"oauth_access_tokens",
|
||||
["subject_email", "client_id"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_token_hash",
|
||||
"oauth_access_tokens",
|
||||
["token_hash"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
# Partial unique index — rotate-in-place keyed on (subject, client, device).
|
||||
# subject_issuer NULL vs populated distinguishes account vs external-SSO rows
|
||||
# for the same email, because Postgres treats NULL as distinct.
|
||||
op.create_index(
|
||||
"uq_oauth_active_per_device",
|
||||
"oauth_access_tokens",
|
||||
["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("uq_oauth_active_per_device", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_token_hash", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_client", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_account", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_subject_email", table_name="oauth_access_tokens")
|
||||
op.drop_table("oauth_access_tokens")
|
||||
@ -73,7 +73,7 @@ from .model import (
|
||||
TrialApp,
|
||||
UploadFile,
|
||||
)
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider, OAuthAccessToken
|
||||
from .provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
@ -177,6 +177,7 @@ __all__ = [
|
||||
"MessageChain",
|
||||
"MessageFeedback",
|
||||
"MessageFile",
|
||||
"OAuthAccessToken",
|
||||
"OperationLog",
|
||||
"PinnedConversation",
|
||||
"Provider",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func
|
||||
@ -84,3 +84,38 @@ class DatasourceOauthTenantParamConfig(TypeBase):
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessToken(TypeBase):
|
||||
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account);
|
||||
account_id NULL + subject_issuer ⇒ dfoe_ (external SSO, EE-only).
|
||||
Partial unique index on (subject_email, subject_issuer, client_id,
|
||||
device_label) WHERE revoked_at IS NULL lets re-login rotate in place.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_access_tokens"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
subject_email: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False)
|
||||
device_label: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
|
||||
subject_issuer: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
account_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
token_hash: Mapped[Optional[str]] = mapped_column(sa.String(64), nullable=True, default=None)
|
||||
last_used_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
revoked_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
|
||||
)
|
||||
|
||||
57
api/schedule/clean_oauth_access_tokens_task.py
Normal file
57
api/schedule/clean_oauth_access_tokens_task.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""DELETE oauth_access_tokens past retention. Revocation is UPDATE
|
||||
(token_id stays for audits) so rows accumulate across re-logins, and
|
||||
expired-but-never-presented rows have no hard-expire trigger — both get
|
||||
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import click
|
||||
from sqlalchemy import delete, or_, select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.oauth import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DELETE_BATCH_SIZE = 500
|
||||
|
||||
|
||||
@app.celery.task(queue="retention")
|
||||
def clean_oauth_access_tokens_task():
|
||||
click.echo(click.style("Start clean oauth_access_tokens.", fg="green"))
|
||||
retention_days = int(dify_config.OAUTH_ACCESS_TOKEN_RETENTION_DAYS)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=retention_days)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
candidates = or_(
|
||||
OAuthAccessToken.revoked_at < cutoff,
|
||||
# Zombies: expired but never re-presented, so middleware never flipped them.
|
||||
(OAuthAccessToken.revoked_at.is_(None))
|
||||
& (OAuthAccessToken.expires_at < cutoff),
|
||||
)
|
||||
|
||||
total = 0
|
||||
while True:
|
||||
ids = db.session.scalars(
|
||||
select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)
|
||||
).all()
|
||||
if not ids:
|
||||
break
|
||||
db.session.execute(
|
||||
delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids))
|
||||
)
|
||||
db.session.commit()
|
||||
total += len(ids)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(click.style(
|
||||
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d "
|
||||
f"in {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
))
|
||||
@ -106,6 +106,15 @@ class EnterpriseService:
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
@classmethod
|
||||
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
|
||||
return EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/device-flow/sso-initiate",
|
||||
json={"signed_state": signed_state},
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
|
||||
"""
|
||||
|
||||
417
api/services/oauth_device_flow.py
Normal file
417
api/services/oauth_device_flow.py
Normal file
@ -0,0 +1,417 @@
|
||||
"""Device-flow service layer: Redis state machine, OAuth token mint
|
||||
(DB upsert + plaintext generation), and TTL policy. Specs:
|
||||
docs/specs/v1.0/server/{device-flow.md, tokens.md}.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
|
||||
from models.oauth import OAuthAccessToken
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Redis state machine — device_code + user_code ephemeral state
|
||||
# ============================================================================
|
||||
|
||||
|
||||
DEVICE_CODE_KEY_FMT = "device_code:{code}"
|
||||
USER_CODE_KEY_FMT = "user_code:{code}"
|
||||
|
||||
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
|
||||
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
|
||||
|
||||
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
|
||||
USER_CODE_SEGMENT_LEN = 4
|
||||
USER_CODE_MAX_CLAIM_ATTEMPTS = 5
|
||||
|
||||
DEFAULT_POLL_INTERVAL_SECONDS = 5 # RFC 8628 minimum
|
||||
|
||||
|
||||
class DeviceFlowStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
|
||||
|
||||
class SlowDownDecision(StrEnum):
|
||||
OK = "ok"
|
||||
SLOW_DOWN = "slow_down"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceFlowState:
|
||||
"""``minted_token`` is plaintext between approve and the next poll;
|
||||
DEL'd after the poll reads it.
|
||||
"""
|
||||
|
||||
user_code: str
|
||||
client_id: str
|
||||
device_label: str
|
||||
status: DeviceFlowStatus
|
||||
subject_email: str | None = None
|
||||
account_id: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
minted_token: str | None = None
|
||||
token_id: str | None = None
|
||||
created_at: str = ""
|
||||
created_ip: str = ""
|
||||
last_poll_at: str = ""
|
||||
poll_payload: dict | None = field(default=None)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "DeviceFlowState":
|
||||
data = json.loads(raw)
|
||||
if "status" in data:
|
||||
data["status"] = DeviceFlowStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
def _random_device_code() -> str:
|
||||
return "dc_" + secrets.token_urlsafe(24)
|
||||
|
||||
|
||||
def _random_user_code_segment() -> str:
|
||||
return "".join(secrets.choice(USER_CODE_ALPHABET) for _ in range(USER_CODE_SEGMENT_LEN))
|
||||
|
||||
|
||||
def _random_user_code() -> str:
|
||||
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
|
||||
|
||||
|
||||
class StateNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTransition(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UserCodeExhausted(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceFlowRedis:
|
||||
|
||||
def __init__(self, redis_client) -> None:
|
||||
self._redis = redis_client
|
||||
|
||||
def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]:
|
||||
device_code = _random_device_code()
|
||||
user_code = self._claim_user_code(device_code)
|
||||
state = DeviceFlowState(
|
||||
user_code=user_code,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
status=DeviceFlowStatus.PENDING,
|
||||
created_at=datetime.now(UTC).isoformat(),
|
||||
created_ip=created_ip,
|
||||
)
|
||||
self._redis.setex(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
state.to_json(),
|
||||
)
|
||||
return device_code, user_code, DEVICE_FLOW_TTL_SECONDS
|
||||
|
||||
def _claim_user_code(self, device_code: str) -> str:
|
||||
for _ in range(USER_CODE_MAX_CLAIM_ATTEMPTS):
|
||||
user_code = _random_user_code()
|
||||
key = USER_CODE_KEY_FMT.format(code=user_code)
|
||||
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
|
||||
if ok:
|
||||
return user_code
|
||||
raise UserCodeExhausted("could not allocate a unique user_code in 5 attempts")
|
||||
|
||||
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
|
||||
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
|
||||
if not raw_dc:
|
||||
return None
|
||||
device_code = raw_dc.decode() if isinstance(raw_dc, (bytes, bytearray)) else raw_dc
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
return None
|
||||
return device_code, state
|
||||
|
||||
def load_by_device_code(self, device_code: str) -> DeviceFlowState | None:
|
||||
return self._load_state(device_code)
|
||||
|
||||
def _load_state(self, device_code: str) -> DeviceFlowState | None:
|
||||
raw = self._redis.get(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||
if not raw:
|
||||
return None
|
||||
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
try:
|
||||
return DeviceFlowState.from_json(text_)
|
||||
except (ValueError, KeyError):
|
||||
logger.error("device_flow: corrupt state for %s", device_code)
|
||||
return None
|
||||
|
||||
def approve(
|
||||
self,
|
||||
device_code: str,
|
||||
subject_email: str,
|
||||
account_id: str | None,
|
||||
minted_token: str,
|
||||
token_id: str,
|
||||
subject_issuer: str | None = None,
|
||||
poll_payload: dict | None = None,
|
||||
) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFound(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransition(f"cannot approve {state.status}")
|
||||
|
||||
state.status = DeviceFlowStatus.APPROVED
|
||||
state.subject_email = subject_email
|
||||
state.account_id = account_id
|
||||
state.subject_issuer = subject_issuer
|
||||
state.minted_token = minted_token
|
||||
state.token_id = token_id
|
||||
state.poll_payload = poll_payload
|
||||
|
||||
new_ttl = self._remaining_ttl(device_code, floor=APPROVED_TTL_SECONDS_MIN)
|
||||
self._redis.setex(DEVICE_CODE_KEY_FMT.format(code=device_code), new_ttl, state.to_json())
|
||||
|
||||
def deny(self, device_code: str) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFound(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransition(f"cannot deny {state.status}")
|
||||
state.status = DeviceFlowStatus.DENIED
|
||||
self._redis.setex(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
self._remaining_ttl(device_code, floor=1),
|
||||
state.to_json(),
|
||||
)
|
||||
|
||||
def consume_on_poll(self, device_code: str) -> DeviceFlowState | None:
|
||||
"""Race-safe via DEL: concurrent polls — one wins, the other gets
|
||||
None and the caller maps that to expired_token.
|
||||
"""
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
return None
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return None
|
||||
self._redis.delete(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
USER_CODE_KEY_FMT.format(code=state.user_code),
|
||||
)
|
||||
return state
|
||||
|
||||
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
|
||||
now = time.time()
|
||||
key = f"device_code:{device_code}:last_poll"
|
||||
prev_raw = self._redis.get(key)
|
||||
self._redis.setex(key, DEVICE_FLOW_TTL_SECONDS, str(now))
|
||||
if prev_raw is None:
|
||||
return SlowDownDecision.OK
|
||||
prev_s = prev_raw.decode() if isinstance(prev_raw, (bytes, bytearray)) else prev_raw
|
||||
try:
|
||||
prev = float(prev_s)
|
||||
except ValueError:
|
||||
return SlowDownDecision.OK
|
||||
if now - prev < interval_seconds:
|
||||
return SlowDownDecision.SLOW_DOWN
|
||||
return SlowDownDecision.OK
|
||||
|
||||
def _remaining_ttl(self, device_code: str, floor: int) -> int:
|
||||
"""``max(remaining, floor)`` — guarantees the CLI has at least
|
||||
``floor`` seconds to poll after a near-expiry approve.
|
||||
"""
|
||||
ttl = self._redis.ttl(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||
if ttl is None or ttl < 0:
|
||||
return floor
|
||||
return max(int(ttl), floor)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token mint — generate + upsert
|
||||
# ============================================================================
|
||||
|
||||
|
||||
OAUTH_BODY_BYTES = 32 # ~256 bits entropy
|
||||
PREFIX_OAUTH_ACCOUNT = "dfoa_"
|
||||
PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MintResult:
|
||||
"""Plaintext token surfaces to the caller once."""
|
||||
token: str
|
||||
token_id: uuid.UUID
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class UpsertOutcome:
|
||||
token_id: uuid.UUID
|
||||
rotated: bool
|
||||
old_hash: str | None
|
||||
|
||||
|
||||
def generate_token(prefix: str) -> str:
|
||||
return prefix + secrets.token_urlsafe(OAUTH_BODY_BYTES)
|
||||
|
||||
|
||||
def sha256_hex(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def mint_oauth_token(
|
||||
session: Session,
|
||||
redis_client,
|
||||
*,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
account_id: str | None,
|
||||
client_id: str,
|
||||
device_label: str,
|
||||
prefix: str,
|
||||
ttl_days: int,
|
||||
) -> MintResult:
|
||||
"""Live row rotates in place via partial unique index
|
||||
``uq_oauth_active_per_device``; hard-expired rows are excluded by the
|
||||
index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is
|
||||
deleted so stale AuthContext drops immediately.
|
||||
"""
|
||||
if prefix not in (PREFIX_OAUTH_ACCOUNT, PREFIX_OAUTH_EXTERNAL_SSO):
|
||||
raise ValueError(f"unknown oauth prefix: {prefix!r}")
|
||||
|
||||
token = generate_token(prefix)
|
||||
new_hash = sha256_hex(token)
|
||||
expires_at = datetime.now(UTC) + timedelta(days=ttl_days)
|
||||
|
||||
outcome = _upsert(
|
||||
session,
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
account_id=account_id,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
prefix=prefix,
|
||||
new_hash=new_hash,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if outcome.rotated and outcome.old_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=outcome.old_hash))
|
||||
|
||||
return MintResult(token=token, token_id=outcome.token_id, expires_at=expires_at)
|
||||
|
||||
|
||||
def _upsert(
|
||||
session: Session,
|
||||
*,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
account_id: str | None,
|
||||
client_id: str,
|
||||
device_label: str,
|
||||
prefix: str,
|
||||
new_hash: str,
|
||||
expires_at: datetime,
|
||||
) -> UpsertOutcome:
|
||||
# Snapshot prior live row's hash for Redis invalidation post-rotate.
|
||||
prior = session.execute(
|
||||
select(OAuthAccessToken.id, OAuthAccessToken.token_hash)
|
||||
.where(
|
||||
OAuthAccessToken.subject_email == subject_email,
|
||||
OAuthAccessToken.subject_issuer.is_not_distinct_from(subject_issuer),
|
||||
OAuthAccessToken.client_id == client_id,
|
||||
OAuthAccessToken.device_label == device_label,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
old_hash = prior.token_hash if prior else None
|
||||
|
||||
insert_stmt = pg_insert(OAuthAccessToken).values(
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
account_id=account_id,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
prefix=prefix,
|
||||
token_hash=new_hash,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
upsert_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||
index_where=OAuthAccessToken.revoked_at.is_(None),
|
||||
set_={
|
||||
"token_hash": insert_stmt.excluded.token_hash,
|
||||
"prefix": insert_stmt.excluded.prefix,
|
||||
"account_id": insert_stmt.excluded.account_id,
|
||||
"expires_at": insert_stmt.excluded.expires_at,
|
||||
"created_at": func.now(),
|
||||
"last_used_at": None,
|
||||
},
|
||||
).returning(OAuthAccessToken.id)
|
||||
row = session.execute(upsert_stmt).first()
|
||||
session.commit()
|
||||
|
||||
token_id = uuid.UUID(str(row.id))
|
||||
return UpsertOutcome(
|
||||
token_id=token_id,
|
||||
rotated=prior is not None,
|
||||
old_hash=old_hash,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TTL policy — days new OAuth tokens live
|
||||
# ============================================================================
|
||||
|
||||
|
||||
DEFAULT_OAUTH_TTL_DAYS = 14
|
||||
MIN_TTL_DAYS = 1
|
||||
MAX_TTL_DAYS = 365
|
||||
|
||||
_TTL_ENV_VAR = "OAUTH_TTL_DAYS"
|
||||
|
||||
|
||||
def oauth_ttl_days(tenant_id: str | None = None) -> int:
|
||||
"""``OAUTH_TTL_DAYS`` env, else default. EE tenant-level lookup
|
||||
is deferred; when it lands it wins over the env (Redis-cached 60s).
|
||||
"""
|
||||
_ = tenant_id
|
||||
|
||||
raw = os.environ.get(_TTL_ENV_VAR)
|
||||
if raw is None:
|
||||
return DEFAULT_OAUTH_TTL_DAYS
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"%s=%r is not an int; falling back to %d",
|
||||
_TTL_ENV_VAR, raw, DEFAULT_OAUTH_TTL_DAYS,
|
||||
)
|
||||
return DEFAULT_OAUTH_TTL_DAYS
|
||||
if value < MIN_TTL_DAYS:
|
||||
logger.warning("%s=%d below min %d; clamping", _TTL_ENV_VAR, value, MIN_TTL_DAYS)
|
||||
return MIN_TTL_DAYS
|
||||
if value > MAX_TTL_DAYS:
|
||||
logger.warning("%s=%d above max %d; clamping", _TTL_ENV_VAR, value, MAX_TTL_DAYS)
|
||||
return MAX_TTL_DAYS
|
||||
return value
|
||||
96
web/app/device/components/authorize-account.tsx
Normal file
96
web/app/device/components/authorize-account.tsx
Normal file
@ -0,0 +1,96 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useState } from 'react'
|
||||
import { deviceApproveAccount, deviceDenyAccount } from '@/service/device-flow'
|
||||
|
||||
type Props = {
|
||||
userCode: string
|
||||
accountEmail?: string
|
||||
defaultWorkspace?: string
|
||||
onApproved: () => void
|
||||
onDenied: () => void
|
||||
onError: (message: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* AuthorizeAccount is the account-branch authorize screen. Called with a
|
||||
* live console session already established (user bounced through /signin).
|
||||
* Posts to /console/api/oauth/device/{approve,deny}; these endpoints mint
|
||||
* the dfoa_ token server-side.
|
||||
*/
|
||||
const AuthorizeAccount: FC<Props> = ({
|
||||
userCode, accountEmail, defaultWorkspace, onApproved, onDenied, onError,
|
||||
}) => {
|
||||
const [busy, setBusy] = useState(false)
|
||||
|
||||
const approve = async () => {
|
||||
setBusy(true)
|
||||
try {
|
||||
await deviceApproveAccount(userCode)
|
||||
onApproved()
|
||||
}
|
||||
catch (e: any) {
|
||||
onError(e?.message || 'Approve failed')
|
||||
}
|
||||
finally {
|
||||
setBusy(false)
|
||||
}
|
||||
}
|
||||
|
||||
const deny = async () => {
|
||||
setBusy(true)
|
||||
try {
|
||||
await deviceDenyAccount(userCode)
|
||||
onDenied()
|
||||
}
|
||||
catch (e: any) {
|
||||
onError(e?.message || 'Deny failed')
|
||||
}
|
||||
finally {
|
||||
setBusy(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<div>
|
||||
<h2 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h2>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Dify CLI (difyctl) is requesting access to your account.
|
||||
{' '}If you did not start this from your terminal, click Cancel.
|
||||
</p>
|
||||
</div>
|
||||
<div className="rounded-lg border border-components-panel-border bg-components-panel-bg px-4 py-3">
|
||||
{accountEmail && (
|
||||
<p className="text-sm text-text-secondary">
|
||||
Signed in as <span className="font-medium text-text-primary">{accountEmail}</span>
|
||||
</p>
|
||||
)}
|
||||
{defaultWorkspace && (
|
||||
<p className="mt-1 text-sm text-text-secondary">
|
||||
Default workspace: <span className="font-medium text-text-primary">{defaultWorkspace}</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
<button
|
||||
onClick={approve}
|
||||
disabled={busy}
|
||||
className="flex-1 rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover disabled:opacity-50"
|
||||
>
|
||||
Authorize
|
||||
</button>
|
||||
<button
|
||||
onClick={deny}
|
||||
disabled={busy}
|
||||
className="flex-1 rounded-lg border border-components-button-secondary-border bg-components-button-secondary-bg px-4 py-3 text-components-button-secondary-text font-medium hover:bg-components-button-secondary-bg-hover disabled:opacity-50"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default AuthorizeAccount
|
||||
96
web/app/device/components/authorize-sso.tsx
Normal file
96
web/app/device/components/authorize-sso.tsx
Normal file
@ -0,0 +1,96 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import type { ApprovalContext } from '@/service/device-flow'
|
||||
import { approveExternal, fetchApprovalContext } from '@/service/device-flow'
|
||||
|
||||
type Props = {
|
||||
onApproved: () => void
|
||||
onError: (message: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* AuthorizeSSO is the external-SSO branch authorize screen. On mount it
|
||||
* fetches /v1/oauth/device/approval-context to learn subject_email, issuer,
|
||||
* user_code, and csrf_token from the device_approval_grant cookie. On
|
||||
* Approve click, posts /v1/oauth/device/approve-external with the CSRF header.
|
||||
*
|
||||
* The user_code in state is bound to the cookie by server; we do not accept
|
||||
* one from the URL because the SSO branch deliberately detaches from the
|
||||
* pre-SSO ?user_code=... query param.
|
||||
*/
|
||||
const AuthorizeSSO: FC<Props> = ({ onApproved, onError }) => {
|
||||
const [ctx, setCtx] = useState<ApprovalContext | null>(null)
|
||||
const [busy, setBusy] = useState(false)
|
||||
const [loadErr, setLoadErr] = useState<string | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false
|
||||
fetchApprovalContext()
|
||||
.then((c) => { if (!cancelled) setCtx(c) })
|
||||
.catch((e: any) => {
|
||||
if (!cancelled)
|
||||
setLoadErr(e?.message || 'Failed to load session')
|
||||
})
|
||||
return () => { cancelled = true }
|
||||
}, [])
|
||||
|
||||
const approve = async () => {
|
||||
if (!ctx) return
|
||||
setBusy(true)
|
||||
try {
|
||||
await approveExternal(ctx, ctx.user_code)
|
||||
onApproved()
|
||||
}
|
||||
catch (e: any) {
|
||||
onError(e?.message || 'Approve failed')
|
||||
}
|
||||
finally {
|
||||
setBusy(false)
|
||||
}
|
||||
}
|
||||
|
||||
if (loadErr) {
|
||||
return (
|
||||
<div>
|
||||
<h2 className="text-2xl font-semibold text-text-primary">This session is no longer valid</h2>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Run <code className="rounded bg-components-panel-bg px-1">difyctl auth login</code> again to start a new sign-in.
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (!ctx) {
|
||||
return <div className="text-sm text-text-secondary">Loading session…</div>
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<div>
|
||||
<h2 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h2>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Dify CLI (difyctl) is requesting access via SSO. If you did not start
|
||||
this from your terminal, close this tab.
|
||||
</p>
|
||||
</div>
|
||||
<div className="rounded-lg border border-components-panel-border bg-components-panel-bg px-4 py-3">
|
||||
<p className="text-sm text-text-secondary">
|
||||
Signed in as <span className="font-medium text-text-primary">{ctx.subject_email}</span>
|
||||
</p>
|
||||
<p className="mt-1 text-sm text-text-secondary">
|
||||
Issuer: <span className="font-medium text-text-primary">{ctx.subject_issuer}</span>
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={approve}
|
||||
disabled={busy}
|
||||
className="rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover disabled:opacity-50"
|
||||
>
|
||||
Authorize
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default AuthorizeSSO
|
||||
60
web/app/device/components/chooser.tsx
Normal file
60
web/app/device/components/chooser.tsx
Normal file
@ -0,0 +1,60 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
|
||||
type Props = {
|
||||
userCode: string
|
||||
ssoAvailable: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Chooser renders the two-button device-auth login selector. Account button
|
||||
* seeds postLoginRedirect + navigates to /signin so every existing account
|
||||
* login method (password / email-code / social OAuth / account-SSO) flows
|
||||
* through its usual plumbing. SSO button hits /v1/oauth/device/sso-initiate
|
||||
* directly — the SSO branch skips /signin entirely.
|
||||
*
|
||||
* v1.0 scope: only account-SSO honours postLoginRedirect (via sso-auth's
|
||||
* return_to plumbing). Password / email-code / social-OAuth users land on
|
||||
* /signin's default post-login target and manually return to the /device
|
||||
* URL printed by the CLI. That's not great UX; a follow-up milestone
|
||||
* generalises post-signin redirect to all methods.
|
||||
*/
|
||||
const Chooser: FC<Props> = ({ userCode, ssoAvailable }) => {
|
||||
const router = useRouter()
|
||||
|
||||
const onAccount = () => {
|
||||
setPostLoginRedirect(`/device?user_code=${encodeURIComponent(userCode)}`)
|
||||
router.push('/signin')
|
||||
}
|
||||
|
||||
const onSSO = () => {
|
||||
// Full-page navigation, not router.push — /v1/oauth/device/sso-initiate
|
||||
// issues a 302 to the IdP. Next's client router can't follow cross-
|
||||
// origin redirects; a plain window.location assignment handles it.
|
||||
window.location.href = `/v1/oauth/device/sso-initiate?user_code=${encodeURIComponent(userCode)}`
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3">
|
||||
<button
|
||||
onClick={onAccount}
|
||||
className="rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover"
|
||||
>
|
||||
Sign in with Dify account
|
||||
</button>
|
||||
{ssoAvailable && (
|
||||
<button
|
||||
onClick={onSSO}
|
||||
className="rounded-lg border border-components-button-secondary-border bg-components-button-secondary-bg px-4 py-3 text-components-button-secondary-text font-medium hover:bg-components-button-secondary-bg-hover"
|
||||
>
|
||||
Sign in with SSO
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Chooser
|
||||
45
web/app/device/components/code-input.tsx
Normal file
45
web/app/device/components/code-input.tsx
Normal file
@ -0,0 +1,45 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { normaliseUserCodeInput } from '../utils/user-code'
|
||||
|
||||
type Props = {
|
||||
value: string
|
||||
onChange: (normalised: string) => void
|
||||
disabled?: boolean
|
||||
autoFocus?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* CodeInput renders the user_code text field with live normalisation
|
||||
* (uppercase, reduced alphabet, XXXX-XXXX hyphenation).
|
||||
*
|
||||
* The onChange callback receives the normalised value only — the parent does
|
||||
* not need to run validation itself.
|
||||
*/
|
||||
const CodeInput: FC<Props> = ({ value, onChange, disabled, autoFocus }) => {
|
||||
const handle = useCallback((raw: string) => {
|
||||
onChange(normaliseUserCodeInput(raw))
|
||||
}, [onChange])
|
||||
|
||||
return (
|
||||
<input
|
||||
type="text"
|
||||
inputMode="text"
|
||||
autoCapitalize="characters"
|
||||
autoComplete="off"
|
||||
spellCheck={false}
|
||||
placeholder="ABCD-1234"
|
||||
maxLength={9}
|
||||
aria-label="one-time code"
|
||||
className="w-full rounded-lg border border-components-input-border-normal bg-components-input-bg-normal px-4 py-3 text-center text-2xl font-mono tracking-wider text-text-primary focus:border-components-input-border-active focus:outline-none"
|
||||
value={value}
|
||||
disabled={disabled}
|
||||
autoFocus={autoFocus}
|
||||
onChange={e => handle(e.target.value)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default CodeInput
|
||||
173
web/app/device/page.tsx
Normal file
173
web/app/device/page.tsx
Normal file
@ -0,0 +1,173 @@
|
||||
'use client'
|
||||
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useSearchParams } from '@/next/navigation'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import { commonQueryKeys, userProfileQueryOptions } from '@/service/use-common'
|
||||
import { post } from '@/service/base'
|
||||
import type { ICurrentWorkspace } from '@/models/common'
|
||||
import { deviceLookup } from '@/service/device-flow'
|
||||
import CodeInput from './components/code-input'
|
||||
import Chooser from './components/chooser'
|
||||
import AuthorizeAccount from './components/authorize-account'
|
||||
import AuthorizeSSO from './components/authorize-sso'
|
||||
import { isValidUserCode } from './utils/user-code'
|
||||
|
||||
type View =
|
||||
| { kind: 'code_entry' }
|
||||
| { kind: 'chooser'; userCode: string }
|
||||
| { kind: 'authorize_account'; userCode: string }
|
||||
| { kind: 'authorize_sso' }
|
||||
| { kind: 'success' }
|
||||
| { kind: 'error_expired' }
|
||||
|
||||
export default function DevicePage() {
|
||||
const searchParams = useSearchParams()
|
||||
const urlUserCode = (searchParams.get('user_code') || '').trim().toUpperCase()
|
||||
const ssoVerified = searchParams.get('sso_verified') === '1'
|
||||
|
||||
const [typed, setTyped] = useState('')
|
||||
const [view, setView] = useState<View>({ kind: 'code_entry' })
|
||||
const [errMsg, setErrMsg] = useState<string | null>(null)
|
||||
|
||||
// Account subject + workspace identity (for the authorize-account screen).
|
||||
// Logged-out is a valid landing state on /device — disable refetch storms
|
||||
// and skip workspace probe until profile resolves (avoids /current + chained
|
||||
// /refresh-token 401 loops while the user is still entering the code).
|
||||
const { data: userResp, isError: profileErr } = useQuery({
|
||||
...userProfileQueryOptions(),
|
||||
throwOnError: false,
|
||||
retry: false,
|
||||
refetchOnWindowFocus: false,
|
||||
refetchOnMount: false,
|
||||
})
|
||||
const account = userResp?.profile
|
||||
const { data: currentWorkspace } = useQuery<ICurrentWorkspace>({
|
||||
queryKey: commonQueryKeys.currentWorkspace,
|
||||
queryFn: () => post<ICurrentWorkspace>('/workspaces/current'),
|
||||
enabled: !!account && !profileErr,
|
||||
retry: false,
|
||||
refetchOnWindowFocus: false,
|
||||
})
|
||||
const { data: sys } = useQuery(systemFeaturesQueryOptions())
|
||||
// Device-flow SSO branch uses external-user (webapp) SSO, not console SSO —
|
||||
// backend mints EXTERNAL_SSO tokens via Enterprise's external ACS. Gate on
|
||||
// webapp_auth.{enabled, allow_sso} + a configured webapp SSO protocol.
|
||||
const ssoAvailable = !!sys?.webapp_auth?.enabled
|
||||
&& !!sys?.webapp_auth?.allow_sso
|
||||
&& (sys?.webapp_auth?.sso_config?.protocol || '') !== ''
|
||||
|
||||
// URL-driven view transitions. Only advances while the user is still on
|
||||
// the entry/chooser screens — never clobbers terminal views (success /
|
||||
// error_expired / authorize_*) when userProfile refetches.
|
||||
useEffect(() => {
|
||||
if (view.kind !== 'code_entry' && view.kind !== 'chooser') return
|
||||
if (ssoVerified) {
|
||||
setView({ kind: 'authorize_sso' })
|
||||
return
|
||||
}
|
||||
if (urlUserCode && isValidUserCode(urlUserCode)) {
|
||||
if (account)
|
||||
setView({ kind: 'authorize_account', userCode: urlUserCode })
|
||||
else
|
||||
setView({ kind: 'chooser', userCode: urlUserCode })
|
||||
}
|
||||
}, [urlUserCode, ssoVerified, account, view.kind])
|
||||
|
||||
const onContinue = async () => {
|
||||
if (!isValidUserCode(typed)) return
|
||||
try {
|
||||
const reply = await deviceLookup(typed)
|
||||
if (!reply.valid) {
|
||||
setView({ kind: 'error_expired' })
|
||||
return
|
||||
}
|
||||
}
|
||||
catch {
|
||||
setView({ kind: 'error_expired' })
|
||||
return
|
||||
}
|
||||
if (account) setView({ kind: 'authorize_account', userCode: typed })
|
||||
else setView({ kind: 'chooser', userCode: typed })
|
||||
}
|
||||
|
||||
return (
|
||||
<main className="mx-auto flex min-h-screen max-w-lg flex-col items-center justify-center px-6 py-10">
|
||||
<div className="w-full rounded-xl border border-components-panel-border bg-components-panel-bg p-8 shadow-sm">
|
||||
{view.kind === 'code_entry' && (
|
||||
<div className="flex flex-col gap-5">
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Enter the code shown in your terminal.
|
||||
</p>
|
||||
</div>
|
||||
<CodeInput value={typed} onChange={setTyped} autoFocus />
|
||||
<button
|
||||
onClick={onContinue}
|
||||
disabled={!isValidUserCode(typed)}
|
||||
className="rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover disabled:opacity-50"
|
||||
>
|
||||
Continue
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{view.kind === 'chooser' && (
|
||||
<div className="flex flex-col gap-5">
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">Sign in to authorize</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Code <span className="font-mono">{view.userCode}</span> is valid. Choose how to sign in.
|
||||
</p>
|
||||
</div>
|
||||
<Chooser userCode={view.userCode} ssoAvailable={ssoAvailable} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{view.kind === 'authorize_account' && (
|
||||
<AuthorizeAccount
|
||||
userCode={view.userCode}
|
||||
accountEmail={account?.email}
|
||||
defaultWorkspace={currentWorkspace?.name}
|
||||
onApproved={() => setView({ kind: 'success' })}
|
||||
onDenied={() => setView({ kind: 'error_expired' })}
|
||||
onError={e => setErrMsg(e)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{view.kind === 'authorize_sso' && (
|
||||
<AuthorizeSSO
|
||||
onApproved={() => setView({ kind: 'success' })}
|
||||
onError={e => setErrMsg(e)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{view.kind === 'success' && (
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">You're signed in</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">Return to your terminal to continue.</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{view.kind === 'error_expired' && (
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">This code is no longer valid</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
The code may have expired or already been used. Run
|
||||
{' '}
|
||||
<code className="rounded bg-components-panel-bg px-1">difyctl auth login</code>
|
||||
{' '}
|
||||
again to get a new one.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{errMsg && (
|
||||
<p className="mt-4 text-sm text-text-destructive">{errMsg}</p>
|
||||
)}
|
||||
</div>
|
||||
</main>
|
||||
)
|
||||
}
|
||||
37
web/app/device/utils/user-code.ts
Normal file
37
web/app/device/utils/user-code.ts
Normal file
@ -0,0 +1,37 @@
|
||||
// user-code.ts — input normalisation + validation for the RFC 8628
|
||||
// 8-character user_code format the CLI prints to stderr.
|
||||
//
|
||||
// Format: XXXX-XXXX, uppercase, reduced alphabet (no 0/O, 1/I/l, 2/Z). Low
|
||||
// entropy by design — humans type it — so the server-side rate-limit + TTL +
|
||||
// single-use properties are what defend it, not the alphabet.
|
||||
|
||||
export const USER_CODE_ALPHABET = 'ABCDEFGHJKLMNPQRSTUVWXY3456789' // excludes 0 O 1 I L 2 Z
|
||||
|
||||
/**
|
||||
* normaliseUserCodeInput prepares raw input for display in the code field:
|
||||
* strips non-alphanumerics, uppercases, drops disallowed characters, and
|
||||
* inserts the hyphen after the fourth accepted char.
|
||||
*
|
||||
* Returns at most 9 chars ("XXXX-XXXX"); longer input is truncated.
|
||||
*/
|
||||
export function normaliseUserCodeInput(raw: string): string {
|
||||
const cleaned: string[] = []
|
||||
for (const ch of raw.toUpperCase()) {
|
||||
if (USER_CODE_ALPHABET.includes(ch))
|
||||
cleaned.push(ch)
|
||||
if (cleaned.length === 8)
|
||||
break
|
||||
}
|
||||
if (cleaned.length <= 4)
|
||||
return cleaned.join('')
|
||||
return `${cleaned.slice(0, 4).join('')}-${cleaned.slice(4).join('')}`
|
||||
}
|
||||
|
||||
/**
|
||||
* isValidUserCode tests whether the normalised form is a complete XXXX-XXXX
|
||||
* token suitable for submission to /console/api/oauth/device/lookup.
|
||||
*/
|
||||
export function isValidUserCode(normalised: string): boolean {
|
||||
return /^[A-Z0-9]{4}-[A-Z0-9]{4}$/.test(normalised)
|
||||
&& [...normalised.replace('-', '')].every(c => USER_CODE_ALPHABET.includes(c))
|
||||
}
|
||||
@ -1,15 +1,68 @@
|
||||
let postLoginRedirect: string | null = null
|
||||
// Persists target across full-page redirects within the same tab (social
|
||||
// OAuth, SSO IdP bounce). sessionStorage is tab-scoped so concurrent
|
||||
// /device tabs don't clobber each other. 15-min TTL drops stale values.
|
||||
// Same-origin + exact-path whitelist prevents open-redirect.
|
||||
//
|
||||
// Signup-via-email-link opening in a new tab is out of scope — that tab
|
||||
// starts with an empty sessionStorage and falls to /apps default.
|
||||
|
||||
const KEY = 'dify_post_login_redirect'
|
||||
const TTL_MS = 15 * 60 * 1000
|
||||
|
||||
const ALLOWED: Record<string, ReadonlySet<string>> = {
|
||||
'/device': new Set(['user_code', 'sso_verified']),
|
||||
'/account/oauth/authorize': new Set(['client_id', 'scope', 'state', 'redirect_uri']),
|
||||
}
|
||||
|
||||
function validate(target: string): string | null {
|
||||
if (typeof window === 'undefined') return null
|
||||
try {
|
||||
const url = new URL(target, window.location.origin)
|
||||
if (url.origin !== window.location.origin) return null
|
||||
const allowedKeys = ALLOWED[url.pathname]
|
||||
if (!allowedKeys) return null
|
||||
for (const key of url.searchParams.keys()) {
|
||||
if (!allowedKeys.has(key)) return null
|
||||
}
|
||||
return url.pathname + (url.search || '')
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export const setPostLoginRedirect = (value: string | null) => {
|
||||
postLoginRedirect = value
|
||||
}
|
||||
|
||||
export const resolvePostLoginRedirect = () => {
|
||||
if (postLoginRedirect) {
|
||||
const redirectUrl = postLoginRedirect
|
||||
postLoginRedirect = null
|
||||
return redirectUrl
|
||||
if (typeof window === 'undefined') return
|
||||
if (value === null) {
|
||||
try { sessionStorage.removeItem(KEY) } catch {}
|
||||
return
|
||||
}
|
||||
const safe = validate(value)
|
||||
if (!safe) return
|
||||
try {
|
||||
sessionStorage.setItem(KEY, JSON.stringify({ target: safe, ts: Date.now() }))
|
||||
}
|
||||
catch {}
|
||||
}
|
||||
|
||||
export const resolvePostLoginRedirect = (): string | null => {
|
||||
if (typeof window === 'undefined') return null
|
||||
let raw: string | null = null
|
||||
try {
|
||||
raw = sessionStorage.getItem(KEY)
|
||||
sessionStorage.removeItem(KEY)
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
if (!raw) return null
|
||||
try {
|
||||
const parsed = JSON.parse(raw)
|
||||
if (typeof parsed?.target !== 'string' || typeof parsed?.ts !== 'number') return null
|
||||
if (Date.now() - parsed.ts > TTL_MS) return null
|
||||
return validate(parsed.target)
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@ -30,6 +30,20 @@ const nextConfig: NextConfig = {
|
||||
},
|
||||
]
|
||||
},
|
||||
// Anti-framing for device-flow surfaces. A framed /device page could UI-trick
|
||||
// a victim with a valid device_approval_grant cookie into approving a
|
||||
// device_code — functionally CSRF, bypasses the double-submit token. Deny
|
||||
// framing outright on every device-flow route; no trusted embedder exists.
|
||||
async headers() {
|
||||
const antiFrame = [
|
||||
{ key: 'X-Frame-Options', value: 'DENY' },
|
||||
{ key: 'Content-Security-Policy', value: "frame-ancestors 'none'" },
|
||||
]
|
||||
return [
|
||||
{ source: '/device', headers: antiFrame },
|
||||
{ source: '/device/:path*', headers: antiFrame },
|
||||
]
|
||||
},
|
||||
output: 'standalone',
|
||||
compiler: {
|
||||
removeConsole: isDev ? false : { exclude: ['warn', 'error'] },
|
||||
|
||||
@ -794,6 +794,11 @@ export const request = async<T>(url: string, options = {}, otherOptions?: IOther
|
||||
const [refreshErr] = await asyncRunSafe(refreshAccessTokenOrReLogin(TIME_OUT))
|
||||
if (refreshErr === null)
|
||||
return baseFetch<T>(url, options, otherOptionsForBaseFetch)
|
||||
// /device is the device-flow chooser; logged-out is a valid state
|
||||
// there. Redirecting to /signin loses the user_code context and
|
||||
// the post-login flow lands on /apps instead of returning here.
|
||||
if (location.pathname === `${basePath}/device`)
|
||||
return Promise.reject(err)
|
||||
if (location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) {
|
||||
jumpTo(loginUrl)
|
||||
return Promise.reject(err)
|
||||
|
||||
84
web/service/device-flow.ts
Normal file
84
web/service/device-flow.ts
Normal file
@ -0,0 +1,84 @@
|
||||
// Web-side calls into the Dify device-flow endpoints:
|
||||
//
|
||||
// /v1/oauth/device/lookup (public — GET, no auth, IP-rate-limited)
|
||||
// /v1/oauth/device/approval-context (cookie-authed — GET)
|
||||
// /v1/oauth/device/approve-external (cookie-authed + CSRF — POST)
|
||||
// /console/api/oauth/device/approve (session-authed — POST)
|
||||
// /console/api/oauth/device/deny (session-authed — POST)
|
||||
//
|
||||
// Approve/deny use the standard service/base helpers so they get console-
|
||||
// session cookies automatically. Lookup + SSO-branch endpoints sit under
|
||||
// /v1 so they ride the existing service-API gateway route.
|
||||
|
||||
import { del, post } from './base'
|
||||
|
||||
const DEVICE_BASE = '/v1/oauth/device'
|
||||
|
||||
// ----- Account branch --------------------------------------------------------
|
||||
|
||||
export type DeviceLookupReply = {
|
||||
valid: boolean
|
||||
expires_in_remaining: number
|
||||
client_id: string
|
||||
}
|
||||
|
||||
export async function deviceLookup(user_code: string): Promise<DeviceLookupReply> {
|
||||
const res = await fetch(`${DEVICE_BASE}/lookup?user_code=${encodeURIComponent(user_code)}`, {
|
||||
method: 'GET',
|
||||
})
|
||||
if (!res.ok) {
|
||||
const body = await res.text().catch(() => '')
|
||||
throw new Error(`lookup ${res.status}: ${body}`)
|
||||
}
|
||||
return res.json()
|
||||
}
|
||||
|
||||
export const deviceApproveAccount = (user_code: string) =>
|
||||
post<{ status: 'approved' }>('/oauth/device/approve', { body: { user_code } })
|
||||
|
||||
export const deviceDenyAccount = (user_code: string) =>
|
||||
post<{ status: 'denied' }>('/oauth/device/deny', { body: { user_code } })
|
||||
|
||||
// ----- SSO branch (cookie-authed via /v1/oauth/device/*) --------------------
|
||||
|
||||
export type ApprovalContext = {
|
||||
subject_email: string
|
||||
subject_issuer: string
|
||||
user_code: string
|
||||
csrf_token: string
|
||||
expires_at: string
|
||||
}
|
||||
|
||||
export async function fetchApprovalContext(): Promise<ApprovalContext> {
|
||||
const res = await fetch(`${DEVICE_BASE}/approval-context`, {
|
||||
method: 'GET',
|
||||
credentials: 'include',
|
||||
})
|
||||
if (!res.ok) {
|
||||
const body = await res.text().catch(() => '')
|
||||
throw new Error(`approval-context ${res.status}: ${body}`)
|
||||
}
|
||||
return res.json()
|
||||
}
|
||||
|
||||
export async function approveExternal(ctx: ApprovalContext, user_code: string): Promise<void> {
|
||||
const res = await fetch(`${DEVICE_BASE}/approve-external`, {
|
||||
method: 'POST',
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-CSRF-Token': ctx.csrf_token,
|
||||
},
|
||||
body: JSON.stringify({ user_code }),
|
||||
})
|
||||
if (!res.ok) {
|
||||
const body = await res.text().catch(() => '')
|
||||
throw new Error(`approve-external ${res.status}: ${body}`)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- Export for future PAT revoke; noop in v1.0 --------------------------
|
||||
|
||||
// Intentionally left out: personal_access_tokens endpoints are not in this
|
||||
// milestone; see docs/specs/v1.0/README.md.
|
||||
void del // keep import live for the TypeScript linter without surfacing usage
|
||||
Loading…
Reference in New Issue
Block a user