feat(api): lift POST /oauth/device/token to /openapi/v1 (Phase B.7)

Same pattern as B.6: OAuthDeviceTokenApi moves to
controllers/openapi/oauth_device/token.py and is re-registered on
service_api_ns to keep /v1/oauth/device/token serving until Phase F.

_audit_cross_ip_if_needed helper moves with the handler. Now-unused
imports removed from service_api/oauth.py.

Plan: docs/superpowers/plans/2026-04-26-openapi-migration.md (in difyctl repo).
This commit is contained in:
GareArc 2026-04-26 23:42:27 -07:00
parent fe9412af5d
commit 9408759954
No known key found for this signature in database
4 changed files with 134 additions and 66 deletions

View File

@ -16,10 +16,12 @@ openapi_ns = Namespace("openapi", description="User-scoped operations", path="/"
from . import index
from .oauth_device import code as oauth_device_code
from .oauth_device import token as oauth_device_token
__all__ = [
"index",
"oauth_device_code",
"oauth_device_token",
]
api.add_namespace(openapi_ns)

View File

@ -0,0 +1,82 @@
"""POST /openapi/v1/oauth/device/token — RFC 8628 device authorization
poll. Public; the CLI polls until the user completes approval at
/device.
The class is also registered on the legacy /v1/ namespace from
service_api/oauth.py until Phase F retires that mount.
"""
from __future__ import annotations
import logging
from flask import request
from flask_restx import Resource, reqparse
from controllers.openapi import openapi_ns
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from services.oauth_device_flow import (
DEFAULT_POLL_INTERVAL_SECONDS,
DeviceFlowRedis,
DeviceFlowStatus,
SlowDownDecision,
)
logger = logging.getLogger(__name__)
_poll_parser = reqparse.RequestParser()
_poll_parser.add_argument("device_code", type=str, required=True, location="json")
_poll_parser.add_argument("client_id", type=str, required=True, location="json")
@openapi_ns.route("/oauth/device/token")
class OAuthDeviceTokenApi(Resource):
"""RFC 8628 poll."""
def post(self):
args = _poll_parser.parse_args()
device_code = args["device_code"]
store = DeviceFlowRedis(redis_client)
# slow_down beats every other branch — polling-too-fast clients
# see only that response regardless of underlying state.
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
return {"error": "slow_down"}, 400
state = store.load_by_device_code(device_code)
if state is None:
return {"error": "expired_token"}, 400
if state.status is DeviceFlowStatus.PENDING:
return {"error": "authorization_pending"}, 400
terminal = store.consume_on_poll(device_code)
if terminal is None:
return {"error": "expired_token"}, 400
if terminal.status is DeviceFlowStatus.DENIED:
return {"error": "access_denied"}, 400
poll_payload = terminal.poll_payload or {}
if "token" not in poll_payload:
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
return {"error": "expired_token"}, 400
_audit_cross_ip_if_needed(state)
return poll_payload, 200
def _audit_cross_ip_if_needed(state) -> None:
poll_ip = extract_remote_ip(request)
if state.created_ip and poll_ip and poll_ip != state.created_ip:
logger.warning(
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
state.token_id, state.created_ip, poll_ip,
extra={
"audit": True,
"token_id": state.token_id,
"creation_ip": state.created_ip,
"poll_ip": poll_ip,
},
)

View File

@ -14,10 +14,10 @@ from sqlalchemy import update
from werkzeug.exceptions import BadRequest
from controllers.openapi.oauth_device.code import OAuthDeviceCodeApi
from controllers.openapi.oauth_device.token import OAuthDeviceTokenApi
from controllers.service_api import service_api_ns
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
SubjectType,
@ -33,18 +33,17 @@ from libs.rate_limit import (
)
from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin
from services.oauth_device_flow import (
DEFAULT_POLL_INTERVAL_SECONDS,
DEVICE_FLOW_TTL_SECONDS,
DeviceFlowRedis,
DeviceFlowStatus,
SlowDownDecision,
)
logger = logging.getLogger(__name__)
# Legacy /v1/oauth/device/code mount — handler lives in
# controllers/openapi/oauth_device/code.py. Removed in Phase F.
# Legacy /v1/* mounts — handlers live in controllers/openapi/oauth_device/.
# Removed in Phase F.
service_api_ns.add_resource(OAuthDeviceCodeApi, "/oauth/device/code")
service_api_ns.add_resource(OAuthDeviceTokenApi, "/oauth/device/token")
# ============================================================================
@ -162,54 +161,6 @@ class OAuthAuthorizationsSelfApi(Resource):
return {"status": "revoked"}, 200
# ============================================================================
# POST /v1/oauth/device/token (unauthenticated — CLI polls)
# ============================================================================
_poll_parser = reqparse.RequestParser()
_poll_parser.add_argument("device_code", type=str, required=True, location="json")
_poll_parser.add_argument("client_id", type=str, required=True, location="json")
@service_api_ns.route("/oauth/device/token")
class OAuthDeviceTokenApi(Resource):
"""RFC 8628 poll."""
def post(self):
args = _poll_parser.parse_args()
device_code = args["device_code"]
store = DeviceFlowRedis(redis_client)
# slow_down beats every other branch — polling-too-fast clients
# see only that response regardless of underlying state.
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
return {"error": "slow_down"}, 400
state = store.load_by_device_code(device_code)
if state is None:
return {"error": "expired_token"}, 400
if state.status is DeviceFlowStatus.PENDING:
return {"error": "authorization_pending"}, 400
terminal = store.consume_on_poll(device_code)
if terminal is None:
return {"error": "expired_token"}, 400
if terminal.status is DeviceFlowStatus.DENIED:
return {"error": "access_denied"}, 400
poll_payload = terminal.poll_payload or {}
if "token" not in poll_payload:
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
return {"error": "expired_token"}, 400
_audit_cross_ip_if_needed(state)
return poll_payload, 200
# ============================================================================
# GET /v1/oauth/device/lookup (unauthenticated — /device page pre-validates)
# ============================================================================
@ -246,16 +197,3 @@ class OAuthDeviceLookupApi(Resource):
}, 200
def _audit_cross_ip_if_needed(state) -> None:
poll_ip = extract_remote_ip(request)
if state.created_ip and poll_ip and poll_ip != state.created_ip:
logger.warning(
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
state.token_id, state.created_ip, poll_ip,
extra={
"audit": True,
"token_id": state.token_id,
"creation_ip": state.created_ip,
"poll_ip": poll_ip,
},
)

View File

@ -0,0 +1,46 @@
"""Phase B step 7: POST /openapi/v1/oauth/device/token mounted via the
canonical class. Legacy /v1/oauth/device/token re-registered. Both
paths must dispatch to the same class.
"""
import builtins
import pytest
from flask import Flask
from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
from controllers.openapi.oauth_device.token import OAuthDeviceTokenApi
from controllers.service_api import bp as service_api_bp
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def dual_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(service_api_bp)
app.register_blueprint(openapi_bp)
return app
def test_openapi_route_registered(dual_app: Flask):
rules = {r.rule for r in dual_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/token" in rules
def test_legacy_v1_route_registered(dual_app: Flask):
rules = {r.rule for r in dual_app.url_map.iter_rules()}
assert "/v1/oauth/device/token" in rules
def test_both_paths_dispatch_to_same_class(dual_app: Flask):
new = next(
r for r in dual_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token"
)
legacy = next(
r for r in dual_app.url_map.iter_rules() if r.rule == "/v1/oauth/device/token"
)
assert dual_app.view_functions[new.endpoint].view_class is OAuthDeviceTokenApi
assert dual_app.view_functions[legacy.endpoint].view_class is OAuthDeviceTokenApi