diff --git a/api/controllers/openapi/__init__.py b/api/controllers/openapi/__init__.py index 67a1169317..8ca9c668fe 100644 --- a/api/controllers/openapi/__init__.py +++ b/api/controllers/openapi/__init__.py @@ -16,10 +16,12 @@ openapi_ns = Namespace("openapi", description="User-scoped operations", path="/" from . import index from .oauth_device import code as oauth_device_code +from .oauth_device import token as oauth_device_token __all__ = [ "index", "oauth_device_code", + "oauth_device_token", ] api.add_namespace(openapi_ns) diff --git a/api/controllers/openapi/oauth_device/token.py b/api/controllers/openapi/oauth_device/token.py new file mode 100644 index 0000000000..e3c4fe1e88 --- /dev/null +++ b/api/controllers/openapi/oauth_device/token.py @@ -0,0 +1,82 @@ +"""POST /openapi/v1/oauth/device/token — RFC 8628 device authorization +poll. Public; the CLI polls until the user completes approval at +/device. + +The class is also registered on the legacy /v1/ namespace from +service_api/oauth.py until Phase F retires that mount. +""" +from __future__ import annotations + +import logging + +from flask import request +from flask_restx import Resource, reqparse + +from controllers.openapi import openapi_ns +from extensions.ext_redis import redis_client +from libs.helper import extract_remote_ip +from services.oauth_device_flow import ( + DEFAULT_POLL_INTERVAL_SECONDS, + DeviceFlowRedis, + DeviceFlowStatus, + SlowDownDecision, +) + +logger = logging.getLogger(__name__) + +_poll_parser = reqparse.RequestParser() +_poll_parser.add_argument("device_code", type=str, required=True, location="json") +_poll_parser.add_argument("client_id", type=str, required=True, location="json") + + +@openapi_ns.route("/oauth/device/token") +class OAuthDeviceTokenApi(Resource): + """RFC 8628 poll.""" + + def post(self): + args = _poll_parser.parse_args() + device_code = args["device_code"] + + store = DeviceFlowRedis(redis_client) + + # slow_down beats every other branch — polling-too-fast clients + # see only that response regardless of underlying state. + if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN: + return {"error": "slow_down"}, 400 + + state = store.load_by_device_code(device_code) + if state is None: + return {"error": "expired_token"}, 400 + + if state.status is DeviceFlowStatus.PENDING: + return {"error": "authorization_pending"}, 400 + + terminal = store.consume_on_poll(device_code) + if terminal is None: + return {"error": "expired_token"}, 400 + + if terminal.status is DeviceFlowStatus.DENIED: + return {"error": "access_denied"}, 400 + + poll_payload = terminal.poll_payload or {} + if "token" not in poll_payload: + logger.error("device_flow: approved state missing poll_payload for %s", device_code) + return {"error": "expired_token"}, 400 + + _audit_cross_ip_if_needed(state) + return poll_payload, 200 + + +def _audit_cross_ip_if_needed(state) -> None: + poll_ip = extract_remote_ip(request) + if state.created_ip and poll_ip and poll_ip != state.created_ip: + logger.warning( + "audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s", + state.token_id, state.created_ip, poll_ip, + extra={ + "audit": True, + "token_id": state.token_id, + "creation_ip": state.created_ip, + "poll_ip": poll_ip, + }, + ) diff --git a/api/controllers/service_api/oauth.py b/api/controllers/service_api/oauth.py index b0e8a867ac..61d984002a 100644 --- a/api/controllers/service_api/oauth.py +++ b/api/controllers/service_api/oauth.py @@ -14,10 +14,10 @@ from sqlalchemy import update from werkzeug.exceptions import BadRequest from controllers.openapi.oauth_device.code import OAuthDeviceCodeApi +from controllers.openapi.oauth_device.token import OAuthDeviceTokenApi from controllers.service_api import service_api_ns from extensions.ext_database import db from extensions.ext_redis import redis_client -from libs.helper import extract_remote_ip from libs.oauth_bearer import ( ACCEPT_USER_ANY, SubjectType, @@ -33,18 +33,17 @@ from libs.rate_limit import ( ) from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin from services.oauth_device_flow import ( - DEFAULT_POLL_INTERVAL_SECONDS, DEVICE_FLOW_TTL_SECONDS, DeviceFlowRedis, DeviceFlowStatus, - SlowDownDecision, ) logger = logging.getLogger(__name__) -# Legacy /v1/oauth/device/code mount — handler lives in -# controllers/openapi/oauth_device/code.py. Removed in Phase F. +# Legacy /v1/* mounts — handlers live in controllers/openapi/oauth_device/. +# Removed in Phase F. service_api_ns.add_resource(OAuthDeviceCodeApi, "/oauth/device/code") +service_api_ns.add_resource(OAuthDeviceTokenApi, "/oauth/device/token") # ============================================================================ @@ -162,54 +161,6 @@ class OAuthAuthorizationsSelfApi(Resource): return {"status": "revoked"}, 200 -# ============================================================================ -# POST /v1/oauth/device/token (unauthenticated — CLI polls) -# ============================================================================ - - -_poll_parser = reqparse.RequestParser() -_poll_parser.add_argument("device_code", type=str, required=True, location="json") -_poll_parser.add_argument("client_id", type=str, required=True, location="json") - - -@service_api_ns.route("/oauth/device/token") -class OAuthDeviceTokenApi(Resource): - """RFC 8628 poll.""" - - def post(self): - args = _poll_parser.parse_args() - device_code = args["device_code"] - - store = DeviceFlowRedis(redis_client) - - # slow_down beats every other branch — polling-too-fast clients - # see only that response regardless of underlying state. - if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN: - return {"error": "slow_down"}, 400 - - state = store.load_by_device_code(device_code) - if state is None: - return {"error": "expired_token"}, 400 - - if state.status is DeviceFlowStatus.PENDING: - return {"error": "authorization_pending"}, 400 - - terminal = store.consume_on_poll(device_code) - if terminal is None: - return {"error": "expired_token"}, 400 - - if terminal.status is DeviceFlowStatus.DENIED: - return {"error": "access_denied"}, 400 - - poll_payload = terminal.poll_payload or {} - if "token" not in poll_payload: - logger.error("device_flow: approved state missing poll_payload for %s", device_code) - return {"error": "expired_token"}, 400 - - _audit_cross_ip_if_needed(state) - return poll_payload, 200 - - # ============================================================================ # GET /v1/oauth/device/lookup (unauthenticated — /device page pre-validates) # ============================================================================ @@ -246,16 +197,3 @@ class OAuthDeviceLookupApi(Resource): }, 200 -def _audit_cross_ip_if_needed(state) -> None: - poll_ip = extract_remote_ip(request) - if state.created_ip and poll_ip and poll_ip != state.created_ip: - logger.warning( - "audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s", - state.token_id, state.created_ip, poll_ip, - extra={ - "audit": True, - "token_id": state.token_id, - "creation_ip": state.created_ip, - "poll_ip": poll_ip, - }, - ) diff --git a/api/tests/unit_tests/controllers/openapi/test_device_token.py b/api/tests/unit_tests/controllers/openapi/test_device_token.py new file mode 100644 index 0000000000..3b47fd3ecb --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_device_token.py @@ -0,0 +1,46 @@ +"""Phase B step 7: POST /openapi/v1/oauth/device/token mounted via the +canonical class. Legacy /v1/oauth/device/token re-registered. Both +paths must dispatch to the same class. +""" +import builtins + +import pytest +from flask import Flask +from flask.views import MethodView + +from controllers.openapi import bp as openapi_bp +from controllers.openapi.oauth_device.token import OAuthDeviceTokenApi +from controllers.service_api import bp as service_api_bp + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def dual_app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + app.register_blueprint(service_api_bp) + app.register_blueprint(openapi_bp) + return app + + +def test_openapi_route_registered(dual_app: Flask): + rules = {r.rule for r in dual_app.url_map.iter_rules()} + assert "/openapi/v1/oauth/device/token" in rules + + +def test_legacy_v1_route_registered(dual_app: Flask): + rules = {r.rule for r in dual_app.url_map.iter_rules()} + assert "/v1/oauth/device/token" in rules + + +def test_both_paths_dispatch_to_same_class(dual_app: Flask): + new = next( + r for r in dual_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token" + ) + legacy = next( + r for r in dual_app.url_map.iter_rules() if r.rule == "/v1/oauth/device/token" + ) + assert dual_app.view_functions[new.endpoint].view_class is OAuthDeviceTokenApi + assert dual_app.view_functions[legacy.endpoint].view_class is OAuthDeviceTokenApi