From fe8510ad1a0ff428803f29546a06fc2043fbd813 Mon Sep 17 00:00:00 2001 From: GareArc Date: Sun, 26 Apr 2026 20:06:43 -0700 Subject: [PATCH] feat(api,web): OAuth 2.0 device flow + bearer auth (RFC 8628) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a CLI-friendly authorization flow so difyctl (and future non-browser clients) can obtain user-scoped tokens without copy- pasting cookies or raw API keys. Two grant paths share one device flow surface: 1. Account branch — user signs in via the existing /signin methods, /device page calls console-authed approve, mints a dfoa_ token tied to (account_id, tenant). 2. External-SSO branch (EE) — /v1/oauth/device/sso-initiate signs an SSOState envelope, hands off to Enterprise's external ACS, receives a signed external-subject assertion, mints a dfoe_ token tied to (subject_email, subject_issuer). API surface (all under /v1, EE-only endpoints 404 on CE): POST /v1/oauth/device/code — RFC 8628 start POST /v1/oauth/device/token — RFC 8628 poll GET /v1/oauth/device/lookup — pre-validate user_code GET /v1/oauth/device/sso-initiate — SSO branch entry GET /v1/device/sso-complete — SSO callback sink GET /v1/oauth/device/approval-context — /device cookie probe POST /v1/oauth/device/approve-external — SSO approve GET /v1/me — bearer subject lookup DELETE /v1/oauth/authorizations/self — self-revoke POST /console/api/oauth/device/approve — account approve POST /console/api/oauth/device/deny — account deny Core primitives: - libs/oauth_bearer.py: prefix-keyed TokenKindRegistry + BearerAuthenticator + validate_bearer decorator. Two-tier scope (full vs apps:run) stamped from the registry, never from the DB. - libs/jws.py: HS256 compact JWS keyed on the shared Dify SECRET_KEY — same key-set verifies the SSOState envelope, the external-subject assertion (minted by Enterprise), and the approval-grant cookie. - libs/device_flow_security.py: enterprise_only gate, approval- grant cookie mint/verify/consume (Path=/v1/oauth/device, HttpOnly, SameSite=Lax, Secure follows is_secure()), anti- framing headers. - libs/rate_limit.py: typed RateLimit / RateLimitScope dispatch with composite-key buckets; both decorator + imperative form. - services/oauth_device_flow.py: Redis state machine (PENDING -> APPROVED|DENIED with atomic consume-on-poll), token mint via partial unique index uq_oauth_active_per_device (rotates in place), env-driven TTL policy. Storage: oauth_access_tokens table with partial unique index on (subject_email, subject_issuer, client_id, device_label) WHERE revoked_at IS NULL. account_id NULL distinguishes external-SSO rows. Hard-expire is CAS UPDATE (revoked_at + nullify token_hash) so audit events keep their token_id. Retention pruner DELETEs revoked + zombie-expired rows past OAUTH_ACCESS_TOKEN_RETENTION_DAYS. Frontend: /device page with code-entry, chooser (account vs SSO), authorize-account, authorize-sso views. SSO branch detaches from the URL user_code and reads everything from the cookie via /approval-context. Anti-framing headers on all responses. Wiring: ENABLE_OAUTH_BEARER feature flag; ext_oauth_bearer binds the authenticator at startup; clean_oauth_access_tokens_task scheduled in ext_celery. Spec: docs/specs/v1.0/server/{device-flow,tokens,middleware,security}.md --- api/app_factory.py | 2 + api/configs/feature/__init__.py | 13 + api/controllers/console/__init__.py | 2 + api/controllers/console/auth/oauth_device.py | 221 +++++++++ api/controllers/oauth_device_sso.py | 264 +++++++++++ api/controllers/service_api/__init__.py | 3 +- api/controllers/service_api/oauth.py | 302 +++++++++++++ api/extensions/ext_blueprints.py | 9 + api/extensions/ext_celery.py | 6 + api/extensions/ext_oauth_bearer.py | 22 + api/libs/device_flow_security.py | 187 ++++++++ api/libs/jws.py | 106 +++++ api/libs/oauth_bearer.py | 425 ++++++++++++++++++ api/libs/rate_limit.py | 109 +++++ ...00-d4a5e1f3c9b7_add_oauth_access_tokens.py | 102 +++++ api/models/__init__.py | 3 +- api/models/oauth.py | 37 +- .../clean_oauth_access_tokens_task.py | 57 +++ api/services/enterprise/enterprise_service.py | 9 + api/services/oauth_device_flow.py | 417 +++++++++++++++++ .../device/components/authorize-account.tsx | 96 ++++ web/app/device/components/authorize-sso.tsx | 96 ++++ web/app/device/components/chooser.tsx | 60 +++ web/app/device/components/code-input.tsx | 45 ++ web/app/device/page.tsx | 173 +++++++ web/app/device/utils/user-code.ts | 37 ++ web/app/signin/utils/post-login-redirect.ts | 75 +++- web/next.config.ts | 14 + web/service/base.ts | 5 + web/service/device-flow.ts | 84 ++++ 30 files changed, 2967 insertions(+), 14 deletions(-) create mode 100644 api/controllers/console/auth/oauth_device.py create mode 100644 api/controllers/oauth_device_sso.py create mode 100644 api/controllers/service_api/oauth.py create mode 100644 api/extensions/ext_oauth_bearer.py create mode 100644 api/libs/device_flow_security.py create mode 100644 api/libs/jws.py create mode 100644 api/libs/oauth_bearer.py create mode 100644 api/libs/rate_limit.py create mode 100644 api/migrations/versions/2026_04_23_2200-d4a5e1f3c9b7_add_oauth_access_tokens.py create mode 100644 api/schedule/clean_oauth_access_tokens_task.py create mode 100644 api/services/oauth_device_flow.py create mode 100644 web/app/device/components/authorize-account.tsx create mode 100644 web/app/device/components/authorize-sso.tsx create mode 100644 web/app/device/components/chooser.tsx create mode 100644 web/app/device/components/code-input.tsx create mode 100644 web/app/device/page.tsx create mode 100644 web/app/device/utils/user-code.ts create mode 100644 web/service/device-flow.ts diff --git a/api/app_factory.py b/api/app_factory.py index 48e50ceae9..d6fb70ab2e 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp): ext_logstore, ext_mail, ext_migrate, + ext_oauth_bearer, ext_orjson, ext_otel, ext_proxy_fix, @@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp): ext_enterprise_telemetry, ext_request_logging, ext_session_factory, + ext_oauth_bearer, ] for ext in extensions: short_name = ext.__name__.split(".")[-1] diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index ae49ae47d0..dca44f2f32 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -874,6 +874,11 @@ class AuthConfig(BaseSettings): default=86400, ) + ENABLE_OAUTH_BEARER: bool = Field( + description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).", + default=True, + ) + class ModerationConfig(BaseSettings): """ @@ -1148,6 +1153,14 @@ class CeleryScheduleTasksConfig(BaseSettings): description="Enable scheduled workflow run cleanup task", default=False, ) + ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field( + description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.", + default=True, + ) + OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field( + description="Days to retain revoked OAuth access-token rows before deletion.", + default=30, + ) ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field( description="Enable mail clean document notify task", default=False, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 980e828945..09224d7d43 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -81,6 +81,7 @@ from .auth import ( forgot_password, login, oauth, + oauth_device, oauth_server, ) @@ -189,6 +190,7 @@ __all__ = [ "models", "notification", "oauth", + "oauth_device", "oauth_server", "ops_trace", "parameter", diff --git a/api/controllers/console/auth/oauth_device.py b/api/controllers/console/auth/oauth_device.py new file mode 100644 index 0000000000..beef3698de --- /dev/null +++ b/api/controllers/console/auth/oauth_device.py @@ -0,0 +1,221 @@ +"""Console-session-authed device-flow approve/deny. Called by the +/device page after the user signs in. Public lookup is in service_api/oauth.py. +""" +from __future__ import annotations + +import logging + +from functools import wraps + +from flask_login import login_required +from flask_restx import Resource, reqparse +from werkzeug.exceptions import ServiceUnavailable + +from configs import dify_config +from controllers.console import console_ns +from controllers.console.wraps import account_initialization_required, setup_required +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.login import current_account_with_tenant +from libs.oauth_bearer import SubjectType +from libs.rate_limit import LIMIT_APPROVE_CONSOLE, rate_limit + + +def bearer_feature_required(fn): + """503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable + without the authenticator, so fail fast instead of approving silently. + """ + + @wraps(fn) + def inner(*args, **kwargs): + if not dify_config.ENABLE_OAUTH_BEARER: + raise ServiceUnavailable( + "bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable" + ) + return fn(*args, **kwargs) + + return inner +from services.oauth_device_flow import ( + PREFIX_OAUTH_ACCOUNT, + DeviceFlowRedis, + DeviceFlowStatus, + InvalidTransition, + StateNotFound, + mint_oauth_token, + oauth_ttl_days, +) + +logger = logging.getLogger(__name__) + + +_mutate_parser = reqparse.RequestParser() +_mutate_parser.add_argument("user_code", type=str, required=True, location="json") + + +_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving" +_APPROVE_GUARD_TTL_SECONDS = 10 + + +@console_ns.route("/oauth/device/approve") +class DeviceApproveApi(Resource): + @setup_required + @login_required + @account_initialization_required + @bearer_feature_required + @rate_limit(LIMIT_APPROVE_CONSOLE) + def post(self): + args = _mutate_parser.parse_args() + user_code = args["user_code"].strip().upper() + + account, tenant = current_account_with_tenant() + store = DeviceFlowRedis(redis_client) + + found = store.load_by_user_code(user_code) + if found is None: + return {"error": "expired_or_unknown"}, 404 + device_code, state = found + if state.status is not DeviceFlowStatus.PENDING: + return {"error": "already_resolved"}, 409 + + # SET NX guard — without it, two in-flight approves both pass + # PENDING, both mint, and the second upsert silently rotates the + # first caller into an already-revoked token. + guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code) + if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS): + return {"error": "approve_in_progress"}, 409 + + try: + ttl_days = oauth_ttl_days(tenant_id=tenant) + mint = mint_oauth_token( + db.session, + redis_client, + subject_email=account.email, + subject_issuer=None, + account_id=str(account.id), + client_id=state.client_id, + device_label=state.device_label, + prefix=PREFIX_OAUTH_ACCOUNT, + ttl_days=ttl_days, + ) + + poll_payload = _build_account_poll_payload(account, tenant, mint) + try: + store.approve( + device_code, + subject_email=account.email, + account_id=str(account.id), + subject_issuer=None, + minted_token=mint.token, + token_id=str(mint.token_id), + poll_payload=poll_payload, + ) + except (StateNotFound, InvalidTransition) as e: + # Row minted but state vanished — roll forward; the orphan + # token is revocable via auth devices list / Authorized Apps. + logger.error("device_flow: approve raced on %s: %s", device_code, e) + return {"error": "state_lost"}, 409 + finally: + redis_client.delete(guard_key) + + _emit_approve_audit(state, account, tenant, mint) + return {"status": "approved"}, 200 + + +@console_ns.route("/oauth/device/deny") +class DeviceDenyApi(Resource): + @setup_required + @login_required + @account_initialization_required + @bearer_feature_required + @rate_limit(LIMIT_APPROVE_CONSOLE) + def post(self): + args = _mutate_parser.parse_args() + user_code = args["user_code"].strip().upper() + + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(user_code) + if found is None: + return {"error": "expired_or_unknown"}, 404 + device_code, state = found + if state.status is not DeviceFlowStatus.PENDING: + return {"error": "already_resolved"}, 409 + + try: + store.deny(device_code) + except (StateNotFound, InvalidTransition) as e: + logger.error("device_flow: deny raced on %s: %s", device_code, e) + return {"error": "state_lost"}, 409 + + _emit_deny_audit(state) + return {"status": "denied"}, 200 + + +def _build_account_poll_payload(account, tenant, mint) -> dict: + """Pre-render the poll-response body so the unauthenticated poll + handler doesn't re-query accounts/tenants for authz data. + """ + from models import Tenant, TenantAccountJoin + rows = ( + db.session.query(Tenant, TenantAccountJoin) + .join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.account_id == account.id) + .all() + ) + workspaces = [ + {"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} + for t, m in rows + ] + # Prefer active session tenant → DB-flagged current join → first membership. + default_ws_id = None + if tenant and any(w["id"] == str(tenant) for w in workspaces): + default_ws_id = str(tenant) + if default_ws_id is None: + for _t, m in rows: + if getattr(m, "current", False): + default_ws_id = str(m.tenant_id) + break + if default_ws_id is None and workspaces: + default_ws_id = workspaces[0]["id"] + + return { + "token": mint.token, + "expires_at": mint.expires_at.isoformat(), + "subject_type": SubjectType.ACCOUNT, + "account": {"id": str(account.id), "email": account.email, "name": account.name}, + "workspaces": workspaces, + "default_workspace_id": default_ws_id, + "token_id": str(mint.token_id), + } + + +def _emit_approve_audit(state, account, tenant, mint) -> None: + logger.warning( + "audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s", + mint.token_id, account.email, state.client_id, state.device_label, mint.expires_at, + extra={ + "audit": True, + "event": "oauth.device_flow_approved", + "token_id": str(mint.token_id), + "subject_type": SubjectType.ACCOUNT, + "subject_email": account.email, + "account_id": str(account.id), + "tenant_id": tenant, + "client_id": state.client_id, + "device_label": state.device_label, + "scopes": ["full"], + "expires_at": mint.expires_at.isoformat(), + }, + ) + + +def _emit_deny_audit(state) -> None: + logger.warning( + "audit: oauth.device_flow_denied client_id=%s device_label=%s", + state.client_id, state.device_label, + extra={ + "audit": True, + "event": "oauth.device_flow_denied", + "client_id": state.client_id, + "device_label": state.device_label, + }, + ) diff --git a/api/controllers/oauth_device_sso.py b/api/controllers/oauth_device_sso.py new file mode 100644 index 0000000000..6f6555ac8f --- /dev/null +++ b/api/controllers/oauth_device_sso.py @@ -0,0 +1,264 @@ +"""SSO-branch device-flow endpoints. Browser hits sso-initiate → API +signs an SSOState envelope → Enterprise inner-API returns IdP authorize +URL → 302. IdP → Enterprise ACS → DeviceFlowDispatcher mints a signed +external-subject assertion → 302 to /v1/device/sso-complete → API mints +the approval-grant cookie → /device → user clicks Approve → /approve- +external mints the OAuth token. All four endpoints are EE-only. +""" +from __future__ import annotations + +import logging +import secrets + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from flask import Blueprint, jsonify, make_response, redirect, request +from libs import jws +from libs.oauth_bearer import SubjectType +from libs.rate_limit import ( + LIMIT_APPROVE_EXT_PER_EMAIL, + LIMIT_SSO_INITIATE_PER_IP, + enforce, + rate_limit, +) +from libs.device_flow_security import (APPROVAL_GRANT_COOKIE_NAME, ApprovalGrantClaims, + approval_grant_cleared_cookie_kwargs, + approval_grant_cookie_kwargs, + attach_anti_framing, + consume_approval_grant_nonce, + consume_sso_assertion_nonce, + enterprise_only, mint_approval_grant, + verify_approval_grant) +from services.enterprise.enterprise_service import EnterpriseService +from services.oauth_device_flow import (PREFIX_OAUTH_EXTERNAL_SSO, + DeviceFlowRedis, DeviceFlowStatus, + InvalidTransition, StateNotFound, + mint_oauth_token, oauth_ttl_days) +from werkzeug.exceptions import (BadGateway, BadRequest, Conflict, Forbidden, + NotFound, Unauthorized) + +logger = logging.getLogger(__name__) + + +bp = Blueprint("oauth_device_sso", __name__, url_prefix="/v1") +attach_anti_framing(bp) + + +# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the +# device_code it references. +STATE_ENVELOPE_TTL_SECONDS = 15 * 60 + + +@bp.route("/oauth/device/sso-initiate", methods=["GET"]) +@enterprise_only +@rate_limit(LIMIT_SSO_INITIATE_PER_IP) +def sso_initiate(): + user_code = (request.args.get("user_code") or "").strip().upper() + if not user_code: + raise BadRequest("user_code required") + + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(user_code) + if found is None: + raise BadRequest("invalid_user_code") + _, state = found + if state.status is not DeviceFlowStatus.PENDING: + raise BadRequest("invalid_user_code") + + keyset = jws.KeySet.from_shared_secret() + signed_state = jws.sign( + keyset, + payload={ + "redirect_url": "", + "app_code": "", + "intent": "device_flow", + "user_code": user_code, + "nonce": secrets.token_urlsafe(16), + "return_to": "", + "idp_callback_url": f"{request.host_url.rstrip('/')}/v1/device/sso-complete", + }, + aud=jws.AUD_STATE_ENVELOPE, + ttl_seconds=STATE_ENVELOPE_TTL_SECONDS, + ) + + try: + reply = EnterpriseService.initiate_device_flow_sso(signed_state) + except Exception as e: + logger.warning("sso-initiate: enterprise call failed: %s", e) + raise BadGateway("sso_initiate_failed") from e + + url = (reply or {}).get("url") + if not url: + raise BadGateway("sso_initiate_missing_url") + + # Clear stale approval-grant — defends against cross-tab/back-button mixing. + resp = redirect(url, code=302) + resp.set_cookie(**approval_grant_cleared_cookie_kwargs()) + return resp + + +@bp.route("/device/sso-complete", methods=["GET"]) +@enterprise_only +def sso_complete(): + blob = request.args.get("sso_assertion") + if not blob: + raise BadRequest("sso_assertion required") + + keyset = jws.KeySet.from_shared_secret() + + try: + claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION) + except jws.VerifyError as e: + logger.warning("sso-complete: rejected assertion: %s", e) + raise BadRequest("invalid_sso_assertion") from e + + if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")): + raise BadRequest("invalid_sso_assertion") + + user_code = (claims.get("user_code") or "").strip().upper() + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(user_code) + if found is None: + raise Conflict("user_code_not_pending") + _, state = found + if state.status is not DeviceFlowStatus.PENDING: + raise Conflict("user_code_not_pending") + + iss = request.host_url.rstrip("/") + cookie_value, _ = mint_approval_grant( + keyset=keyset, + iss=iss, + subject_email=claims["email"], + subject_issuer=claims["issuer"], + user_code=user_code, + ) + + resp = redirect("/device?sso_verified=1", code=302) + resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value)) + return resp + + +@bp.route("/oauth/device/approval-context", methods=["GET"]) +@enterprise_only +def approval_context(): + token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME) + if not token: + raise Unauthorized("no_session") + + keyset = jws.KeySet.from_shared_secret() + try: + claims = verify_approval_grant(keyset, token) + except jws.VerifyError as e: + logger.warning("approval-context: bad cookie: %s", e) + raise Unauthorized("no_session") from e + + return jsonify({ + "subject_email": claims.subject_email, + "subject_issuer": claims.subject_issuer, + "user_code": claims.user_code, + "csrf_token": claims.csrf_token, + "expires_at": claims.expires_at.isoformat(), + }), 200 + + + +@bp.route("/oauth/device/approve-external", methods=["POST"]) +@enterprise_only +def approve_external(): + token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME) + if not token: + raise Unauthorized("invalid_session") + + keyset = jws.KeySet.from_shared_secret() + try: + claims: ApprovalGrantClaims = verify_approval_grant(keyset, token) + except jws.VerifyError as e: + logger.warning("approve-external: bad cookie: %s", e) + raise Unauthorized("invalid_session") from e + + enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}") + + csrf_header = request.headers.get("X-CSRF-Token", "") + if not csrf_header or csrf_header != claims.csrf_token: + raise Forbidden("csrf_mismatch") + + data = request.get_json(silent=True) or {} + body_user_code = (data.get("user_code") or "").strip().upper() + if body_user_code != claims.user_code: + raise BadRequest("user_code_mismatch") + + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(claims.user_code) + if found is None: + raise NotFound("user_code_not_pending") + device_code, state = found + if state.status is not DeviceFlowStatus.PENDING: + raise Conflict("user_code_not_pending") + + if not consume_approval_grant_nonce(redis_client, claims.nonce): + raise Unauthorized("session_already_consumed") + + ttl_days = oauth_ttl_days(tenant_id=None) + mint = mint_oauth_token( + db.session, + redis_client, + subject_email=claims.subject_email, + subject_issuer=claims.subject_issuer, + account_id=None, + client_id=state.client_id, + device_label=state.device_label, + prefix=PREFIX_OAUTH_EXTERNAL_SSO, + ttl_days=ttl_days, + ) + + poll_payload = { + "token": mint.token, + "expires_at": mint.expires_at.isoformat(), + "subject_type": SubjectType.EXTERNAL_SSO, + "subject_email": claims.subject_email, + "subject_issuer": claims.subject_issuer, + "account": None, + "workspaces": [], + "default_workspace_id": None, + "token_id": str(mint.token_id), + } + + try: + store.approve( + device_code, + subject_email=claims.subject_email, + account_id=None, + subject_issuer=claims.subject_issuer, + minted_token=mint.token, + token_id=str(mint.token_id), + poll_payload=poll_payload, + ) + except (StateNotFound, InvalidTransition) as e: + logger.error("approve-external: state transition raced: %s", e) + raise Conflict("state_lost") from e + + _emit_approve_external_audit(state, claims, mint) + + resp = make_response(jsonify({"status": "approved"}), 200) + resp.set_cookie(**approval_grant_cleared_cookie_kwargs()) + return resp + + +def _emit_approve_external_audit(state, claims, mint) -> None: + logger.warning( + "audit: oauth.device_flow_approved subject_type=%s " + "subject_email=%s subject_issuer=%s token_id=%s", + SubjectType.EXTERNAL_SSO, claims.subject_email, claims.subject_issuer, mint.token_id, + extra={ + "audit": True, + "event": "oauth.device_flow_approved", + "subject_type": SubjectType.EXTERNAL_SSO, + "subject_email": claims.subject_email, + "subject_issuer": claims.subject_issuer, + "token_id": str(mint.token_id), + "client_id": state.client_id, + "device_label": state.device_label, + "scopes": ["apps:run"], + "expires_at": mint.expires_at.isoformat(), + }, + ) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 4f7f7d9a98..544dfbbfef 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -14,7 +14,7 @@ api = ExternalApi( service_api_ns = Namespace("service_api", description="Service operations", path="/") -from . import index +from . import index, oauth from .app import ( annotation, app, @@ -54,6 +54,7 @@ __all__ = [ "message", "metadata", "models", + "oauth", "rag_pipeline_workflow", "segment", "site", diff --git a/api/controllers/service_api/oauth.py b/api/controllers/service_api/oauth.py new file mode 100644 index 0000000000..e1ab831d16 --- /dev/null +++ b/api/controllers/service_api/oauth.py @@ -0,0 +1,302 @@ +"""``/v1`` OAuth bearer + device-flow endpoints. ``/me`` and self-revoke +are bearer-authed; the device-flow trio (code/token/lookup) is public — +code/token per RFC 8628, lookup so the /device page can pre-validate +before the user has a console session. +""" +from __future__ import annotations + +import logging +from datetime import UTC, datetime + +from flask import g, request +from flask_restx import Resource, reqparse +from sqlalchemy import update +from werkzeug.exceptions import BadRequest + +from controllers.service_api import service_api_ns +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.helper import extract_remote_ip +from libs.oauth_bearer import ( + ACCEPT_USER_ANY, + SubjectType, + TOKEN_CACHE_KEY_FMT, + validate_bearer, +) +from libs.rate_limit import ( + LIMIT_DEVICE_CODE_PER_IP, + LIMIT_LOOKUP_PUBLIC, + LIMIT_ME_PER_ACCOUNT, + LIMIT_ME_PER_EMAIL, + enforce, + rate_limit, +) +from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin +from services.oauth_device_flow import ( + DEFAULT_POLL_INTERVAL_SECONDS, + DEVICE_FLOW_TTL_SECONDS, + DeviceFlowRedis, + DeviceFlowStatus, + SlowDownDecision, +) + +logger = logging.getLogger(__name__) + +KNOWN_CLIENT_IDS = frozenset({"difyctl"}) + + +# ============================================================================ +# GET /v1/me +# ============================================================================ + + +@service_api_ns.route("/me") +class MeApi(Resource): + @validate_bearer(accept=ACCEPT_USER_ANY) + def get(self): + ctx = g.auth_ctx + + if ctx.subject_type == SubjectType.EXTERNAL_SSO: + enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}") + else: + enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}") + + if ctx.subject_type == SubjectType.EXTERNAL_SSO: + return { + "subject_type": ctx.subject_type, + "subject_email": ctx.subject_email, + "subject_issuer": ctx.subject_issuer, + "account": None, + "workspaces": [], + "default_workspace_id": None, + } + + account = ( + db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() + if ctx.account_id else None + ) + memberships = _load_memberships(ctx.account_id) if ctx.account_id else [] + default_ws_id = _pick_default_workspace(memberships) + + return { + "subject_type": ctx.subject_type, + "subject_email": ctx.subject_email or (account.email if account else None), + "account": _account_payload(account) if account else None, + "workspaces": [_workspace_payload(m) for m in memberships], + "default_workspace_id": default_ws_id, + } + + +def _load_memberships(account_id): + return ( + db.session.query(TenantAccountJoin, Tenant) + .join(Tenant, Tenant.id == TenantAccountJoin.tenant_id) + .filter(TenantAccountJoin.account_id == account_id) + .all() + ) + + +def _pick_default_workspace(memberships) -> str | None: + if not memberships: + return None + for join, tenant in memberships: + if getattr(join, "current", False): + return str(tenant.id) + return str(memberships[0][1].id) + + +def _workspace_payload(row) -> dict: + join, tenant = row + return {"id": str(tenant.id), "name": tenant.name, "role": getattr(join, "role", "")} + + +def _account_payload(account) -> dict: + return {"id": str(account.id), "email": account.email, "name": account.name} + + +# ============================================================================ +# DELETE /v1/oauth/authorizations/self +# ============================================================================ + + +@service_api_ns.route("/oauth/authorizations/self") +class OAuthAuthorizationsSelfApi(Resource): + @validate_bearer(accept=ACCEPT_USER_ANY) + def delete(self): + ctx = g.auth_ctx + + if not ctx.source.startswith("oauth"): + raise BadRequest( + "this endpoint revokes OAuth bearer tokens; " + "use /v1/personal-access-tokens/self for PATs" + ) + + # Snapshot pre-revoke hash for cache invalidation; UPDATE WHERE + # makes double-revoke idempotent. + row = ( + db.session.query(OAuthAccessToken.token_hash) + .filter( + OAuthAccessToken.id == str(ctx.token_id), + OAuthAccessToken.revoked_at.is_(None), + ) + .one_or_none() + ) + pre_revoke_hash = row[0] if row else None + + stmt = ( + update(OAuthAccessToken) + .where( + OAuthAccessToken.id == str(ctx.token_id), + OAuthAccessToken.revoked_at.is_(None), + ) + .values(revoked_at=datetime.now(UTC), token_hash=None) + ) + db.session.execute(stmt) + db.session.commit() + + if pre_revoke_hash: + redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash)) + + return {"status": "revoked"}, 200 + + +# ============================================================================ +# POST /v1/oauth/device/code (unauthenticated — CLI starts a flow) +# ============================================================================ + + +_code_parser = reqparse.RequestParser() +_code_parser.add_argument("client_id", type=str, required=True, location="json") +_code_parser.add_argument("device_label", type=str, required=True, location="json") + + +@service_api_ns.route("/oauth/device/code") +class OAuthDeviceCodeApi(Resource): + @rate_limit(LIMIT_DEVICE_CODE_PER_IP) + def post(self): + args = _code_parser.parse_args() + client_id = args["client_id"] + device_label = args["device_label"] + + if client_id not in KNOWN_CLIENT_IDS: + return {"error": "unsupported_client"}, 400 + + store = DeviceFlowRedis(redis_client) + ip = extract_remote_ip(request) + device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip) + + return { + "device_code": device_code, + "user_code": user_code, + "verification_uri": _verification_uri(), + "expires_in": expires_in, + "interval": DEFAULT_POLL_INTERVAL_SECONDS, + }, 200 + + +def _verification_uri() -> str: + from configs import dify_config + + base = getattr(dify_config, "CONSOLE_WEB_URL", None) + if base: + return f"{base.rstrip('/')}/device" + return f"{request.host_url.rstrip('/')}/device" + + +# ============================================================================ +# POST /v1/oauth/device/token (unauthenticated — CLI polls) +# ============================================================================ + + +_poll_parser = reqparse.RequestParser() +_poll_parser.add_argument("device_code", type=str, required=True, location="json") +_poll_parser.add_argument("client_id", type=str, required=True, location="json") + + +@service_api_ns.route("/oauth/device/token") +class OAuthDeviceTokenApi(Resource): + """RFC 8628 poll.""" + + def post(self): + args = _poll_parser.parse_args() + device_code = args["device_code"] + + store = DeviceFlowRedis(redis_client) + + # slow_down beats every other branch — polling-too-fast clients + # see only that response regardless of underlying state. + if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN: + return {"error": "slow_down"}, 400 + + state = store.load_by_device_code(device_code) + if state is None: + return {"error": "expired_token"}, 400 + + if state.status is DeviceFlowStatus.PENDING: + return {"error": "authorization_pending"}, 400 + + terminal = store.consume_on_poll(device_code) + if terminal is None: + return {"error": "expired_token"}, 400 + + if terminal.status is DeviceFlowStatus.DENIED: + return {"error": "access_denied"}, 400 + + poll_payload = terminal.poll_payload or {} + if "token" not in poll_payload: + logger.error("device_flow: approved state missing poll_payload for %s", device_code) + return {"error": "expired_token"}, 400 + + _audit_cross_ip_if_needed(state) + return poll_payload, 200 + + +# ============================================================================ +# GET /v1/oauth/device/lookup (unauthenticated — /device page pre-validates) +# ============================================================================ + + +_lookup_parser = reqparse.RequestParser() +_lookup_parser.add_argument("user_code", type=str, required=True, location="args") + + +@service_api_ns.route("/oauth/device/lookup") +class OAuthDeviceLookupApi(Resource): + """Read-only — public for pre-validate before login. user_code is + high-entropy + short-TTL; per-IP rate limit blocks enumeration. + """ + + @rate_limit(LIMIT_LOOKUP_PUBLIC) + def get(self): + args = _lookup_parser.parse_args() + user_code = args["user_code"].strip().upper() + + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(user_code) + if found is None: + return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200 + + _device_code, state = found + if state.status is not DeviceFlowStatus.PENDING: + return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200 + + return { + "valid": True, + "expires_in_remaining": DEVICE_FLOW_TTL_SECONDS, + "client_id": state.client_id, + }, 200 + + +def _audit_cross_ip_if_needed(state) -> None: + poll_ip = extract_remote_ip(request) + if state.created_ip and poll_ip and poll_ip != state.created_ip: + logger.warning( + "audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s", + state.token_id, state.created_ip, poll_ip, + extra={ + "audit": True, + "token_id": state.token_id, + "creation_ip": state.created_ip, + "poll_ip": poll_ip, + }, + ) diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 7d13f0c061..b8da817cc3 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -90,6 +90,15 @@ def init_app(app: DifyApp): app.register_blueprint(inner_api_bp) app.register_blueprint(mcp_bp) + # SSO-branch device-flow routes. No CORS config — these endpoints are + # user-interactive (same-origin browser traffic) and cookie-authed; + # allowing cross-origin would defeat the SameSite=Lax cookie's purpose. + # Gated on ENABLE_OAUTH_BEARER: without the bearer authenticator, tokens + # minted here cannot be validated, so skip the blueprint entirely. + if dify_config.ENABLE_OAUTH_BEARER: + from controllers.oauth_device_sso import bp as oauth_device_sso_bp + app.register_blueprint(oauth_device_sso_bp) + # Register trigger blueprint with CORS for webhook calls _apply_cors_once( trigger_bp, diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 340f514fcc..fce065eda9 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task", "schedule": crontab(minute="0", hour="0"), } + if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: + imports.append("schedule.clean_oauth_access_tokens_task") + beat_schedule["clean_oauth_access_tokens_task"] = { + "task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task", + "schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"), + } if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: imports.append("schedule.workflow_schedule_task") beat_schedule["workflow_schedule_task"] = { diff --git a/api/extensions/ext_oauth_bearer.py b/api/extensions/ext_oauth_bearer.py new file mode 100644 index 0000000000..d881a88c87 --- /dev/null +++ b/api/extensions/ext_oauth_bearer.py @@ -0,0 +1,22 @@ +"""Bind the bearer authenticator at startup. Must run after ext_database +and ext_redis (needs both factories). +""" +from __future__ import annotations + +from configs import dify_config +from dify_app import DifyApp +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.oauth_bearer import build_and_bind + + +def is_enabled() -> bool: + return dify_config.ENABLE_OAUTH_BEARER + + +def init_app(app: DifyApp) -> None: + # scoped_session isn't a context manager; request teardown closes it. + def session_factory(): + return db.session + + build_and_bind(session_factory=session_factory, redis_client=redis_client) diff --git a/api/libs/device_flow_security.py b/api/libs/device_flow_security.py new file mode 100644 index 0000000000..e589a16522 --- /dev/null +++ b/api/libs/device_flow_security.py @@ -0,0 +1,187 @@ +"""Device-flow security primitives: enterprise_only gate, approval-grant +cookie mint/verify/consume, and anti-framing headers. +""" +from __future__ import annotations + +import logging +import secrets +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from functools import wraps +from typing import Callable + +from flask import Blueprint +from werkzeug.exceptions import NotFound + +from libs import jws +from libs.token import is_secure +from services.feature_service import FeatureService, LicenseStatus + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# enterprise_only decorator +# ============================================================================ + + +_CE_LIKE_STATUSES = {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST} + + +def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]: + """404 on CE, passthrough on EE. Apply before rate-limit so CE + responses don't consume the bucket. + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + settings = FeatureService.get_system_features() + if settings.license.status in _CE_LIKE_STATUSES: + raise NotFound() + return view(*args, **kwargs) + + return decorated + + +# ============================================================================ +# approval_grant cookie +# ============================================================================ + + +APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant" +APPROVAL_GRANT_COOKIE_PATH = "/v1/oauth/device" +APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min +NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay +NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}" +SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}" + + +@dataclass(frozen=True, slots=True) +class ApprovalGrantClaims: + subject_email: str + subject_issuer: str + user_code: str + nonce: str + csrf_token: str + expires_at: datetime + + +def mint_approval_grant( + *, + keyset: jws.KeySet, + iss: str, + subject_email: str, + subject_issuer: str, + user_code: str, +) -> tuple[str, ApprovalGrantClaims]: + """Use ``approval_grant_cookie_kwargs`` to set the cookie — single + source of truth for Path/HttpOnly/Secure/SameSite. + """ + now = datetime.now(UTC) + exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS) + nonce = _random_opaque() + csrf_token = _random_opaque() + + payload = { + "iss": iss, + "subject_email": subject_email, + "subject_issuer": subject_issuer, + "user_code": user_code, + "nonce": nonce, + "csrf_token": csrf_token, + } + token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS) + + return token, ApprovalGrantClaims( + subject_email=subject_email, + subject_issuer=subject_issuer, + user_code=user_code, + nonce=nonce, + csrf_token=csrf_token, + expires_at=exp, + ) + + +def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims: + """Sig + aud + exp only — nonce consumption is the caller's job.""" + data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT) + return ApprovalGrantClaims( + subject_email=data["subject_email"], + subject_issuer=data["subject_issuer"], + user_code=data["user_code"], + nonce=data["nonce"], + csrf_token=data["csrf_token"], + expires_at=datetime.fromtimestamp(data["exp"], tz=UTC), + ) + + +def consume_approval_grant_nonce(redis_client, nonce: str) -> bool: + if not nonce: + return False + return bool( + redis_client.set( + NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS, + ) + ) + + +def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool: + if not nonce: + return False + return bool( + redis_client.set( + SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS, + ) + ) + + +def approval_grant_cookie_kwargs(value: str) -> dict: + """``secure`` follows is_secure() so HTTP-only deployments don't + silently drop the cookie. + """ + return { + "key": APPROVAL_GRANT_COOKIE_NAME, + "value": value, + "max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS, + "path": APPROVAL_GRANT_COOKIE_PATH, + "secure": is_secure(), + "httponly": True, + "samesite": "Lax", + } + + +def approval_grant_cleared_cookie_kwargs() -> dict: + return { + "key": APPROVAL_GRANT_COOKIE_NAME, + "value": "", + "max_age": 0, + "path": APPROVAL_GRANT_COOKIE_PATH, + "secure": is_secure(), + "httponly": True, + "samesite": "Lax", + } + + +def _random_opaque() -> str: + return secrets.token_urlsafe(16) + + +# ============================================================================ +# Anti-framing headers +# ============================================================================ + + +_ANTI_FRAMING_HEADERS = { + "X-Frame-Options": "DENY", + "Content-Security-Policy": "frame-ancestors 'none'", +} + + +def attach_anti_framing(bp: Blueprint) -> None: + """X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4).""" + + @bp.after_request + def _apply_headers(response): + for name, value in _ANTI_FRAMING_HEADERS.items(): + response.headers.setdefault(name, value) + return response diff --git a/api/libs/jws.py b/api/libs/jws.py new file mode 100644 index 0000000000..f66811aabd --- /dev/null +++ b/api/libs/jws.py @@ -0,0 +1,106 @@ +"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO +state envelope, external subject assertion, and approval-grant cookie — +all three share one key-set so api ↔ enterprise can verify each other. +""" +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import jwt +from configs import dify_config + +AUD_STATE_ENVELOPE = "api.sso.state_envelope" +AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion" +AUD_APPROVAL_GRANT = "api.device_flow.approval_grant" + +ACTIVE_KID_V1 = "dify-shared-v1" + + +class KeySetError(Exception): + pass + + +class KeySet: + """``from_entries`` reserves multi-kid construction for rotation slots.""" + + def __init__(self, entries: dict[str, bytes], active_kid: str) -> None: + if active_kid not in entries: + raise KeySetError(f"active kid {active_kid!r} missing from key-set") + if not entries[active_kid]: + raise KeySetError(f"active kid {active_kid!r} has empty secret") + self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()} + self._active_kid = active_kid + + @classmethod + def from_shared_secret(cls) -> "KeySet": + secret = dify_config.SECRET_KEY + if not secret: + raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set") + return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1) + + @classmethod + def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> "KeySet": + return cls(entries, active_kid) + + @property + def active_kid(self) -> str: + return self._active_kid + + def lookup(self, kid: str) -> bytes | None: + return self._entries.get(kid) + + +def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str: + """``iat`` + ``exp`` are injected here; callers must not set them.""" + if "aud" in payload or "iat" in payload or "exp" in payload: + raise ValueError("reserved claim present in payload (aud/iat/exp)") + if ttl_seconds <= 0: + raise ValueError("ttl_seconds must be positive") + + kid = keyset.active_kid + secret = keyset.lookup(kid) + if secret is None: + raise KeySetError(f"active kid {kid!r} lookup miss") + + iat = datetime.now(UTC) + exp = iat + timedelta(seconds=ttl_seconds) + claims = {**payload, "aud": aud, "iat": iat, "exp": exp} + return jwt.encode( + claims, + secret, + algorithm="HS256", + headers={"kid": kid, "typ": "JWT"}, + ) + + +class VerifyError(Exception): + pass + + +def verify(keyset: KeySet, token: str, expected_aud: str) -> dict: + """Unknown kid is rejected — never fall back to the active kid, since + a past kid value would otherwise be forgeable by anyone who saw it. + """ + try: + header = jwt.get_unverified_header(token) + except jwt.PyJWTError as e: + raise VerifyError(f"decode header: {e}") from e + kid = header.get("kid") + if not kid: + raise VerifyError("no kid in header") + secret = keyset.lookup(kid) + if secret is None: + raise VerifyError(f"unknown kid {kid!r}") + try: + return jwt.decode( + token, + secret, + algorithms=["HS256"], + audience=expected_aud, + ) + except jwt.ExpiredSignatureError as e: + raise VerifyError("token expired") from e + except jwt.InvalidAudienceError as e: + raise VerifyError("aud mismatch") from e + except jwt.PyJWTError as e: + raise VerifyError(f"decode: {e}") from e diff --git a/api/libs/oauth_bearer.py b/api/libs/oauth_bearer.py new file mode 100644 index 0000000000..d82250a622 --- /dev/null +++ b/api/libs/oauth_bearer.py @@ -0,0 +1,425 @@ +"""OAuth bearer primitives. + +To add a token kind: write a Resolver, add a SubjectType + Accepts member, +append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT. +Authenticator + validate_bearer stay untouched. +""" +from __future__ import annotations + +import hashlib +import json +import logging +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime +from enum import StrEnum +from functools import wraps +from typing import Callable, Iterable, Literal, Protocol + +from flask import g, request +from sqlalchemy import update +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized + +from models import OAuthAccessToken + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Contract — types, enums, protocols +# ============================================================================ + + +class SubjectType(StrEnum): + ACCOUNT = "account" + EXTERNAL_SSO = "external_sso" + + +@dataclass(frozen=True, slots=True) +class AuthContext: + """Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source`` + come from the TokenKind, not the DB — corrupt rows can't elevate scope. + """ + + subject_type: SubjectType + subject_email: str | None + subject_issuer: str | None + account_id: uuid.UUID | None + scopes: frozenset[str] + token_id: uuid.UUID + source: str + expires_at: datetime | None + + +@dataclass(frozen=True, slots=True) +class ResolvedRow: + subject_email: str | None + subject_issuer: str | None + account_id: uuid.UUID | None + token_id: uuid.UUID + expires_at: datetime | None + + +class Resolver(Protocol): + def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract + ... + + +@dataclass(frozen=True, slots=True) +class TokenKind: + prefix: str + subject_type: SubjectType + scopes: frozenset[str] + source: str + resolver: Resolver + + def matches(self, token: str) -> bool: + return token.startswith(self.prefix) + + +class InvalidBearer(Exception): + """Token missing, unknown prefix, or no live row.""" + + +class TokenExpired(Exception): + """Hard-expire bookkeeping is the resolver's job before raising.""" + + +# ============================================================================ +# Registry +# ============================================================================ + + +class TokenKindRegistry: + def __init__(self, kinds: Iterable[TokenKind]) -> None: + self._kinds: tuple[TokenKind, ...] = tuple(kinds) + prefixes = [k.prefix for k in self._kinds] + if len(set(prefixes)) != len(prefixes): + raise ValueError(f"duplicate prefix in registry: {prefixes}") + + def find(self, token: str) -> TokenKind | None: + for k in self._kinds: + if k.matches(token): + return k + return None + + def kinds(self) -> tuple[TokenKind, ...]: + return self._kinds + + +# ============================================================================ +# Authenticator +# ============================================================================ + + +def sha256_hex(token: str) -> str: + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + +class BearerAuthenticator: + def __init__(self, registry: TokenKindRegistry) -> None: + self._registry = registry + + def authenticate(self, token: str) -> AuthContext: + kind = self._registry.find(token) + if kind is None: + raise InvalidBearer("unknown token prefix") + row = kind.resolver.resolve(sha256_hex(token)) + if row is None: + raise InvalidBearer("token unknown or revoked") + return AuthContext( + subject_type=kind.subject_type, + subject_email=row.subject_email, + subject_issuer=row.subject_issuer, + account_id=row.account_id, + scopes=kind.scopes, + token_id=row.token_id, + source=kind.source, + expires_at=row.expires_at, + ) + + +# ============================================================================ +# OAuth access token resolver (PAT resolver would be a sibling class) +# ============================================================================ + +TOKEN_CACHE_KEY_FMT = "auth:token:{hash}" +POSITIVE_TTL_SECONDS = 60 +NEGATIVE_TTL_SECONDS = 10 +AUDIT_OAUTH_EXPIRED = "oauth.token_expired" + +ScopeVariant = Literal["account", "external_sso"] + + +class OAuthAccessTokenResolver: + """``.for_account()`` / ``.for_external_sso()`` are variant-scoped views + sharing DB + cache plumbing. + """ + + def __init__( + self, + session_factory, + redis_client, + positive_ttl: int = POSITIVE_TTL_SECONDS, + negative_ttl: int = NEGATIVE_TTL_SECONDS, + ) -> None: + self._session_factory = session_factory + self._redis = redis_client + self._positive_ttl = positive_ttl + self._negative_ttl = negative_ttl + + def for_account(self) -> Resolver: + return _VariantResolver(self, variant="account") + + def for_external_sso(self) -> Resolver: + return _VariantResolver(self, variant="external_sso") + + def _cache_key(self, token_hash: str) -> str: + return TOKEN_CACHE_KEY_FMT.format(hash=token_hash) + + def _cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]: + raw = self._redis.get(self._cache_key(token_hash)) + if raw is None: + return None + text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw + if text == "invalid": + return "invalid" + try: + data = json.loads(text) + return _row_from_cache(data) + except (ValueError, KeyError): + logger.warning("auth:token cache entry malformed; treating as miss") + return None + + def _cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None: + self._redis.setex( + self._cache_key(token_hash), + self._positive_ttl, + json.dumps(_row_to_cache(row)), + ) + + def _cache_set_negative(self, token_hash: str) -> None: + self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid") + + def _hard_expire(self, session: Session, row_id: uuid.UUID, token_hash: str) -> None: + """Atomic CAS — only the worker that flips revoked_at emits audit; + replays are idempotent. Spec: tokens.md §Detection + hard-expire. + """ + stmt = ( + update(OAuthAccessToken) + .where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None)) + .values(revoked_at=datetime.now(UTC), token_hash=None) + ) + result = session.execute(stmt) + session.commit() + if result.rowcount == 1: + logger.warning( + "audit: %s token_id=%s", AUDIT_OAUTH_EXPIRED, row_id, + extra={"audit": True, "token_id": str(row_id)}, + ) + self._redis.delete(self._cache_key(token_hash)) + self._cache_set_negative(token_hash) + + +class _VariantResolver: + + def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None: + self._parent = parent + self._variant = variant + + def resolve(self, token_hash: str) -> ResolvedRow | None: + cached = self._parent._cache_get(token_hash) + if cached == "invalid": + return None + if cached is not None and not isinstance(cached, str): + if not self._matches_variant(cached): + return None + return cached + + # _session_factory returns Flask-SQLAlchemy's scoped_session, which is + # request-bound and not a context manager; use it directly. + session = self._parent._session_factory() + row = self._load_from_db(session, token_hash) + if row is None: + self._parent._cache_set_negative(token_hash) + return None + + now = datetime.now(UTC) + if row.expires_at is not None and row.expires_at <= now: + self._parent._hard_expire(session, row.id, token_hash) + return None + + if not self._matches_variant_model(row): + logger.error( + "internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s", + row.id, row.prefix, + ) + return None + + resolved = ResolvedRow( + subject_email=row.subject_email, + subject_issuer=row.subject_issuer, + account_id=uuid.UUID(str(row.account_id)) if row.account_id else None, + token_id=uuid.UUID(str(row.id)), + expires_at=row.expires_at, + ) + self._parent._cache_set_positive(token_hash, resolved) + return resolved + + def _matches_variant(self, row: ResolvedRow) -> bool: + has_account = row.account_id is not None + if self._variant == "account": + return has_account + return not has_account + + def _matches_variant_model(self, row: OAuthAccessToken) -> bool: + has_account = row.account_id is not None + if self._variant == "account": + return has_account and row.prefix == "dfoa_" + return (not has_account) and row.prefix == "dfoe_" + + def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None: + return ( + session.query(OAuthAccessToken) + .filter( + OAuthAccessToken.token_hash == token_hash, + OAuthAccessToken.revoked_at.is_(None), + ) + .one_or_none() + ) + + +def _row_to_cache(row: ResolvedRow) -> dict: + return { + "subject_email": row.subject_email, + "subject_issuer": row.subject_issuer, + "account_id": str(row.account_id) if row.account_id else None, + "token_id": str(row.token_id), + "expires_at": row.expires_at.isoformat() if row.expires_at else None, + } + + +def _row_from_cache(data: dict) -> ResolvedRow: + return ResolvedRow( + subject_email=data["subject_email"], + subject_issuer=data["subject_issuer"], + account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None, + token_id=uuid.UUID(data["token_id"]), + expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None, + ) + + +# ============================================================================ +# Decorator — route-level bearer gate +# ============================================================================ + + +class Accepts(StrEnum): + USER_ACCOUNT = "user_account" + USER_EXT_SSO = "user_ext_sso" + APP = "app" + + +ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO}) + + +_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = { + SubjectType.ACCOUNT: Accepts.USER_ACCOUNT, + SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO, +} + + +_authenticator: BearerAuthenticator | None = None + + +def bind_authenticator(authenticator: BearerAuthenticator) -> None: + global _authenticator + _authenticator = authenticator + + +def get_authenticator() -> BearerAuthenticator: + if _authenticator is None: + raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup") + return _authenticator + + +def _extract_bearer(req) -> str | None: + header = req.headers.get("Authorization", "") + scheme, _, value = header.partition(" ") + if scheme.lower() != "bearer" or not value: + return None + return value.strip() + + +def validate_bearer(*, accept: frozenset[Accepts]) -> Callable: + """Opt-in: omitting it leaves the route unauthenticated. + + Coexists with legacy ``app-`` keys (tenant+app scoped, resolved in + ``service_api/wraps.py``) and user-level OAuth bearers (resolved here). + """ + + def wrap(fn: Callable) -> Callable: + @wraps(fn) + def inner(*args, **kwargs): + token = _extract_bearer(request) + if token is None: + raise Unauthorized("missing bearer token") + + # app- keys bypass the OAuth authenticator (work even when disabled). + if token.startswith("app-"): + if Accepts.APP not in accept: + raise Unauthorized("app-scoped keys not accepted here") + return fn(*args, **kwargs) + + if _authenticator is None: + raise ServiceUnavailable( + "bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable" + ) + + try: + ctx = get_authenticator().authenticate(token) + except InvalidBearer as e: + raise Unauthorized(str(e)) + + if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept: + raise Forbidden("token subject type not accepted here") + + g.auth_ctx = ctx + return fn(*args, **kwargs) + + return inner + + return wrap + + +# ============================================================================ +# Wiring — called once from the app factory +# ============================================================================ + + +def build_registry(session_factory, redis_client) -> TokenKindRegistry: + oauth = OAuthAccessTokenResolver(session_factory, redis_client) + return TokenKindRegistry([ + TokenKind( + prefix="dfoa_", + subject_type=SubjectType.ACCOUNT, + scopes=frozenset({"full"}), + source="oauth_account", + resolver=oauth.for_account(), + ), + TokenKind( + prefix="dfoe_", + subject_type=SubjectType.EXTERNAL_SSO, + scopes=frozenset({"apps:run"}), + source="oauth_external_sso", + resolver=oauth.for_external_sso(), + ), + ]) + + +def build_and_bind(session_factory, redis_client) -> BearerAuthenticator: + registry = build_registry(session_factory, redis_client) + auth = BearerAuthenticator(registry) + bind_authenticator(auth) + return auth diff --git a/api/libs/rate_limit.py b/api/libs/rate_limit.py new file mode 100644 index 0000000000..dd9322bba6 --- /dev/null +++ b/api/libs/rate_limit.py @@ -0,0 +1,109 @@ +"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding- +window Redis ZSET). Apply after auth decorators so scopes can read +``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed +in-handler. RFC-8628 ``slow_down`` is inline — its response shape isn't +generic 429. Spec: docs/specs/v1.0/server/security.md. +""" +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta +from enum import StrEnum +from functools import wraps +from typing import Callable + +from flask import g, request, session +from werkzeug.exceptions import TooManyRequests + +from libs.helper import RateLimiter, extract_remote_ip + + +class RateLimitScope(StrEnum): + IP = "ip" + SESSION = "session" + ACCOUNT = "account" + SUBJECT_EMAIL = "subject_email" + TOKEN_ID = "token_id" + + +@dataclass(frozen=True, slots=True) +class RateLimit: + limit: int + window: timedelta + scopes: tuple[RateLimitScope, ...] + + +LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,)) +LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,)) +LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,)) +LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,)) +LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,)) +LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,)) +LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,)) + + +def _one_key(scope: RateLimitScope) -> str: + match scope: + case RateLimitScope.IP: + return f"ip:{extract_remote_ip(request) or 'unknown'}" + case RateLimitScope.SESSION: + return f"session:{session.get('_id', 'anon')}" + case RateLimitScope.ACCOUNT: + ctx = getattr(g, "auth_ctx", None) + if ctx and ctx.account_id: + return f"account:{ctx.account_id}" + return "account:anon" + case RateLimitScope.SUBJECT_EMAIL: + ctx = getattr(g, "auth_ctx", None) + if ctx and ctx.subject_email: + return f"subject:{ctx.subject_email}" + return "subject:anon" + case RateLimitScope.TOKEN_ID: + ctx = getattr(g, "auth_ctx", None) + if ctx and ctx.token_id: + return f"token:{ctx.token_id}" + return "token:anon" + + +def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str: + return "|".join(_one_key(s) for s in scopes) + + +def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str: + return "rl:" + "+".join(s.value for s in scopes) + + +def _build_limiter(spec: RateLimit) -> RateLimiter: + return RateLimiter( + prefix=_limiter_prefix(spec.scopes), + max_attempts=spec.limit, + time_window=int(spec.window.total_seconds()), + ) + + +def rate_limit(spec: RateLimit) -> Callable: + """Apply after auth decorators that the scopes read from.""" + limiter = _build_limiter(spec) + + def wrap(fn: Callable) -> Callable: + @wraps(fn) + def inner(*args, **kwargs): + key = _composite_key(spec.scopes) + if limiter.is_rate_limited(key): + raise TooManyRequests("rate_limited") + limiter.increment_rate_limit(key) + return fn(*args, **kwargs) + + return inner + + return wrap + + +def enforce(spec: RateLimit, *, key: str) -> None: + """Imperative form — caller composes the bucket key to match scope + semantics (the key is opaque here). + """ + limiter = _build_limiter(spec) + if limiter.is_rate_limited(key): + raise TooManyRequests("rate_limited") + limiter.increment_rate_limit(key) diff --git a/api/migrations/versions/2026_04_23_2200-d4a5e1f3c9b7_add_oauth_access_tokens.py b/api/migrations/versions/2026_04_23_2200-d4a5e1f3c9b7_add_oauth_access_tokens.py new file mode 100644 index 0000000000..a0e34b9a17 --- /dev/null +++ b/api/migrations/versions/2026_04_23_2200-d4a5e1f3c9b7_add_oauth_access_tokens.py @@ -0,0 +1,102 @@ +"""add oauth_access_tokens table + +Revision ID: d4a5e1f3c9b7 +Revises: 227822d22895, b69ca54b9208, 2a3aebbbf4bb +Create Date: 2026-04-23 22:00:00.000000 + +Merges the three open heads at time of authoring (add_workflow_comments_table, +add_chatbot_color_theme, add_app_tracing) into a single parent so the new +oauth_access_tokens table sits on a definite linear chain thereafter. + +Table stores user-level OAuth bearer tokens minted via the device-flow grant +(difyctl auth login). PAT storage (personal_access_tokens) is a separate +table not added in this migration. +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "d4a5e1f3c9b7" +down_revision = ("227822d22895", "b69ca54b9208", "2a3aebbbf4bb") +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "oauth_access_tokens", + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + server_default=sa.text("gen_random_uuid()"), + nullable=False, + primary_key=True, + ), + sa.Column("subject_email", sa.Text(), nullable=False), + sa.Column("subject_issuer", sa.Text(), nullable=True), + sa.Column("account_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("client_id", sa.String(length=64), nullable=False), + sa.Column("device_label", sa.Text(), nullable=False), + sa.Column("prefix", sa.String(length=8), nullable=False), + sa.Column("token_hash", sa.String(length=64), nullable=True, unique=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("NOW()"), + nullable=False, + ), + sa.Column("last_used_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["account_id"], + ["accounts.id"], + name="fk_oauth_access_tokens_account_id", + ondelete="SET NULL", + ), + ) + + op.create_index( + "idx_oauth_subject_email", + "oauth_access_tokens", + ["subject_email"], + postgresql_where=sa.text("revoked_at IS NULL"), + ) + op.create_index( + "idx_oauth_account", + "oauth_access_tokens", + ["account_id"], + postgresql_where=sa.text("revoked_at IS NULL AND account_id IS NOT NULL"), + ) + op.create_index( + "idx_oauth_client", + "oauth_access_tokens", + ["subject_email", "client_id"], + postgresql_where=sa.text("revoked_at IS NULL"), + ) + op.create_index( + "idx_oauth_token_hash", + "oauth_access_tokens", + ["token_hash"], + postgresql_where=sa.text("revoked_at IS NULL"), + ) + # Partial unique index — rotate-in-place keyed on (subject, client, device). + # subject_issuer NULL vs populated distinguishes account vs external-SSO rows + # for the same email, because Postgres treats NULL as distinct. + op.create_index( + "uq_oauth_active_per_device", + "oauth_access_tokens", + ["subject_email", "subject_issuer", "client_id", "device_label"], + unique=True, + postgresql_where=sa.text("revoked_at IS NULL"), + ) + + +def downgrade(): + op.drop_index("uq_oauth_active_per_device", table_name="oauth_access_tokens") + op.drop_index("idx_oauth_token_hash", table_name="oauth_access_tokens") + op.drop_index("idx_oauth_client", table_name="oauth_access_tokens") + op.drop_index("idx_oauth_account", table_name="oauth_access_tokens") + op.drop_index("idx_oauth_subject_email", table_name="oauth_access_tokens") + op.drop_table("oauth_access_tokens") diff --git a/api/models/__init__.py b/api/models/__init__.py index 85be9ca3bd..4880f94779 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -73,7 +73,7 @@ from .model import ( TrialApp, UploadFile, ) -from .oauth import DatasourceOauthParamConfig, DatasourceProvider +from .oauth import DatasourceOauthParamConfig, DatasourceProvider, OAuthAccessToken from .provider import ( LoadBalancingModelConfig, Provider, @@ -177,6 +177,7 @@ __all__ = [ "MessageChain", "MessageFeedback", "MessageFile", + "OAuthAccessToken", "OperationLog", "PinnedConversation", "Provider", diff --git a/api/models/oauth.py b/api/models/oauth.py index bd04d890d3..a88dd9345d 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import Any, Optional import sqlalchemy as sa from sqlalchemy import func @@ -84,3 +84,38 @@ class DatasourceOauthTenantParamConfig(TypeBase): onupdate=func.current_timestamp(), init=False, ) + + +class OAuthAccessToken(TypeBase): + """Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account); + account_id NULL + subject_issuer ⇒ dfoe_ (external SSO, EE-only). + Partial unique index on (subject_email, subject_issuer, client_id, + device_label) WHERE revoked_at IS NULL lets re-login rotate in place. + """ + + __tablename__ = "oauth_access_tokens" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"), + ) + + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False + ) + subject_email: Mapped[str] = mapped_column(sa.Text, nullable=False) + client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False) + device_label: Mapped[str] = mapped_column(sa.Text, nullable=False) + prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False) + expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False) + subject_issuer: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True, default=None) + account_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True, default=None) + token_hash: Mapped[Optional[str]] = mapped_column(sa.String(64), nullable=True, default=None) + last_used_at: Mapped[Optional[datetime]] = mapped_column( + sa.DateTime(timezone=True), nullable=True, default=None + ) + revoked_at: Mapped[Optional[datetime]] = mapped_column( + sa.DateTime(timezone=True), nullable=True, default=None + ) + + created_at: Mapped[datetime] = mapped_column( + sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False + ) diff --git a/api/schedule/clean_oauth_access_tokens_task.py b/api/schedule/clean_oauth_access_tokens_task.py new file mode 100644 index 0000000000..b4b7dc0236 --- /dev/null +++ b/api/schedule/clean_oauth_access_tokens_task.py @@ -0,0 +1,57 @@ +"""DELETE oauth_access_tokens past retention. Revocation is UPDATE +(token_id stays for audits) so rows accumulate across re-logins, and +expired-but-never-presented rows have no hard-expire trigger — both get +pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire. +""" +from __future__ import annotations + +import logging +import time +from datetime import UTC, datetime, timedelta + +import click +from sqlalchemy import delete, or_, select + +import app +from configs import dify_config +from extensions.ext_database import db +from models.oauth import OAuthAccessToken + +logger = logging.getLogger(__name__) + +DELETE_BATCH_SIZE = 500 + + +@app.celery.task(queue="retention") +def clean_oauth_access_tokens_task(): + click.echo(click.style("Start clean oauth_access_tokens.", fg="green")) + retention_days = int(dify_config.OAUTH_ACCESS_TOKEN_RETENTION_DAYS) + cutoff = datetime.now(UTC) - timedelta(days=retention_days) + start_at = time.perf_counter() + + candidates = or_( + OAuthAccessToken.revoked_at < cutoff, + # Zombies: expired but never re-presented, so middleware never flipped them. + (OAuthAccessToken.revoked_at.is_(None)) + & (OAuthAccessToken.expires_at < cutoff), + ) + + total = 0 + while True: + ids = db.session.scalars( + select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE) + ).all() + if not ids: + break + db.session.execute( + delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)) + ) + db.session.commit() + total += len(ids) + + end_at = time.perf_counter() + click.echo(click.style( + f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d " + f"in {end_at - start_at:.2f}s", + fg="green", + )) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 5040fcc7e3..6d61a3c3e5 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -106,6 +106,15 @@ class EnterpriseService: def get_workspace_info(cls, tenant_id: str): return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") + @classmethod + def initiate_device_flow_sso(cls, signed_state: str) -> dict: + return EnterpriseRequest.send_request( + "POST", + "/device-flow/sso-initiate", + json={"signed_state": signed_state}, + raise_for_status=True, + ) + @classmethod def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult: """ diff --git a/api/services/oauth_device_flow.py b/api/services/oauth_device_flow.py new file mode 100644 index 0000000000..381d6d6a85 --- /dev/null +++ b/api/services/oauth_device_flow.py @@ -0,0 +1,417 @@ +"""Device-flow service layer: Redis state machine, OAuth token mint +(DB upsert + plaintext generation), and TTL policy. Specs: +docs/specs/v1.0/server/{device-flow.md, tokens.md}. +""" +from __future__ import annotations + +import hashlib +import json +import logging +import os +import secrets +import time +import uuid +from dataclasses import asdict, dataclass, field +from datetime import UTC, datetime, timedelta +from enum import StrEnum + +from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT +from models.oauth import OAuthAccessToken +from sqlalchemy import func, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Redis state machine — device_code + user_code ephemeral state +# ============================================================================ + + +DEVICE_CODE_KEY_FMT = "device_code:{code}" +USER_CODE_KEY_FMT = "user_code:{code}" + +DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in +APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor + +USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped +USER_CODE_SEGMENT_LEN = 4 +USER_CODE_MAX_CLAIM_ATTEMPTS = 5 + +DEFAULT_POLL_INTERVAL_SECONDS = 5 # RFC 8628 minimum + + +class DeviceFlowStatus(StrEnum): + PENDING = "pending" + APPROVED = "approved" + DENIED = "denied" + + +class SlowDownDecision(StrEnum): + OK = "ok" + SLOW_DOWN = "slow_down" + + +@dataclass +class DeviceFlowState: + """``minted_token`` is plaintext between approve and the next poll; + DEL'd after the poll reads it. + """ + + user_code: str + client_id: str + device_label: str + status: DeviceFlowStatus + subject_email: str | None = None + account_id: str | None = None + subject_issuer: str | None = None + minted_token: str | None = None + token_id: str | None = None + created_at: str = "" + created_ip: str = "" + last_poll_at: str = "" + poll_payload: dict | None = field(default=None) + + def to_json(self) -> str: + return json.dumps(asdict(self)) + + @classmethod + def from_json(cls, raw: str) -> "DeviceFlowState": + data = json.loads(raw) + if "status" in data: + data["status"] = DeviceFlowStatus(data["status"]) + return cls(**data) + + +def _random_device_code() -> str: + return "dc_" + secrets.token_urlsafe(24) + + +def _random_user_code_segment() -> str: + return "".join(secrets.choice(USER_CODE_ALPHABET) for _ in range(USER_CODE_SEGMENT_LEN)) + + +def _random_user_code() -> str: + return f"{_random_user_code_segment()}-{_random_user_code_segment()}" + + +class StateNotFound(Exception): + pass + + +class InvalidTransition(Exception): + pass + + +class UserCodeExhausted(Exception): + pass + + +class DeviceFlowRedis: + + def __init__(self, redis_client) -> None: + self._redis = redis_client + + def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]: + device_code = _random_device_code() + user_code = self._claim_user_code(device_code) + state = DeviceFlowState( + user_code=user_code, + client_id=client_id, + device_label=device_label, + status=DeviceFlowStatus.PENDING, + created_at=datetime.now(UTC).isoformat(), + created_ip=created_ip, + ) + self._redis.setex( + DEVICE_CODE_KEY_FMT.format(code=device_code), + DEVICE_FLOW_TTL_SECONDS, + state.to_json(), + ) + return device_code, user_code, DEVICE_FLOW_TTL_SECONDS + + def _claim_user_code(self, device_code: str) -> str: + for _ in range(USER_CODE_MAX_CLAIM_ATTEMPTS): + user_code = _random_user_code() + key = USER_CODE_KEY_FMT.format(code=user_code) + ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS) + if ok: + return user_code + raise UserCodeExhausted("could not allocate a unique user_code in 5 attempts") + + def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None: + raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code)) + if not raw_dc: + return None + device_code = raw_dc.decode() if isinstance(raw_dc, (bytes, bytearray)) else raw_dc + state = self._load_state(device_code) + if state is None: + return None + return device_code, state + + def load_by_device_code(self, device_code: str) -> DeviceFlowState | None: + return self._load_state(device_code) + + def _load_state(self, device_code: str) -> DeviceFlowState | None: + raw = self._redis.get(DEVICE_CODE_KEY_FMT.format(code=device_code)) + if not raw: + return None + text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw + try: + return DeviceFlowState.from_json(text_) + except (ValueError, KeyError): + logger.error("device_flow: corrupt state for %s", device_code) + return None + + def approve( + self, + device_code: str, + subject_email: str, + account_id: str | None, + minted_token: str, + token_id: str, + subject_issuer: str | None = None, + poll_payload: dict | None = None, + ) -> None: + state = self._load_state(device_code) + if state is None: + raise StateNotFound(device_code) + if state.status is not DeviceFlowStatus.PENDING: + raise InvalidTransition(f"cannot approve {state.status}") + + state.status = DeviceFlowStatus.APPROVED + state.subject_email = subject_email + state.account_id = account_id + state.subject_issuer = subject_issuer + state.minted_token = minted_token + state.token_id = token_id + state.poll_payload = poll_payload + + new_ttl = self._remaining_ttl(device_code, floor=APPROVED_TTL_SECONDS_MIN) + self._redis.setex(DEVICE_CODE_KEY_FMT.format(code=device_code), new_ttl, state.to_json()) + + def deny(self, device_code: str) -> None: + state = self._load_state(device_code) + if state is None: + raise StateNotFound(device_code) + if state.status is not DeviceFlowStatus.PENDING: + raise InvalidTransition(f"cannot deny {state.status}") + state.status = DeviceFlowStatus.DENIED + self._redis.setex( + DEVICE_CODE_KEY_FMT.format(code=device_code), + self._remaining_ttl(device_code, floor=1), + state.to_json(), + ) + + def consume_on_poll(self, device_code: str) -> DeviceFlowState | None: + """Race-safe via DEL: concurrent polls — one wins, the other gets + None and the caller maps that to expired_token. + """ + state = self._load_state(device_code) + if state is None: + return None + if state.status is DeviceFlowStatus.PENDING: + return None + self._redis.delete( + DEVICE_CODE_KEY_FMT.format(code=device_code), + USER_CODE_KEY_FMT.format(code=state.user_code), + ) + return state + + def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision: + now = time.time() + key = f"device_code:{device_code}:last_poll" + prev_raw = self._redis.get(key) + self._redis.setex(key, DEVICE_FLOW_TTL_SECONDS, str(now)) + if prev_raw is None: + return SlowDownDecision.OK + prev_s = prev_raw.decode() if isinstance(prev_raw, (bytes, bytearray)) else prev_raw + try: + prev = float(prev_s) + except ValueError: + return SlowDownDecision.OK + if now - prev < interval_seconds: + return SlowDownDecision.SLOW_DOWN + return SlowDownDecision.OK + + def _remaining_ttl(self, device_code: str, floor: int) -> int: + """``max(remaining, floor)`` — guarantees the CLI has at least + ``floor`` seconds to poll after a near-expiry approve. + """ + ttl = self._redis.ttl(DEVICE_CODE_KEY_FMT.format(code=device_code)) + if ttl is None or ttl < 0: + return floor + return max(int(ttl), floor) + + +# ============================================================================ +# Token mint — generate + upsert +# ============================================================================ + + +OAUTH_BODY_BYTES = 32 # ~256 bits entropy +PREFIX_OAUTH_ACCOUNT = "dfoa_" +PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_" + + +@dataclass(frozen=True, slots=True) +class MintResult: + """Plaintext token surfaces to the caller once.""" + token: str + token_id: uuid.UUID + expires_at: datetime + + +@dataclass(frozen=True, slots=True) +class UpsertOutcome: + token_id: uuid.UUID + rotated: bool + old_hash: str | None + + +def generate_token(prefix: str) -> str: + return prefix + secrets.token_urlsafe(OAUTH_BODY_BYTES) + + +def sha256_hex(token: str) -> str: + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + +def mint_oauth_token( + session: Session, + redis_client, + *, + subject_email: str, + subject_issuer: str | None, + account_id: str | None, + client_id: str, + device_label: str, + prefix: str, + ttl_days: int, +) -> MintResult: + """Live row rotates in place via partial unique index + ``uq_oauth_active_per_device``; hard-expired rows are excluded by the + index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is + deleted so stale AuthContext drops immediately. + """ + if prefix not in (PREFIX_OAUTH_ACCOUNT, PREFIX_OAUTH_EXTERNAL_SSO): + raise ValueError(f"unknown oauth prefix: {prefix!r}") + + token = generate_token(prefix) + new_hash = sha256_hex(token) + expires_at = datetime.now(UTC) + timedelta(days=ttl_days) + + outcome = _upsert( + session, + subject_email=subject_email, + subject_issuer=subject_issuer, + account_id=account_id, + client_id=client_id, + device_label=device_label, + prefix=prefix, + new_hash=new_hash, + expires_at=expires_at, + ) + + if outcome.rotated and outcome.old_hash: + redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=outcome.old_hash)) + + return MintResult(token=token, token_id=outcome.token_id, expires_at=expires_at) + + +def _upsert( + session: Session, + *, + subject_email: str, + subject_issuer: str | None, + account_id: str | None, + client_id: str, + device_label: str, + prefix: str, + new_hash: str, + expires_at: datetime, +) -> UpsertOutcome: + # Snapshot prior live row's hash for Redis invalidation post-rotate. + prior = session.execute( + select(OAuthAccessToken.id, OAuthAccessToken.token_hash) + .where( + OAuthAccessToken.subject_email == subject_email, + OAuthAccessToken.subject_issuer.is_not_distinct_from(subject_issuer), + OAuthAccessToken.client_id == client_id, + OAuthAccessToken.device_label == device_label, + OAuthAccessToken.revoked_at.is_(None), + ) + .limit(1) + ).first() + old_hash = prior.token_hash if prior else None + + insert_stmt = pg_insert(OAuthAccessToken).values( + subject_email=subject_email, + subject_issuer=subject_issuer, + account_id=account_id, + client_id=client_id, + device_label=device_label, + prefix=prefix, + token_hash=new_hash, + expires_at=expires_at, + ) + upsert_stmt = insert_stmt.on_conflict_do_update( + index_elements=["subject_email", "subject_issuer", "client_id", "device_label"], + index_where=OAuthAccessToken.revoked_at.is_(None), + set_={ + "token_hash": insert_stmt.excluded.token_hash, + "prefix": insert_stmt.excluded.prefix, + "account_id": insert_stmt.excluded.account_id, + "expires_at": insert_stmt.excluded.expires_at, + "created_at": func.now(), + "last_used_at": None, + }, + ).returning(OAuthAccessToken.id) + row = session.execute(upsert_stmt).first() + session.commit() + + token_id = uuid.UUID(str(row.id)) + return UpsertOutcome( + token_id=token_id, + rotated=prior is not None, + old_hash=old_hash, + ) + + +# ============================================================================ +# TTL policy — days new OAuth tokens live +# ============================================================================ + + +DEFAULT_OAUTH_TTL_DAYS = 14 +MIN_TTL_DAYS = 1 +MAX_TTL_DAYS = 365 + +_TTL_ENV_VAR = "OAUTH_TTL_DAYS" + + +def oauth_ttl_days(tenant_id: str | None = None) -> int: + """``OAUTH_TTL_DAYS`` env, else default. EE tenant-level lookup + is deferred; when it lands it wins over the env (Redis-cached 60s). + """ + _ = tenant_id + + raw = os.environ.get(_TTL_ENV_VAR) + if raw is None: + return DEFAULT_OAUTH_TTL_DAYS + try: + value = int(raw) + except ValueError: + logger.warning( + "%s=%r is not an int; falling back to %d", + _TTL_ENV_VAR, raw, DEFAULT_OAUTH_TTL_DAYS, + ) + return DEFAULT_OAUTH_TTL_DAYS + if value < MIN_TTL_DAYS: + logger.warning("%s=%d below min %d; clamping", _TTL_ENV_VAR, value, MIN_TTL_DAYS) + return MIN_TTL_DAYS + if value > MAX_TTL_DAYS: + logger.warning("%s=%d above max %d; clamping", _TTL_ENV_VAR, value, MAX_TTL_DAYS) + return MAX_TTL_DAYS + return value diff --git a/web/app/device/components/authorize-account.tsx b/web/app/device/components/authorize-account.tsx new file mode 100644 index 0000000000..a02088c48f --- /dev/null +++ b/web/app/device/components/authorize-account.tsx @@ -0,0 +1,96 @@ +'use client' + +import type { FC } from 'react' +import { useState } from 'react' +import { deviceApproveAccount, deviceDenyAccount } from '@/service/device-flow' + +type Props = { + userCode: string + accountEmail?: string + defaultWorkspace?: string + onApproved: () => void + onDenied: () => void + onError: (message: string) => void +} + +/** + * AuthorizeAccount is the account-branch authorize screen. Called with a + * live console session already established (user bounced through /signin). + * Posts to /console/api/oauth/device/{approve,deny}; these endpoints mint + * the dfoa_ token server-side. + */ +const AuthorizeAccount: FC = ({ + userCode, accountEmail, defaultWorkspace, onApproved, onDenied, onError, +}) => { + const [busy, setBusy] = useState(false) + + const approve = async () => { + setBusy(true) + try { + await deviceApproveAccount(userCode) + onApproved() + } + catch (e: any) { + onError(e?.message || 'Approve failed') + } + finally { + setBusy(false) + } + } + + const deny = async () => { + setBusy(true) + try { + await deviceDenyAccount(userCode) + onDenied() + } + catch (e: any) { + onError(e?.message || 'Deny failed') + } + finally { + setBusy(false) + } + } + + return ( +
+
+

Authorize Dify CLI

+

+ Dify CLI (difyctl) is requesting access to your account. + {' '}If you did not start this from your terminal, click Cancel. +

+
+
+ {accountEmail && ( +

+ Signed in as {accountEmail} +

+ )} + {defaultWorkspace && ( +

+ Default workspace: {defaultWorkspace} +

+ )} +
+
+ + +
+
+ ) +} + +export default AuthorizeAccount diff --git a/web/app/device/components/authorize-sso.tsx b/web/app/device/components/authorize-sso.tsx new file mode 100644 index 0000000000..a327c54858 --- /dev/null +++ b/web/app/device/components/authorize-sso.tsx @@ -0,0 +1,96 @@ +'use client' + +import type { FC } from 'react' +import { useEffect, useState } from 'react' +import type { ApprovalContext } from '@/service/device-flow' +import { approveExternal, fetchApprovalContext } from '@/service/device-flow' + +type Props = { + onApproved: () => void + onError: (message: string) => void +} + +/** + * AuthorizeSSO is the external-SSO branch authorize screen. On mount it + * fetches /v1/oauth/device/approval-context to learn subject_email, issuer, + * user_code, and csrf_token from the device_approval_grant cookie. On + * Approve click, posts /v1/oauth/device/approve-external with the CSRF header. + * + * The user_code in state is bound to the cookie by server; we do not accept + * one from the URL because the SSO branch deliberately detaches from the + * pre-SSO ?user_code=... query param. + */ +const AuthorizeSSO: FC = ({ onApproved, onError }) => { + const [ctx, setCtx] = useState(null) + const [busy, setBusy] = useState(false) + const [loadErr, setLoadErr] = useState(null) + + useEffect(() => { + let cancelled = false + fetchApprovalContext() + .then((c) => { if (!cancelled) setCtx(c) }) + .catch((e: any) => { + if (!cancelled) + setLoadErr(e?.message || 'Failed to load session') + }) + return () => { cancelled = true } + }, []) + + const approve = async () => { + if (!ctx) return + setBusy(true) + try { + await approveExternal(ctx, ctx.user_code) + onApproved() + } + catch (e: any) { + onError(e?.message || 'Approve failed') + } + finally { + setBusy(false) + } + } + + if (loadErr) { + return ( +
+

This session is no longer valid

+

+ Run difyctl auth login again to start a new sign-in. +

+
+ ) + } + if (!ctx) { + return
Loading session…
+ } + + return ( +
+
+

Authorize Dify CLI

+

+ Dify CLI (difyctl) is requesting access via SSO. If you did not start + this from your terminal, close this tab. +

+
+
+

+ Signed in as {ctx.subject_email} +

+

+ Issuer: {ctx.subject_issuer} +

+
+ +
+ ) +} + +export default AuthorizeSSO diff --git a/web/app/device/components/chooser.tsx b/web/app/device/components/chooser.tsx new file mode 100644 index 0000000000..751d19b897 --- /dev/null +++ b/web/app/device/components/chooser.tsx @@ -0,0 +1,60 @@ +'use client' + +import type { FC } from 'react' +import { useRouter } from '@/next/navigation' +import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' + +type Props = { + userCode: string + ssoAvailable: boolean +} + +/** + * Chooser renders the two-button device-auth login selector. Account button + * seeds postLoginRedirect + navigates to /signin so every existing account + * login method (password / email-code / social OAuth / account-SSO) flows + * through its usual plumbing. SSO button hits /v1/oauth/device/sso-initiate + * directly — the SSO branch skips /signin entirely. + * + * v1.0 scope: only account-SSO honours postLoginRedirect (via sso-auth's + * return_to plumbing). Password / email-code / social-OAuth users land on + * /signin's default post-login target and manually return to the /device + * URL printed by the CLI. That's not great UX; a follow-up milestone + * generalises post-signin redirect to all methods. + */ +const Chooser: FC = ({ userCode, ssoAvailable }) => { + const router = useRouter() + + const onAccount = () => { + setPostLoginRedirect(`/device?user_code=${encodeURIComponent(userCode)}`) + router.push('/signin') + } + + const onSSO = () => { + // Full-page navigation, not router.push — /v1/oauth/device/sso-initiate + // issues a 302 to the IdP. Next's client router can't follow cross- + // origin redirects; a plain window.location assignment handles it. + window.location.href = `/v1/oauth/device/sso-initiate?user_code=${encodeURIComponent(userCode)}` + } + + return ( +
+ + {ssoAvailable && ( + + )} +
+ ) +} + +export default Chooser diff --git a/web/app/device/components/code-input.tsx b/web/app/device/components/code-input.tsx new file mode 100644 index 0000000000..1d358f782b --- /dev/null +++ b/web/app/device/components/code-input.tsx @@ -0,0 +1,45 @@ +'use client' + +import type { FC } from 'react' +import { useCallback } from 'react' +import { normaliseUserCodeInput } from '../utils/user-code' + +type Props = { + value: string + onChange: (normalised: string) => void + disabled?: boolean + autoFocus?: boolean +} + +/** + * CodeInput renders the user_code text field with live normalisation + * (uppercase, reduced alphabet, XXXX-XXXX hyphenation). + * + * The onChange callback receives the normalised value only — the parent does + * not need to run validation itself. + */ +const CodeInput: FC = ({ value, onChange, disabled, autoFocus }) => { + const handle = useCallback((raw: string) => { + onChange(normaliseUserCodeInput(raw)) + }, [onChange]) + + return ( + handle(e.target.value)} + /> + ) +} + +export default CodeInput diff --git a/web/app/device/page.tsx b/web/app/device/page.tsx new file mode 100644 index 0000000000..0d19448fd7 --- /dev/null +++ b/web/app/device/page.tsx @@ -0,0 +1,173 @@ +'use client' + +import { useEffect, useState } from 'react' +import { useSearchParams } from '@/next/navigation' +import { useQuery } from '@tanstack/react-query' +import { systemFeaturesQueryOptions } from '@/service/system-features' +import { commonQueryKeys, userProfileQueryOptions } from '@/service/use-common' +import { post } from '@/service/base' +import type { ICurrentWorkspace } from '@/models/common' +import { deviceLookup } from '@/service/device-flow' +import CodeInput from './components/code-input' +import Chooser from './components/chooser' +import AuthorizeAccount from './components/authorize-account' +import AuthorizeSSO from './components/authorize-sso' +import { isValidUserCode } from './utils/user-code' + +type View = + | { kind: 'code_entry' } + | { kind: 'chooser'; userCode: string } + | { kind: 'authorize_account'; userCode: string } + | { kind: 'authorize_sso' } + | { kind: 'success' } + | { kind: 'error_expired' } + +export default function DevicePage() { + const searchParams = useSearchParams() + const urlUserCode = (searchParams.get('user_code') || '').trim().toUpperCase() + const ssoVerified = searchParams.get('sso_verified') === '1' + + const [typed, setTyped] = useState('') + const [view, setView] = useState({ kind: 'code_entry' }) + const [errMsg, setErrMsg] = useState(null) + + // Account subject + workspace identity (for the authorize-account screen). + // Logged-out is a valid landing state on /device — disable refetch storms + // and skip workspace probe until profile resolves (avoids /current + chained + // /refresh-token 401 loops while the user is still entering the code). + const { data: userResp, isError: profileErr } = useQuery({ + ...userProfileQueryOptions(), + throwOnError: false, + retry: false, + refetchOnWindowFocus: false, + refetchOnMount: false, + }) + const account = userResp?.profile + const { data: currentWorkspace } = useQuery({ + queryKey: commonQueryKeys.currentWorkspace, + queryFn: () => post('/workspaces/current'), + enabled: !!account && !profileErr, + retry: false, + refetchOnWindowFocus: false, + }) + const { data: sys } = useQuery(systemFeaturesQueryOptions()) + // Device-flow SSO branch uses external-user (webapp) SSO, not console SSO — + // backend mints EXTERNAL_SSO tokens via Enterprise's external ACS. Gate on + // webapp_auth.{enabled, allow_sso} + a configured webapp SSO protocol. + const ssoAvailable = !!sys?.webapp_auth?.enabled + && !!sys?.webapp_auth?.allow_sso + && (sys?.webapp_auth?.sso_config?.protocol || '') !== '' + + // URL-driven view transitions. Only advances while the user is still on + // the entry/chooser screens — never clobbers terminal views (success / + // error_expired / authorize_*) when userProfile refetches. + useEffect(() => { + if (view.kind !== 'code_entry' && view.kind !== 'chooser') return + if (ssoVerified) { + setView({ kind: 'authorize_sso' }) + return + } + if (urlUserCode && isValidUserCode(urlUserCode)) { + if (account) + setView({ kind: 'authorize_account', userCode: urlUserCode }) + else + setView({ kind: 'chooser', userCode: urlUserCode }) + } + }, [urlUserCode, ssoVerified, account, view.kind]) + + const onContinue = async () => { + if (!isValidUserCode(typed)) return + try { + const reply = await deviceLookup(typed) + if (!reply.valid) { + setView({ kind: 'error_expired' }) + return + } + } + catch { + setView({ kind: 'error_expired' }) + return + } + if (account) setView({ kind: 'authorize_account', userCode: typed }) + else setView({ kind: 'chooser', userCode: typed }) + } + + return ( +
+
+ {view.kind === 'code_entry' && ( +
+
+

Authorize Dify CLI

+

+ Enter the code shown in your terminal. +

+
+ + +
+ )} + + {view.kind === 'chooser' && ( +
+
+

Sign in to authorize

+

+ Code {view.userCode} is valid. Choose how to sign in. +

+
+ +
+ )} + + {view.kind === 'authorize_account' && ( + setView({ kind: 'success' })} + onDenied={() => setView({ kind: 'error_expired' })} + onError={e => setErrMsg(e)} + /> + )} + + {view.kind === 'authorize_sso' && ( + setView({ kind: 'success' })} + onError={e => setErrMsg(e)} + /> + )} + + {view.kind === 'success' && ( +
+

You're signed in

+

Return to your terminal to continue.

+
+ )} + + {view.kind === 'error_expired' && ( +
+

This code is no longer valid

+

+ The code may have expired or already been used. Run + {' '} + difyctl auth login + {' '} + again to get a new one. +

+
+ )} + + {errMsg && ( +

{errMsg}

+ )} +
+
+ ) +} diff --git a/web/app/device/utils/user-code.ts b/web/app/device/utils/user-code.ts new file mode 100644 index 0000000000..1753da16f3 --- /dev/null +++ b/web/app/device/utils/user-code.ts @@ -0,0 +1,37 @@ +// user-code.ts — input normalisation + validation for the RFC 8628 +// 8-character user_code format the CLI prints to stderr. +// +// Format: XXXX-XXXX, uppercase, reduced alphabet (no 0/O, 1/I/l, 2/Z). Low +// entropy by design — humans type it — so the server-side rate-limit + TTL + +// single-use properties are what defend it, not the alphabet. + +export const USER_CODE_ALPHABET = 'ABCDEFGHJKLMNPQRSTUVWXY3456789' // excludes 0 O 1 I L 2 Z + +/** + * normaliseUserCodeInput prepares raw input for display in the code field: + * strips non-alphanumerics, uppercases, drops disallowed characters, and + * inserts the hyphen after the fourth accepted char. + * + * Returns at most 9 chars ("XXXX-XXXX"); longer input is truncated. + */ +export function normaliseUserCodeInput(raw: string): string { + const cleaned: string[] = [] + for (const ch of raw.toUpperCase()) { + if (USER_CODE_ALPHABET.includes(ch)) + cleaned.push(ch) + if (cleaned.length === 8) + break + } + if (cleaned.length <= 4) + return cleaned.join('') + return `${cleaned.slice(0, 4).join('')}-${cleaned.slice(4).join('')}` +} + +/** + * isValidUserCode tests whether the normalised form is a complete XXXX-XXXX + * token suitable for submission to /console/api/oauth/device/lookup. + */ +export function isValidUserCode(normalised: string): boolean { + return /^[A-Z0-9]{4}-[A-Z0-9]{4}$/.test(normalised) + && [...normalised.replace('-', '')].every(c => USER_CODE_ALPHABET.includes(c)) +} diff --git a/web/app/signin/utils/post-login-redirect.ts b/web/app/signin/utils/post-login-redirect.ts index a94fb2ad79..291661a87c 100644 --- a/web/app/signin/utils/post-login-redirect.ts +++ b/web/app/signin/utils/post-login-redirect.ts @@ -1,15 +1,68 @@ -let postLoginRedirect: string | null = null +// Persists target across full-page redirects within the same tab (social +// OAuth, SSO IdP bounce). sessionStorage is tab-scoped so concurrent +// /device tabs don't clobber each other. 15-min TTL drops stale values. +// Same-origin + exact-path whitelist prevents open-redirect. +// +// Signup-via-email-link opening in a new tab is out of scope — that tab +// starts with an empty sessionStorage and falls to /apps default. + +const KEY = 'dify_post_login_redirect' +const TTL_MS = 15 * 60 * 1000 + +const ALLOWED: Record> = { + '/device': new Set(['user_code', 'sso_verified']), + '/account/oauth/authorize': new Set(['client_id', 'scope', 'state', 'redirect_uri']), +} + +function validate(target: string): string | null { + if (typeof window === 'undefined') return null + try { + const url = new URL(target, window.location.origin) + if (url.origin !== window.location.origin) return null + const allowedKeys = ALLOWED[url.pathname] + if (!allowedKeys) return null + for (const key of url.searchParams.keys()) { + if (!allowedKeys.has(key)) return null + } + return url.pathname + (url.search || '') + } + catch { + return null + } +} export const setPostLoginRedirect = (value: string | null) => { - postLoginRedirect = value -} - -export const resolvePostLoginRedirect = () => { - if (postLoginRedirect) { - const redirectUrl = postLoginRedirect - postLoginRedirect = null - return redirectUrl + if (typeof window === 'undefined') return + if (value === null) { + try { sessionStorage.removeItem(KEY) } catch {} + return + } + const safe = validate(value) + if (!safe) return + try { + sessionStorage.setItem(KEY, JSON.stringify({ target: safe, ts: Date.now() })) + } + catch {} +} + +export const resolvePostLoginRedirect = (): string | null => { + if (typeof window === 'undefined') return null + let raw: string | null = null + try { + raw = sessionStorage.getItem(KEY) + sessionStorage.removeItem(KEY) + } + catch { + return null + } + if (!raw) return null + try { + const parsed = JSON.parse(raw) + if (typeof parsed?.target !== 'string' || typeof parsed?.ts !== 'number') return null + if (Date.now() - parsed.ts > TTL_MS) return null + return validate(parsed.target) + } + catch { + return null } - - return null } diff --git a/web/next.config.ts b/web/next.config.ts index db44f5b9ed..741cd0afc1 100644 --- a/web/next.config.ts +++ b/web/next.config.ts @@ -30,6 +30,20 @@ const nextConfig: NextConfig = { }, ] }, + // Anti-framing for device-flow surfaces. A framed /device page could UI-trick + // a victim with a valid device_approval_grant cookie into approving a + // device_code — functionally CSRF, bypasses the double-submit token. Deny + // framing outright on every device-flow route; no trusted embedder exists. + async headers() { + const antiFrame = [ + { key: 'X-Frame-Options', value: 'DENY' }, + { key: 'Content-Security-Policy', value: "frame-ancestors 'none'" }, + ] + return [ + { source: '/device', headers: antiFrame }, + { source: '/device/:path*', headers: antiFrame }, + ] + }, output: 'standalone', compiler: { removeConsole: isDev ? false : { exclude: ['warn', 'error'] }, diff --git a/web/service/base.ts b/web/service/base.ts index 64d13ef59a..e278771db5 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -794,6 +794,11 @@ export const request = async(url: string, options = {}, otherOptions?: IOther const [refreshErr] = await asyncRunSafe(refreshAccessTokenOrReLogin(TIME_OUT)) if (refreshErr === null) return baseFetch(url, options, otherOptionsForBaseFetch) + // /device is the device-flow chooser; logged-out is a valid state + // there. Redirecting to /signin loses the user_code context and + // the post-login flow lands on /apps instead of returning here. + if (location.pathname === `${basePath}/device`) + return Promise.reject(err) if (location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) { jumpTo(loginUrl) return Promise.reject(err) diff --git a/web/service/device-flow.ts b/web/service/device-flow.ts new file mode 100644 index 0000000000..b64cea0331 --- /dev/null +++ b/web/service/device-flow.ts @@ -0,0 +1,84 @@ +// Web-side calls into the Dify device-flow endpoints: +// +// /v1/oauth/device/lookup (public — GET, no auth, IP-rate-limited) +// /v1/oauth/device/approval-context (cookie-authed — GET) +// /v1/oauth/device/approve-external (cookie-authed + CSRF — POST) +// /console/api/oauth/device/approve (session-authed — POST) +// /console/api/oauth/device/deny (session-authed — POST) +// +// Approve/deny use the standard service/base helpers so they get console- +// session cookies automatically. Lookup + SSO-branch endpoints sit under +// /v1 so they ride the existing service-API gateway route. + +import { del, post } from './base' + +const DEVICE_BASE = '/v1/oauth/device' + +// ----- Account branch -------------------------------------------------------- + +export type DeviceLookupReply = { + valid: boolean + expires_in_remaining: number + client_id: string +} + +export async function deviceLookup(user_code: string): Promise { + const res = await fetch(`${DEVICE_BASE}/lookup?user_code=${encodeURIComponent(user_code)}`, { + method: 'GET', + }) + if (!res.ok) { + const body = await res.text().catch(() => '') + throw new Error(`lookup ${res.status}: ${body}`) + } + return res.json() +} + +export const deviceApproveAccount = (user_code: string) => + post<{ status: 'approved' }>('/oauth/device/approve', { body: { user_code } }) + +export const deviceDenyAccount = (user_code: string) => + post<{ status: 'denied' }>('/oauth/device/deny', { body: { user_code } }) + +// ----- SSO branch (cookie-authed via /v1/oauth/device/*) -------------------- + +export type ApprovalContext = { + subject_email: string + subject_issuer: string + user_code: string + csrf_token: string + expires_at: string +} + +export async function fetchApprovalContext(): Promise { + const res = await fetch(`${DEVICE_BASE}/approval-context`, { + method: 'GET', + credentials: 'include', + }) + if (!res.ok) { + const body = await res.text().catch(() => '') + throw new Error(`approval-context ${res.status}: ${body}`) + } + return res.json() +} + +export async function approveExternal(ctx: ApprovalContext, user_code: string): Promise { + const res = await fetch(`${DEVICE_BASE}/approve-external`, { + method: 'POST', + credentials: 'include', + headers: { + 'Content-Type': 'application/json', + 'X-CSRF-Token': ctx.csrf_token, + }, + body: JSON.stringify({ user_code }), + }) + if (!res.ok) { + const body = await res.text().catch(() => '') + throw new Error(`approve-external ${res.status}: ${body}`) + } +} + +// ----- Export for future PAT revoke; noop in v1.0 -------------------------- + +// Intentionally left out: personal_access_tokens endpoints are not in this +// milestone; see docs/specs/v1.0/README.md. +void del // keep import live for the TypeScript linter without surfacing usage