diff --git a/api/controllers/oauth_device_sso.py b/api/controllers/oauth_device_sso.py index 37dfb4579b..76208e2793 100644 --- a/api/controllers/oauth_device_sso.py +++ b/api/controllers/oauth_device_sso.py @@ -1,7 +1,7 @@ """Legacy /v1/* mounts for SSO-branch device-flow endpoints. Canonical -handlers live in controllers/openapi/oauth_device/. This file just -re-registers them on the legacy blueprint until Phase F retires the -legacy paths entirely. +handlers live in controllers/openapi/oauth_device_sso.py. This file +just re-registers them on the legacy blueprint until Phase F retires +the legacy paths entirely. Note: /v1/device/sso-complete (no /oauth/ in the path) is the existing ACS callback. Its canonical home is /openapi/v1/oauth/device/sso-complete. @@ -11,16 +11,18 @@ from __future__ import annotations from flask import Blueprint -from controllers.openapi.oauth_device.approval_context import approval_context -from controllers.openapi.oauth_device.approve_external import approve_external -from controllers.openapi.oauth_device.sso_complete import sso_complete -from controllers.openapi.oauth_device.sso_initiate import sso_initiate +from controllers.openapi.oauth_device_sso import ( + approval_context, + approve_external, + sso_complete, + sso_initiate, +) from libs.device_flow_security import attach_anti_framing bp = Blueprint("oauth_device_sso", __name__, url_prefix="/v1") attach_anti_framing(bp) -# Legacy /v1/* mounts — handlers live in controllers/openapi/oauth_device/. +# Legacy /v1/* mounts — handlers live in controllers/openapi/oauth_device_sso.py. # Removed in Phase F. bp.add_url_rule( "/oauth/device/sso-initiate", diff --git a/api/controllers/openapi/__init__.py b/api/controllers/openapi/__init__.py index 05618902f6..a5b30311f3 100644 --- a/api/controllers/openapi/__init__.py +++ b/api/controllers/openapi/__init__.py @@ -16,29 +16,13 @@ api = ExternalApi( openapi_ns = Namespace("openapi", description="User-scoped operations", path="/") -from . import account, index -from .oauth_device import approval_context as oauth_device_approval_context -from .oauth_device import approve as oauth_device_approve -from .oauth_device import approve_external as oauth_device_approve_external -from .oauth_device import code as oauth_device_code -from .oauth_device import deny as oauth_device_deny -from .oauth_device import lookup as oauth_device_lookup -from .oauth_device import sso_complete as oauth_device_sso_complete -from .oauth_device import sso_initiate as oauth_device_sso_initiate -from .oauth_device import token as oauth_device_token +from . import account, index, oauth_device, oauth_device_sso __all__ = [ "account", "index", - "oauth_device_approval_context", - "oauth_device_approve", - "oauth_device_approve_external", - "oauth_device_code", - "oauth_device_deny", - "oauth_device_lookup", - "oauth_device_sso_complete", - "oauth_device_sso_initiate", - "oauth_device_token", + "oauth_device", + "oauth_device_sso", ] api.add_namespace(openapi_ns) diff --git a/api/controllers/openapi/oauth_device.py b/api/controllers/openapi/oauth_device.py new file mode 100644 index 0000000000..457700aa75 --- /dev/null +++ b/api/controllers/openapi/oauth_device.py @@ -0,0 +1,385 @@ +"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two +sub-groups in one module: + + Protocol (RFC 8628, public + rate-limited): + POST /oauth/device/code + POST /oauth/device/token + GET /oauth/device/lookup + + Approval (account branch, console-cookie authed): + POST /oauth/device/approve + POST /oauth/device/deny + +The five Resource classes are also re-registered on legacy mounts: +service_api_ns at /v1/oauth/device/{code,token,lookup} (from +service_api/oauth.py) and console_ns at /console/api/oauth/device/{approve,deny} +(from the deferred _register_legacy_console_mount() at module bottom). +All legacy mounts retire in Phase F. SSO branch lives in oauth_device_sso.py. +""" +from __future__ import annotations + +import logging + +from flask import request +from flask_login import login_required +from flask_restx import Resource, reqparse + +from configs import dify_config +from controllers.console.wraps import account_initialization_required, setup_required +from controllers.openapi import openapi_ns +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.helper import extract_remote_ip +from libs.login import current_account_with_tenant +from libs.oauth_bearer import SubjectType, bearer_feature_required +from libs.rate_limit import ( + LIMIT_APPROVE_CONSOLE, + LIMIT_DEVICE_CODE_PER_IP, + LIMIT_LOOKUP_PUBLIC, + rate_limit, +) +from services.oauth_device_flow import ( + ACCOUNT_ISSUER_SENTINEL, + DEFAULT_POLL_INTERVAL_SECONDS, + DEVICE_FLOW_TTL_SECONDS, + PREFIX_OAUTH_ACCOUNT, + DeviceFlowRedis, + DeviceFlowStatus, + InvalidTransition, + SlowDownDecision, + StateNotFound, + mint_oauth_token, + oauth_ttl_days, +) + +logger = logging.getLogger(__name__) + + +# ========================================================================= +# Parsers +# ========================================================================= + +_code_parser = reqparse.RequestParser() +_code_parser.add_argument("client_id", type=str, required=True, location="json") +_code_parser.add_argument("device_label", type=str, required=True, location="json") + +_poll_parser = reqparse.RequestParser() +_poll_parser.add_argument("device_code", type=str, required=True, location="json") +_poll_parser.add_argument("client_id", type=str, required=True, location="json") + +_lookup_parser = reqparse.RequestParser() +_lookup_parser.add_argument("user_code", type=str, required=True, location="args") + +_mutate_parser = reqparse.RequestParser() +_mutate_parser.add_argument("user_code", type=str, required=True, location="json") + + +# ========================================================================= +# Protocol endpoints — RFC 8628 (public + per-IP rate limit) +# ========================================================================= + + +@openapi_ns.route("/oauth/device/code") +class OAuthDeviceCodeApi(Resource): + @rate_limit(LIMIT_DEVICE_CODE_PER_IP) + def post(self): + args = _code_parser.parse_args() + client_id = args["client_id"] + device_label = args["device_label"] + + if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS: + return {"error": "unsupported_client"}, 400 + + store = DeviceFlowRedis(redis_client) + ip = extract_remote_ip(request) + device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip) + + return { + "device_code": device_code, + "user_code": user_code, + "verification_uri": _verification_uri(), + "expires_in": expires_in, + "interval": DEFAULT_POLL_INTERVAL_SECONDS, + }, 200 + + +@openapi_ns.route("/oauth/device/token") +class OAuthDeviceTokenApi(Resource): + """RFC 8628 poll.""" + + def post(self): + args = _poll_parser.parse_args() + device_code = args["device_code"] + + store = DeviceFlowRedis(redis_client) + + # slow_down beats every other branch — polling-too-fast clients + # see only that response regardless of underlying state. + if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN: + return {"error": "slow_down"}, 400 + + state = store.load_by_device_code(device_code) + if state is None: + return {"error": "expired_token"}, 400 + + if state.status is DeviceFlowStatus.PENDING: + return {"error": "authorization_pending"}, 400 + + terminal = store.consume_on_poll(device_code) + if terminal is None: + return {"error": "expired_token"}, 400 + + if terminal.status is DeviceFlowStatus.DENIED: + return {"error": "access_denied"}, 400 + + poll_payload = terminal.poll_payload or {} + if "token" not in poll_payload: + logger.error("device_flow: approved state missing poll_payload for %s", device_code) + return {"error": "expired_token"}, 400 + + _audit_cross_ip_if_needed(state) + return poll_payload, 200 + + +@openapi_ns.route("/oauth/device/lookup") +class OAuthDeviceLookupApi(Resource): + """Read-only — public for pre-validate before login. user_code is + high-entropy + short-TTL; per-IP rate limit blocks enumeration. + """ + + @rate_limit(LIMIT_LOOKUP_PUBLIC) + def get(self): + args = _lookup_parser.parse_args() + user_code = args["user_code"].strip().upper() + + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(user_code) + if found is None: + return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200 + + _device_code, state = found + if state.status is not DeviceFlowStatus.PENDING: + return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200 + + return { + "valid": True, + "expires_in_remaining": DEVICE_FLOW_TTL_SECONDS, + "client_id": state.client_id, + }, 200 + + +# ========================================================================= +# Approval endpoints — account branch (cookie-authed) +# ========================================================================= + + +_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving" +_APPROVE_GUARD_TTL_SECONDS = 10 + + +@openapi_ns.route("/oauth/device/approve") +class DeviceApproveApi(Resource): + @setup_required + @login_required + @account_initialization_required + @bearer_feature_required + @rate_limit(LIMIT_APPROVE_CONSOLE) + def post(self): + args = _mutate_parser.parse_args() + user_code = args["user_code"].strip().upper() + + account, tenant = current_account_with_tenant() + store = DeviceFlowRedis(redis_client) + + found = store.load_by_user_code(user_code) + if found is None: + return {"error": "expired_or_unknown"}, 404 + device_code, state = found + if state.status is not DeviceFlowStatus.PENDING: + return {"error": "already_resolved"}, 409 + + # SET NX guard — without it, two in-flight approves both pass + # PENDING, both mint, and the second upsert silently rotates the + # first caller into an already-revoked token. + guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code) + if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS): + return {"error": "approve_in_progress"}, 409 + + try: + ttl_days = oauth_ttl_days(tenant_id=tenant) + mint = mint_oauth_token( + db.session, + redis_client, + subject_email=account.email, + subject_issuer=ACCOUNT_ISSUER_SENTINEL, + account_id=str(account.id), + client_id=state.client_id, + device_label=state.device_label, + prefix=PREFIX_OAUTH_ACCOUNT, + ttl_days=ttl_days, + ) + + poll_payload = _build_account_poll_payload(account, tenant, mint) + try: + store.approve( + device_code, + subject_email=account.email, + account_id=str(account.id), + subject_issuer=ACCOUNT_ISSUER_SENTINEL, + minted_token=mint.token, + token_id=str(mint.token_id), + poll_payload=poll_payload, + ) + except (StateNotFound, InvalidTransition) as e: + # Row minted but state vanished — roll forward; the orphan + # token is revocable via auth devices list / Authorized Apps. + logger.error("device_flow: approve raced on %s: %s", device_code, e) + return {"error": "state_lost"}, 409 + finally: + redis_client.delete(guard_key) + + _emit_approve_audit(state, account, tenant, mint) + return {"status": "approved"}, 200 + + +@openapi_ns.route("/oauth/device/deny") +class DeviceDenyApi(Resource): + @setup_required + @login_required + @account_initialization_required + @bearer_feature_required + @rate_limit(LIMIT_APPROVE_CONSOLE) + def post(self): + args = _mutate_parser.parse_args() + user_code = args["user_code"].strip().upper() + + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(user_code) + if found is None: + return {"error": "expired_or_unknown"}, 404 + device_code, state = found + if state.status is not DeviceFlowStatus.PENDING: + return {"error": "already_resolved"}, 409 + + try: + store.deny(device_code) + except (StateNotFound, InvalidTransition) as e: + logger.error("device_flow: deny raced on %s: %s", device_code, e) + return {"error": "state_lost"}, 409 + + _emit_deny_audit(state) + return {"status": "denied"}, 200 + + +# ========================================================================= +# Helpers +# ========================================================================= + + +def _verification_uri() -> str: + base = getattr(dify_config, "CONSOLE_WEB_URL", None) + if base: + return f"{base.rstrip('/')}/device" + return f"{request.host_url.rstrip('/')}/device" + + +def _audit_cross_ip_if_needed(state) -> None: + poll_ip = extract_remote_ip(request) + if state.created_ip and poll_ip and poll_ip != state.created_ip: + logger.warning( + "audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s", + state.token_id, state.created_ip, poll_ip, + extra={ + "audit": True, + "token_id": state.token_id, + "creation_ip": state.created_ip, + "poll_ip": poll_ip, + }, + ) + + +def _build_account_poll_payload(account, tenant, mint) -> dict: + """Pre-render the poll-response body so the unauthenticated poll + handler doesn't re-query accounts/tenants for authz data. + """ + from models import Tenant, TenantAccountJoin + rows = ( + db.session.query(Tenant, TenantAccountJoin) + .join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.account_id == account.id) + .all() + ) + workspaces = [ + {"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} + for t, m in rows + ] + # Prefer active session tenant → DB-flagged current join → first membership. + default_ws_id = None + if tenant and any(w["id"] == str(tenant) for w in workspaces): + default_ws_id = str(tenant) + if default_ws_id is None: + for _t, m in rows: + if getattr(m, "current", False): + default_ws_id = str(m.tenant_id) + break + if default_ws_id is None and workspaces: + default_ws_id = workspaces[0]["id"] + + return { + "token": mint.token, + "expires_at": mint.expires_at.isoformat(), + "subject_type": SubjectType.ACCOUNT, + "account": {"id": str(account.id), "email": account.email, "name": account.name}, + "workspaces": workspaces, + "default_workspace_id": default_ws_id, + "token_id": str(mint.token_id), + } + + +def _emit_approve_audit(state, account, tenant, mint) -> None: + logger.warning( + "audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s", + mint.token_id, account.email, state.client_id, state.device_label, mint.expires_at, + extra={ + "audit": True, + "event": "oauth.device_flow_approved", + "token_id": str(mint.token_id), + "subject_type": SubjectType.ACCOUNT, + "subject_email": account.email, + "account_id": str(account.id), + "tenant_id": tenant, + "client_id": state.client_id, + "device_label": state.device_label, + "scopes": ["full"], + "expires_at": mint.expires_at.isoformat(), + }, + ) + + +def _emit_deny_audit(state) -> None: + logger.warning( + "audit: oauth.device_flow_denied client_id=%s device_label=%s", + state.client_id, state.device_label, + extra={ + "audit": True, + "event": "oauth.device_flow_denied", + "client_id": state.client_id, + "device_label": state.device_label, + }, + ) + + +# ========================================================================= +# Legacy console-side mount — deferred import breaks a cycle that would +# form between this module (imports controllers.console.wraps) and +# controllers.console.__init__ (loads .auth.oauth_device). +# ========================================================================= + + +def _register_legacy_console_mount() -> None: + from controllers.console import console_ns + console_ns.add_resource(DeviceApproveApi, "/oauth/device/approve") + console_ns.add_resource(DeviceDenyApi, "/oauth/device/deny") + + +_register_legacy_console_mount() diff --git a/api/controllers/openapi/oauth_device/__init__.py b/api/controllers/openapi/oauth_device/__init__.py deleted file mode 100644 index 5d55c7ebc1..0000000000 --- a/api/controllers/openapi/oauth_device/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""User-scoped device-flow protocol endpoints (RFC 8628). Public — -unauthenticated, per-IP rate-limited. Approval/deny + SSO branch land -here in Phase D. -""" diff --git a/api/controllers/openapi/oauth_device/approval_context.py b/api/controllers/openapi/oauth_device/approval_context.py deleted file mode 100644 index 3de6bdc221..0000000000 --- a/api/controllers/openapi/oauth_device/approval_context.py +++ /dev/null @@ -1,46 +0,0 @@ -"""GET /openapi/v1/oauth/device/approval-context — EE-only. SPA reads -the device_approval_grant cookie claims (subject email/issuer, csrf -token, user_code, expiry). Idempotent — does not consume the nonce. - -Also registered on the legacy /v1/oauth/device/approval-context path -from controllers/oauth_device_sso.py until Phase F retires that mount. -""" -from __future__ import annotations - -import logging - -from flask import jsonify, request -from werkzeug.exceptions import Unauthorized - -from controllers.openapi import bp -from libs import jws -from libs.device_flow_security import ( - APPROVAL_GRANT_COOKIE_NAME, - enterprise_only, - verify_approval_grant, -) - -logger = logging.getLogger(__name__) - - -@bp.route("/oauth/device/approval-context", methods=["GET"]) -@enterprise_only -def approval_context(): - token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME) - if not token: - raise Unauthorized("no_session") - - keyset = jws.KeySet.from_shared_secret() - try: - claims = verify_approval_grant(keyset, token) - except jws.VerifyError as e: - logger.warning("approval-context: bad cookie: %s", e) - raise Unauthorized("no_session") from e - - return jsonify({ - "subject_email": claims.subject_email, - "subject_issuer": claims.subject_issuer, - "user_code": claims.user_code, - "csrf_token": claims.csrf_token, - "expires_at": claims.expires_at.isoformat(), - }), 200 diff --git a/api/controllers/openapi/oauth_device/approve.py b/api/controllers/openapi/oauth_device/approve.py deleted file mode 100644 index 681bfd3c53..0000000000 --- a/api/controllers/openapi/oauth_device/approve.py +++ /dev/null @@ -1,175 +0,0 @@ -"""POST /openapi/v1/oauth/device/approve — user approves a pending -device flow from the /device page. Console-session authed (the user is -signed in via cookie when they hit Approve in the SPA). - -The class is also registered on console_ns at /console/api/oauth/device/approve -from console/auth/oauth_device.py until Phase F retires that mount. -""" -from __future__ import annotations - -import logging - -from flask_login import login_required -from flask_restx import Resource, reqparse - -from controllers.console.wraps import account_initialization_required, setup_required -from controllers.openapi import openapi_ns -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from libs.login import current_account_with_tenant -from libs.oauth_bearer import SubjectType, bearer_feature_required -from libs.rate_limit import LIMIT_APPROVE_CONSOLE, rate_limit -from services.oauth_device_flow import ( - ACCOUNT_ISSUER_SENTINEL, - PREFIX_OAUTH_ACCOUNT, - DeviceFlowRedis, - DeviceFlowStatus, - InvalidTransition, - StateNotFound, - mint_oauth_token, - oauth_ttl_days, -) - -logger = logging.getLogger(__name__) - - -_mutate_parser = reqparse.RequestParser() -_mutate_parser.add_argument("user_code", type=str, required=True, location="json") - - -_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving" -_APPROVE_GUARD_TTL_SECONDS = 10 - - -@openapi_ns.route("/oauth/device/approve") -class DeviceApproveApi(Resource): - @setup_required - @login_required - @account_initialization_required - @bearer_feature_required - @rate_limit(LIMIT_APPROVE_CONSOLE) - def post(self): - args = _mutate_parser.parse_args() - user_code = args["user_code"].strip().upper() - - account, tenant = current_account_with_tenant() - store = DeviceFlowRedis(redis_client) - - found = store.load_by_user_code(user_code) - if found is None: - return {"error": "expired_or_unknown"}, 404 - device_code, state = found - if state.status is not DeviceFlowStatus.PENDING: - return {"error": "already_resolved"}, 409 - - # SET NX guard — without it, two in-flight approves both pass - # PENDING, both mint, and the second upsert silently rotates the - # first caller into an already-revoked token. - guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code) - if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS): - return {"error": "approve_in_progress"}, 409 - - try: - ttl_days = oauth_ttl_days(tenant_id=tenant) - mint = mint_oauth_token( - db.session, - redis_client, - subject_email=account.email, - subject_issuer=ACCOUNT_ISSUER_SENTINEL, - account_id=str(account.id), - client_id=state.client_id, - device_label=state.device_label, - prefix=PREFIX_OAUTH_ACCOUNT, - ttl_days=ttl_days, - ) - - poll_payload = _build_account_poll_payload(account, tenant, mint) - try: - store.approve( - device_code, - subject_email=account.email, - account_id=str(account.id), - subject_issuer=ACCOUNT_ISSUER_SENTINEL, - minted_token=mint.token, - token_id=str(mint.token_id), - poll_payload=poll_payload, - ) - except (StateNotFound, InvalidTransition) as e: - # Row minted but state vanished — roll forward; the orphan - # token is revocable via auth devices list / Authorized Apps. - logger.error("device_flow: approve raced on %s: %s", device_code, e) - return {"error": "state_lost"}, 409 - finally: - redis_client.delete(guard_key) - - _emit_approve_audit(state, account, tenant, mint) - return {"status": "approved"}, 200 - - -def _build_account_poll_payload(account, tenant, mint) -> dict: - """Pre-render the poll-response body so the unauthenticated poll - handler doesn't re-query accounts/tenants for authz data. - """ - from models import Tenant, TenantAccountJoin - rows = ( - db.session.query(Tenant, TenantAccountJoin) - .join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.account_id == account.id) - .all() - ) - workspaces = [ - {"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} - for t, m in rows - ] - # Prefer active session tenant → DB-flagged current join → first membership. - default_ws_id = None - if tenant and any(w["id"] == str(tenant) for w in workspaces): - default_ws_id = str(tenant) - if default_ws_id is None: - for _t, m in rows: - if getattr(m, "current", False): - default_ws_id = str(m.tenant_id) - break - if default_ws_id is None and workspaces: - default_ws_id = workspaces[0]["id"] - - return { - "token": mint.token, - "expires_at": mint.expires_at.isoformat(), - "subject_type": SubjectType.ACCOUNT, - "account": {"id": str(account.id), "email": account.email, "name": account.name}, - "workspaces": workspaces, - "default_workspace_id": default_ws_id, - "token_id": str(mint.token_id), - } - - -def _emit_approve_audit(state, account, tenant, mint) -> None: - logger.warning( - "audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s", - mint.token_id, account.email, state.client_id, state.device_label, mint.expires_at, - extra={ - "audit": True, - "event": "oauth.device_flow_approved", - "token_id": str(mint.token_id), - "subject_type": SubjectType.ACCOUNT, - "subject_email": account.email, - "account_id": str(account.id), - "tenant_id": tenant, - "client_id": state.client_id, - "device_label": state.device_label, - "scopes": ["full"], - "expires_at": mint.expires_at.isoformat(), - }, - ) - - -# Legacy /console/api/oauth/device/approve mount — handler defined above. -# Removed in Phase F. The console_ns import is local to defer past -# circular-import resolution between this module and controllers.console. -def _register_legacy_console_mount() -> None: - from controllers.console import console_ns - console_ns.add_resource(DeviceApproveApi, "/oauth/device/approve") - - -_register_legacy_console_mount() diff --git a/api/controllers/openapi/oauth_device/approve_external.py b/api/controllers/openapi/oauth_device/approve_external.py deleted file mode 100644 index fb1b214105..0000000000 --- a/api/controllers/openapi/oauth_device/approve_external.py +++ /dev/null @@ -1,141 +0,0 @@ -"""POST /openapi/v1/oauth/device/approve-external — EE-only. User -clicks Approve in the SPA after federated SSO; cookie + CSRF gate -the request, then we mint a dfoe_ token and approve the device flow. - -Also registered on the legacy /v1/oauth/device/approve-external path -from controllers/oauth_device_sso.py until Phase F retires that mount. -""" -from __future__ import annotations - -import logging - -from flask import jsonify, make_response, request -from werkzeug.exceptions import BadRequest, Conflict, Forbidden, NotFound, Unauthorized - -from controllers.openapi import bp -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from libs import jws -from libs.device_flow_security import ( - APPROVAL_GRANT_COOKIE_NAME, - ApprovalGrantClaims, - approval_grant_cleared_cookie_kwargs, - consume_approval_grant_nonce, - enterprise_only, - verify_approval_grant, -) -from libs.oauth_bearer import SubjectType -from libs.rate_limit import LIMIT_APPROVE_EXT_PER_EMAIL, enforce -from services.oauth_device_flow import ( - PREFIX_OAUTH_EXTERNAL_SSO, - DeviceFlowRedis, - DeviceFlowStatus, - InvalidTransition, - StateNotFound, - mint_oauth_token, - oauth_ttl_days, -) - -logger = logging.getLogger(__name__) - - -@bp.route("/oauth/device/approve-external", methods=["POST"]) -@enterprise_only -def approve_external(): - token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME) - if not token: - raise Unauthorized("invalid_session") - - keyset = jws.KeySet.from_shared_secret() - try: - claims: ApprovalGrantClaims = verify_approval_grant(keyset, token) - except jws.VerifyError as e: - logger.warning("approve-external: bad cookie: %s", e) - raise Unauthorized("invalid_session") from e - - enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}") - - csrf_header = request.headers.get("X-CSRF-Token", "") - if not csrf_header or csrf_header != claims.csrf_token: - raise Forbidden("csrf_mismatch") - - data = request.get_json(silent=True) or {} - body_user_code = (data.get("user_code") or "").strip().upper() - if body_user_code != claims.user_code: - raise BadRequest("user_code_mismatch") - - store = DeviceFlowRedis(redis_client) - found = store.load_by_user_code(claims.user_code) - if found is None: - raise NotFound("user_code_not_pending") - device_code, state = found - if state.status is not DeviceFlowStatus.PENDING: - raise Conflict("user_code_not_pending") - - if not consume_approval_grant_nonce(redis_client, claims.nonce): - raise Unauthorized("session_already_consumed") - - ttl_days = oauth_ttl_days(tenant_id=None) - mint = mint_oauth_token( - db.session, - redis_client, - subject_email=claims.subject_email, - subject_issuer=claims.subject_issuer, - account_id=None, - client_id=state.client_id, - device_label=state.device_label, - prefix=PREFIX_OAUTH_EXTERNAL_SSO, - ttl_days=ttl_days, - ) - - poll_payload = { - "token": mint.token, - "expires_at": mint.expires_at.isoformat(), - "subject_type": SubjectType.EXTERNAL_SSO, - "subject_email": claims.subject_email, - "subject_issuer": claims.subject_issuer, - "account": None, - "workspaces": [], - "default_workspace_id": None, - "token_id": str(mint.token_id), - } - - try: - store.approve( - device_code, - subject_email=claims.subject_email, - account_id=None, - subject_issuer=claims.subject_issuer, - minted_token=mint.token, - token_id=str(mint.token_id), - poll_payload=poll_payload, - ) - except (StateNotFound, InvalidTransition) as e: - logger.error("approve-external: state transition raced: %s", e) - raise Conflict("state_lost") from e - - _emit_approve_external_audit(state, claims, mint) - - resp = make_response(jsonify({"status": "approved"}), 200) - resp.set_cookie(**approval_grant_cleared_cookie_kwargs()) - return resp - - -def _emit_approve_external_audit(state, claims, mint) -> None: - logger.warning( - "audit: oauth.device_flow_approved subject_type=%s " - "subject_email=%s subject_issuer=%s token_id=%s", - SubjectType.EXTERNAL_SSO, claims.subject_email, claims.subject_issuer, mint.token_id, - extra={ - "audit": True, - "event": "oauth.device_flow_approved", - "subject_type": SubjectType.EXTERNAL_SSO, - "subject_email": claims.subject_email, - "subject_issuer": claims.subject_issuer, - "token_id": str(mint.token_id), - "client_id": state.client_id, - "device_label": state.device_label, - "scopes": ["apps:run"], - "expires_at": mint.expires_at.isoformat(), - }, - ) diff --git a/api/controllers/openapi/oauth_device/code.py b/api/controllers/openapi/oauth_device/code.py deleted file mode 100644 index f6d4139010..0000000000 --- a/api/controllers/openapi/oauth_device/code.py +++ /dev/null @@ -1,56 +0,0 @@ -"""POST /openapi/v1/oauth/device/code — RFC 8628 device authorization request. - -Public + per-IP rate-limited. The CLI starts a device flow here; the -returned `verification_uri` is what the user opens in a browser. The -class is also registered on the legacy /v1/ namespace from -service_api/oauth.py until Phase F retires that mount. -""" -from __future__ import annotations - -from flask import request -from flask_restx import Resource, reqparse - -from configs import dify_config -from controllers.openapi import openapi_ns -from extensions.ext_redis import redis_client -from libs.helper import extract_remote_ip -from libs.rate_limit import LIMIT_DEVICE_CODE_PER_IP, rate_limit -from services.oauth_device_flow import ( - DEFAULT_POLL_INTERVAL_SECONDS, - DeviceFlowRedis, -) - -_code_parser = reqparse.RequestParser() -_code_parser.add_argument("client_id", type=str, required=True, location="json") -_code_parser.add_argument("device_label", type=str, required=True, location="json") - - -@openapi_ns.route("/oauth/device/code") -class OAuthDeviceCodeApi(Resource): - @rate_limit(LIMIT_DEVICE_CODE_PER_IP) - def post(self): - args = _code_parser.parse_args() - client_id = args["client_id"] - device_label = args["device_label"] - - if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS: - return {"error": "unsupported_client"}, 400 - - store = DeviceFlowRedis(redis_client) - ip = extract_remote_ip(request) - device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip) - - return { - "device_code": device_code, - "user_code": user_code, - "verification_uri": _verification_uri(), - "expires_in": expires_in, - "interval": DEFAULT_POLL_INTERVAL_SECONDS, - }, 200 - - -def _verification_uri() -> str: - base = getattr(dify_config, "CONSOLE_WEB_URL", None) - if base: - return f"{base.rstrip('/')}/device" - return f"{request.host_url.rstrip('/')}/device" diff --git a/api/controllers/openapi/oauth_device/deny.py b/api/controllers/openapi/oauth_device/deny.py deleted file mode 100644 index 598d66b412..0000000000 --- a/api/controllers/openapi/oauth_device/deny.py +++ /dev/null @@ -1,83 +0,0 @@ -"""POST /openapi/v1/oauth/device/deny — user denies a pending device -flow from the /device page. Console-session authed. - -The class is also registered on console_ns at /console/api/oauth/device/deny -from console/auth/oauth_device.py until Phase F retires that mount. -""" -from __future__ import annotations - -import logging - -from flask_login import login_required -from flask_restx import Resource, reqparse - -from controllers.console.wraps import account_initialization_required, setup_required -from controllers.openapi import openapi_ns -from extensions.ext_redis import redis_client -from libs.oauth_bearer import bearer_feature_required -from libs.rate_limit import LIMIT_APPROVE_CONSOLE, rate_limit -from services.oauth_device_flow import ( - DeviceFlowRedis, - DeviceFlowStatus, - InvalidTransition, - StateNotFound, -) - -logger = logging.getLogger(__name__) - - -_mutate_parser = reqparse.RequestParser() -_mutate_parser.add_argument("user_code", type=str, required=True, location="json") - - -@openapi_ns.route("/oauth/device/deny") -class DeviceDenyApi(Resource): - @setup_required - @login_required - @account_initialization_required - @bearer_feature_required - @rate_limit(LIMIT_APPROVE_CONSOLE) - def post(self): - args = _mutate_parser.parse_args() - user_code = args["user_code"].strip().upper() - - store = DeviceFlowRedis(redis_client) - found = store.load_by_user_code(user_code) - if found is None: - return {"error": "expired_or_unknown"}, 404 - device_code, state = found - if state.status is not DeviceFlowStatus.PENDING: - return {"error": "already_resolved"}, 409 - - try: - store.deny(device_code) - except (StateNotFound, InvalidTransition) as e: - logger.error("device_flow: deny raced on %s: %s", device_code, e) - return {"error": "state_lost"}, 409 - - _emit_deny_audit(state) - return {"status": "denied"}, 200 - - -def _emit_deny_audit(state) -> None: - logger.warning( - "audit: oauth.device_flow_denied client_id=%s device_label=%s", - state.client_id, state.device_label, - extra={ - "audit": True, - "event": "oauth.device_flow_denied", - "client_id": state.client_id, - "device_label": state.device_label, - }, - ) - - -# Legacy /console/api/oauth/device/deny mount — handler defined above. -# Removed in Phase F. The console_ns import is local to defer past -# circular-import resolution between this module and controllers.console. -def _register_legacy_console_mount() -> None: - from controllers.console import console_ns - console_ns.add_resource(DeviceDenyApi, "/oauth/device/deny") - - -_register_legacy_console_mount() diff --git a/api/controllers/openapi/oauth_device/lookup.py b/api/controllers/openapi/oauth_device/lookup.py deleted file mode 100644 index 7546ba78cc..0000000000 --- a/api/controllers/openapi/oauth_device/lookup.py +++ /dev/null @@ -1,49 +0,0 @@ -"""GET /openapi/v1/oauth/device/lookup — pre-validate user_code from -the /device page before the user signs in. Public; user_code is -high-entropy + short-TTL, per-IP rate limit blocks enumeration. - -The class is also registered on the legacy /v1/ namespace from -service_api/oauth.py until Phase F retires that mount. -""" -from __future__ import annotations - -from flask_restx import Resource, reqparse - -from controllers.openapi import openapi_ns -from extensions.ext_redis import redis_client -from libs.rate_limit import LIMIT_LOOKUP_PUBLIC, rate_limit -from services.oauth_device_flow import ( - DEVICE_FLOW_TTL_SECONDS, - DeviceFlowRedis, - DeviceFlowStatus, -) - -_lookup_parser = reqparse.RequestParser() -_lookup_parser.add_argument("user_code", type=str, required=True, location="args") - - -@openapi_ns.route("/oauth/device/lookup") -class OAuthDeviceLookupApi(Resource): - """Read-only — public for pre-validate before login. user_code is - high-entropy + short-TTL; per-IP rate limit blocks enumeration. - """ - - @rate_limit(LIMIT_LOOKUP_PUBLIC) - def get(self): - args = _lookup_parser.parse_args() - user_code = args["user_code"].strip().upper() - - store = DeviceFlowRedis(redis_client) - found = store.load_by_user_code(user_code) - if found is None: - return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200 - - _device_code, state = found - if state.status is not DeviceFlowStatus.PENDING: - return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200 - - return { - "valid": True, - "expires_in_remaining": DEVICE_FLOW_TTL_SECONDS, - "client_id": state.client_id, - }, 200 diff --git a/api/controllers/openapi/oauth_device/sso_complete.py b/api/controllers/openapi/oauth_device/sso_complete.py deleted file mode 100644 index 48a94be28f..0000000000 --- a/api/controllers/openapi/oauth_device/sso_complete.py +++ /dev/null @@ -1,69 +0,0 @@ -"""GET /openapi/v1/oauth/device/sso-complete — EE-only ACS callback. -The IdP redirects here with a signed external-subject assertion; -we verify, mint the approval-grant cookie, and redirect to /device. - -The handler is also registered on the legacy /v1/device/sso-complete -path from controllers/oauth_device_sso.py until Phase F retires that mount. -The legacy path lived under /v1/device/, not /v1/oauth/device/, so -existing IdP ACS configs need re-registration to the canonical path. -""" -from __future__ import annotations - -import logging - -from flask import redirect, request -from werkzeug.exceptions import BadRequest, Conflict - -from controllers.openapi import bp -from extensions.ext_redis import redis_client -from libs import jws -from libs.device_flow_security import ( - approval_grant_cookie_kwargs, - consume_sso_assertion_nonce, - enterprise_only, - mint_approval_grant, -) -from services.oauth_device_flow import DeviceFlowRedis, DeviceFlowStatus - -logger = logging.getLogger(__name__) - - -@bp.route("/oauth/device/sso-complete", methods=["GET"]) -@enterprise_only -def sso_complete(): - blob = request.args.get("sso_assertion") - if not blob: - raise BadRequest("sso_assertion required") - - keyset = jws.KeySet.from_shared_secret() - - try: - claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION) - except jws.VerifyError as e: - logger.warning("sso-complete: rejected assertion: %s", e) - raise BadRequest("invalid_sso_assertion") from e - - if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")): - raise BadRequest("invalid_sso_assertion") - - user_code = (claims.get("user_code") or "").strip().upper() - store = DeviceFlowRedis(redis_client) - found = store.load_by_user_code(user_code) - if found is None: - raise Conflict("user_code_not_pending") - _, state = found - if state.status is not DeviceFlowStatus.PENDING: - raise Conflict("user_code_not_pending") - - iss = request.host_url.rstrip("/") - cookie_value, _ = mint_approval_grant( - keyset=keyset, - iss=iss, - subject_email=claims["email"], - subject_issuer=claims["issuer"], - user_code=user_code, - ) - - resp = redirect("/device?sso_verified=1", code=302) - resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value)) - return resp diff --git a/api/controllers/openapi/oauth_device/sso_initiate.py b/api/controllers/openapi/oauth_device/sso_initiate.py deleted file mode 100644 index a4d30cce18..0000000000 --- a/api/controllers/openapi/oauth_device/sso_initiate.py +++ /dev/null @@ -1,83 +0,0 @@ -"""GET /openapi/v1/oauth/device/sso-initiate — EE-only. Browser hits -this with a user_code; we sign an SSOState envelope and call the -Enterprise inner API to get the IdP authorize URL, then 302 to the IdP. - -The handler is also registered on the legacy /v1/oauth/device/sso-initiate -path from controllers/oauth_device_sso.py until Phase F retires that mount. -""" -from __future__ import annotations - -import logging -import secrets - -from flask import redirect, request -from werkzeug.exceptions import BadGateway, BadRequest - -from controllers.openapi import bp -from extensions.ext_redis import redis_client -from libs import jws -from libs.device_flow_security import ( - approval_grant_cleared_cookie_kwargs, - enterprise_only, -) -from libs.rate_limit import LIMIT_SSO_INITIATE_PER_IP, rate_limit -from services.enterprise.enterprise_service import EnterpriseService -from services.oauth_device_flow import DeviceFlowRedis, DeviceFlowStatus - -logger = logging.getLogger(__name__) - - -# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the -# device_code it references. -STATE_ENVELOPE_TTL_SECONDS = 15 * 60 - -# Canonical sso-complete path. IdP-side ACS callback URL must point here. -_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete" - - -@bp.route("/oauth/device/sso-initiate", methods=["GET"]) -@enterprise_only -@rate_limit(LIMIT_SSO_INITIATE_PER_IP) -def sso_initiate(): - user_code = (request.args.get("user_code") or "").strip().upper() - if not user_code: - raise BadRequest("user_code required") - - store = DeviceFlowRedis(redis_client) - found = store.load_by_user_code(user_code) - if found is None: - raise BadRequest("invalid_user_code") - _, state = found - if state.status is not DeviceFlowStatus.PENDING: - raise BadRequest("invalid_user_code") - - keyset = jws.KeySet.from_shared_secret() - signed_state = jws.sign( - keyset, - payload={ - "redirect_url": "", - "app_code": "", - "intent": "device_flow", - "user_code": user_code, - "nonce": secrets.token_urlsafe(16), - "return_to": "", - "idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}", - }, - aud=jws.AUD_STATE_ENVELOPE, - ttl_seconds=STATE_ENVELOPE_TTL_SECONDS, - ) - - try: - reply = EnterpriseService.initiate_device_flow_sso(signed_state) - except Exception as e: - logger.warning("sso-initiate: enterprise call failed: %s", e) - raise BadGateway("sso_initiate_failed") from e - - url = (reply or {}).get("url") - if not url: - raise BadGateway("sso_initiate_missing_url") - - # Clear stale approval-grant — defends against cross-tab/back-button mixing. - resp = redirect(url, code=302) - resp.set_cookie(**approval_grant_cleared_cookie_kwargs()) - return resp diff --git a/api/controllers/openapi/oauth_device/token.py b/api/controllers/openapi/oauth_device/token.py deleted file mode 100644 index e3c4fe1e88..0000000000 --- a/api/controllers/openapi/oauth_device/token.py +++ /dev/null @@ -1,82 +0,0 @@ -"""POST /openapi/v1/oauth/device/token — RFC 8628 device authorization -poll. Public; the CLI polls until the user completes approval at -/device. - -The class is also registered on the legacy /v1/ namespace from -service_api/oauth.py until Phase F retires that mount. -""" -from __future__ import annotations - -import logging - -from flask import request -from flask_restx import Resource, reqparse - -from controllers.openapi import openapi_ns -from extensions.ext_redis import redis_client -from libs.helper import extract_remote_ip -from services.oauth_device_flow import ( - DEFAULT_POLL_INTERVAL_SECONDS, - DeviceFlowRedis, - DeviceFlowStatus, - SlowDownDecision, -) - -logger = logging.getLogger(__name__) - -_poll_parser = reqparse.RequestParser() -_poll_parser.add_argument("device_code", type=str, required=True, location="json") -_poll_parser.add_argument("client_id", type=str, required=True, location="json") - - -@openapi_ns.route("/oauth/device/token") -class OAuthDeviceTokenApi(Resource): - """RFC 8628 poll.""" - - def post(self): - args = _poll_parser.parse_args() - device_code = args["device_code"] - - store = DeviceFlowRedis(redis_client) - - # slow_down beats every other branch — polling-too-fast clients - # see only that response regardless of underlying state. - if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN: - return {"error": "slow_down"}, 400 - - state = store.load_by_device_code(device_code) - if state is None: - return {"error": "expired_token"}, 400 - - if state.status is DeviceFlowStatus.PENDING: - return {"error": "authorization_pending"}, 400 - - terminal = store.consume_on_poll(device_code) - if terminal is None: - return {"error": "expired_token"}, 400 - - if terminal.status is DeviceFlowStatus.DENIED: - return {"error": "access_denied"}, 400 - - poll_payload = terminal.poll_payload or {} - if "token" not in poll_payload: - logger.error("device_flow: approved state missing poll_payload for %s", device_code) - return {"error": "expired_token"}, 400 - - _audit_cross_ip_if_needed(state) - return poll_payload, 200 - - -def _audit_cross_ip_if_needed(state) -> None: - poll_ip = extract_remote_ip(request) - if state.created_ip and poll_ip and poll_ip != state.created_ip: - logger.warning( - "audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s", - state.token_id, state.created_ip, poll_ip, - extra={ - "audit": True, - "token_id": state.token_id, - "creation_ip": state.created_ip, - "poll_ip": poll_ip, - }, - ) diff --git a/api/controllers/openapi/oauth_device_sso.py b/api/controllers/openapi/oauth_device_sso.py new file mode 100644 index 0000000000..9d82b6f591 --- /dev/null +++ b/api/controllers/openapi/oauth_device_sso.py @@ -0,0 +1,284 @@ +"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*. +EE-only. Browser flow: + + GET /oauth/device/sso-initiate → 302 to IdP authorize URL + GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie + GET /oauth/device/approval-context → SPA reads cookie claims (idempotent) + POST /oauth/device/approve-external → mints dfoe_ token + clears cookie + +Function-based (raw @bp.route) rather than Resource classes because the +handlers do redirects + cookie kwargs that don't fit the Resource shape. +Same handlers are also re-registered on the legacy /v1/* paths from +controllers/oauth_device_sso.py until Phase F retires the legacy mount. +""" +from __future__ import annotations + +import logging +import secrets + +from flask import jsonify, make_response, redirect, request +from werkzeug.exceptions import ( + BadGateway, + BadRequest, + Conflict, + Forbidden, + NotFound, + Unauthorized, +) + +from controllers.openapi import bp +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs import jws +from libs.device_flow_security import ( + APPROVAL_GRANT_COOKIE_NAME, + ApprovalGrantClaims, + approval_grant_cleared_cookie_kwargs, + approval_grant_cookie_kwargs, + consume_approval_grant_nonce, + consume_sso_assertion_nonce, + enterprise_only, + mint_approval_grant, + verify_approval_grant, +) +from libs.oauth_bearer import SubjectType +from libs.rate_limit import ( + LIMIT_APPROVE_EXT_PER_EMAIL, + LIMIT_SSO_INITIATE_PER_IP, + enforce, + rate_limit, +) +from services.enterprise.enterprise_service import EnterpriseService +from services.oauth_device_flow import ( + PREFIX_OAUTH_EXTERNAL_SSO, + DeviceFlowRedis, + DeviceFlowStatus, + InvalidTransition, + StateNotFound, + mint_oauth_token, + oauth_ttl_days, +) + +logger = logging.getLogger(__name__) + + +# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the +# device_code it references. +STATE_ENVELOPE_TTL_SECONDS = 15 * 60 + +# Canonical sso-complete path. IdP-side ACS callback URL must point here. +_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete" + + +@bp.route("/oauth/device/sso-initiate", methods=["GET"]) +@enterprise_only +@rate_limit(LIMIT_SSO_INITIATE_PER_IP) +def sso_initiate(): + user_code = (request.args.get("user_code") or "").strip().upper() + if not user_code: + raise BadRequest("user_code required") + + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(user_code) + if found is None: + raise BadRequest("invalid_user_code") + _, state = found + if state.status is not DeviceFlowStatus.PENDING: + raise BadRequest("invalid_user_code") + + keyset = jws.KeySet.from_shared_secret() + signed_state = jws.sign( + keyset, + payload={ + "redirect_url": "", + "app_code": "", + "intent": "device_flow", + "user_code": user_code, + "nonce": secrets.token_urlsafe(16), + "return_to": "", + "idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}", + }, + aud=jws.AUD_STATE_ENVELOPE, + ttl_seconds=STATE_ENVELOPE_TTL_SECONDS, + ) + + try: + reply = EnterpriseService.initiate_device_flow_sso(signed_state) + except Exception as e: + logger.warning("sso-initiate: enterprise call failed: %s", e) + raise BadGateway("sso_initiate_failed") from e + + url = (reply or {}).get("url") + if not url: + raise BadGateway("sso_initiate_missing_url") + + # Clear stale approval-grant — defends against cross-tab/back-button mixing. + resp = redirect(url, code=302) + resp.set_cookie(**approval_grant_cleared_cookie_kwargs()) + return resp + + +@bp.route("/oauth/device/sso-complete", methods=["GET"]) +@enterprise_only +def sso_complete(): + blob = request.args.get("sso_assertion") + if not blob: + raise BadRequest("sso_assertion required") + + keyset = jws.KeySet.from_shared_secret() + + try: + claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION) + except jws.VerifyError as e: + logger.warning("sso-complete: rejected assertion: %s", e) + raise BadRequest("invalid_sso_assertion") from e + + if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")): + raise BadRequest("invalid_sso_assertion") + + user_code = (claims.get("user_code") or "").strip().upper() + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(user_code) + if found is None: + raise Conflict("user_code_not_pending") + _, state = found + if state.status is not DeviceFlowStatus.PENDING: + raise Conflict("user_code_not_pending") + + iss = request.host_url.rstrip("/") + cookie_value, _ = mint_approval_grant( + keyset=keyset, + iss=iss, + subject_email=claims["email"], + subject_issuer=claims["issuer"], + user_code=user_code, + ) + + resp = redirect("/device?sso_verified=1", code=302) + resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value)) + return resp + + +@bp.route("/oauth/device/approval-context", methods=["GET"]) +@enterprise_only +def approval_context(): + token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME) + if not token: + raise Unauthorized("no_session") + + keyset = jws.KeySet.from_shared_secret() + try: + claims = verify_approval_grant(keyset, token) + except jws.VerifyError as e: + logger.warning("approval-context: bad cookie: %s", e) + raise Unauthorized("no_session") from e + + return jsonify({ + "subject_email": claims.subject_email, + "subject_issuer": claims.subject_issuer, + "user_code": claims.user_code, + "csrf_token": claims.csrf_token, + "expires_at": claims.expires_at.isoformat(), + }), 200 + + +@bp.route("/oauth/device/approve-external", methods=["POST"]) +@enterprise_only +def approve_external(): + token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME) + if not token: + raise Unauthorized("invalid_session") + + keyset = jws.KeySet.from_shared_secret() + try: + claims: ApprovalGrantClaims = verify_approval_grant(keyset, token) + except jws.VerifyError as e: + logger.warning("approve-external: bad cookie: %s", e) + raise Unauthorized("invalid_session") from e + + enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}") + + csrf_header = request.headers.get("X-CSRF-Token", "") + if not csrf_header or csrf_header != claims.csrf_token: + raise Forbidden("csrf_mismatch") + + data = request.get_json(silent=True) or {} + body_user_code = (data.get("user_code") or "").strip().upper() + if body_user_code != claims.user_code: + raise BadRequest("user_code_mismatch") + + store = DeviceFlowRedis(redis_client) + found = store.load_by_user_code(claims.user_code) + if found is None: + raise NotFound("user_code_not_pending") + device_code, state = found + if state.status is not DeviceFlowStatus.PENDING: + raise Conflict("user_code_not_pending") + + if not consume_approval_grant_nonce(redis_client, claims.nonce): + raise Unauthorized("session_already_consumed") + + ttl_days = oauth_ttl_days(tenant_id=None) + mint = mint_oauth_token( + db.session, + redis_client, + subject_email=claims.subject_email, + subject_issuer=claims.subject_issuer, + account_id=None, + client_id=state.client_id, + device_label=state.device_label, + prefix=PREFIX_OAUTH_EXTERNAL_SSO, + ttl_days=ttl_days, + ) + + poll_payload = { + "token": mint.token, + "expires_at": mint.expires_at.isoformat(), + "subject_type": SubjectType.EXTERNAL_SSO, + "subject_email": claims.subject_email, + "subject_issuer": claims.subject_issuer, + "account": None, + "workspaces": [], + "default_workspace_id": None, + "token_id": str(mint.token_id), + } + + try: + store.approve( + device_code, + subject_email=claims.subject_email, + account_id=None, + subject_issuer=claims.subject_issuer, + minted_token=mint.token, + token_id=str(mint.token_id), + poll_payload=poll_payload, + ) + except (StateNotFound, InvalidTransition) as e: + logger.error("approve-external: state transition raced: %s", e) + raise Conflict("state_lost") from e + + _emit_approve_external_audit(state, claims, mint) + + resp = make_response(jsonify({"status": "approved"}), 200) + resp.set_cookie(**approval_grant_cleared_cookie_kwargs()) + return resp + + +def _emit_approve_external_audit(state, claims, mint) -> None: + logger.warning( + "audit: oauth.device_flow_approved subject_type=%s " + "subject_email=%s subject_issuer=%s token_id=%s", + SubjectType.EXTERNAL_SSO, claims.subject_email, claims.subject_issuer, mint.token_id, + extra={ + "audit": True, + "event": "oauth.device_flow_approved", + "subject_type": SubjectType.EXTERNAL_SSO, + "subject_email": claims.subject_email, + "subject_issuer": claims.subject_issuer, + "token_id": str(mint.token_id), + "client_id": state.client_id, + "device_label": state.device_label, + "scopes": ["apps:run"], + "expires_at": mint.expires_at.isoformat(), + }, + ) diff --git a/api/controllers/service_api/oauth.py b/api/controllers/service_api/oauth.py index d10e8bf8eb..fb182d423a 100644 --- a/api/controllers/service_api/oauth.py +++ b/api/controllers/service_api/oauth.py @@ -1,14 +1,16 @@ """Legacy /v1/* mounts for the OAuth bearer + device-flow endpoints. Canonical handlers live in controllers/openapi/. This file just -re-registers them on the service_api_ns until Phase F retires the +re-registers them on service_api_ns until Phase F retires the legacy paths entirely. """ from __future__ import annotations from controllers.openapi.account import AccountApi, AccountSessionsSelfApi -from controllers.openapi.oauth_device.code import OAuthDeviceCodeApi -from controllers.openapi.oauth_device.lookup import OAuthDeviceLookupApi -from controllers.openapi.oauth_device.token import OAuthDeviceTokenApi +from controllers.openapi.oauth_device import ( + OAuthDeviceCodeApi, + OAuthDeviceLookupApi, + OAuthDeviceTokenApi, +) from controllers.service_api import service_api_ns # Legacy /v1/* mounts — handlers live in controllers/openapi/. diff --git a/api/tests/unit_tests/controllers/openapi/test_device_approve_deny.py b/api/tests/unit_tests/controllers/openapi/test_device_approve_deny.py index 1718b71a7e..11b42e4ae0 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_approve_deny.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_approve_deny.py @@ -10,8 +10,7 @@ from flask.views import MethodView from controllers.console import bp as console_bp from controllers.openapi import bp as openapi_bp -from controllers.openapi.oauth_device.approve import DeviceApproveApi -from controllers.openapi.oauth_device.deny import DeviceDenyApi +from controllers.openapi.oauth_device import DeviceApproveApi, DeviceDenyApi if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/openapi/test_device_code.py b/api/tests/unit_tests/controllers/openapi/test_device_code.py index 54ba90a81c..79a65fa9f1 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_code.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_code.py @@ -13,7 +13,7 @@ from flask import Flask from flask.views import MethodView from controllers.openapi import bp as openapi_bp -from controllers.openapi.oauth_device.code import OAuthDeviceCodeApi +from controllers.openapi.oauth_device import OAuthDeviceCodeApi from controllers.service_api import bp as service_api_bp if not hasattr(builtins, "MethodView"): diff --git a/api/tests/unit_tests/controllers/openapi/test_device_lookup.py b/api/tests/unit_tests/controllers/openapi/test_device_lookup.py index b28e881d10..7d1ae3b640 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_lookup.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_lookup.py @@ -9,7 +9,7 @@ from flask import Flask from flask.views import MethodView from controllers.openapi import bp as openapi_bp -from controllers.openapi.oauth_device.lookup import OAuthDeviceLookupApi +from controllers.openapi.oauth_device import OAuthDeviceLookupApi from controllers.service_api import bp as service_api_bp if not hasattr(builtins, "MethodView"): diff --git a/api/tests/unit_tests/controllers/openapi/test_device_sso.py b/api/tests/unit_tests/controllers/openapi/test_device_sso.py index d699be9385..bdd337c62b 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_sso.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_sso.py @@ -10,10 +10,12 @@ from flask.views import MethodView from controllers.oauth_device_sso import bp as legacy_sso_bp from controllers.openapi import bp as openapi_bp -from controllers.openapi.oauth_device.approval_context import approval_context -from controllers.openapi.oauth_device.approve_external import approve_external -from controllers.openapi.oauth_device.sso_complete import sso_complete -from controllers.openapi.oauth_device.sso_initiate import sso_initiate +from controllers.openapi.oauth_device_sso import ( + approval_context, + approve_external, + sso_complete, + sso_initiate, +) if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] @@ -115,6 +117,6 @@ def test_sso_complete_idp_callback_url_uses_canonical_path(): canonical /openapi/v1/ path so IdPs are configured against the forward-looking ACS endpoint, not the legacy alias. """ - from controllers.openapi.oauth_device import sso_initiate as si + from controllers.openapi import oauth_device_sso - assert si._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete" + assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete" diff --git a/api/tests/unit_tests/controllers/openapi/test_device_token.py b/api/tests/unit_tests/controllers/openapi/test_device_token.py index 3b47fd3ecb..31d769314b 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_token.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_token.py @@ -9,7 +9,7 @@ from flask import Flask from flask.views import MethodView from controllers.openapi import bp as openapi_bp -from controllers.openapi.oauth_device.token import OAuthDeviceTokenApi +from controllers.openapi.oauth_device import OAuthDeviceTokenApi from controllers.service_api import bp as service_api_bp if not hasattr(builtins, "MethodView"):