diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 41af96de50..00ac6ecb46 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -513,6 +513,20 @@ class HttpConfig(BaseSettings): def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]: return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o] + inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field( + description=( + "Comma-separated client_id values accepted at " + "POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here " + "without code changes. Unknown client_id returns 400 unsupported_client." + ), + validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"), + default="difyctl", + ) + + @computed_field + def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]: + return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c) + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field( ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10 ) diff --git a/api/controllers/openapi/__init__.py b/api/controllers/openapi/__init__.py index d94e9e0140..67a1169317 100644 --- a/api/controllers/openapi/__init__.py +++ b/api/controllers/openapi/__init__.py @@ -15,9 +15,11 @@ api = ExternalApi( openapi_ns = Namespace("openapi", description="User-scoped operations", path="/") from . import index +from .oauth_device import code as oauth_device_code __all__ = [ "index", + "oauth_device_code", ] api.add_namespace(openapi_ns) diff --git a/api/controllers/openapi/oauth_device/__init__.py b/api/controllers/openapi/oauth_device/__init__.py new file mode 100644 index 0000000000..5d55c7ebc1 --- /dev/null +++ b/api/controllers/openapi/oauth_device/__init__.py @@ -0,0 +1,4 @@ +"""User-scoped device-flow protocol endpoints (RFC 8628). Public — +unauthenticated, per-IP rate-limited. Approval/deny + SSO branch land +here in Phase D. +""" diff --git a/api/controllers/openapi/oauth_device/code.py b/api/controllers/openapi/oauth_device/code.py new file mode 100644 index 0000000000..f6d4139010 --- /dev/null +++ b/api/controllers/openapi/oauth_device/code.py @@ -0,0 +1,56 @@ +"""POST /openapi/v1/oauth/device/code — RFC 8628 device authorization request. + +Public + per-IP rate-limited. The CLI starts a device flow here; the +returned `verification_uri` is what the user opens in a browser. The +class is also registered on the legacy /v1/ namespace from +service_api/oauth.py until Phase F retires that mount. +""" +from __future__ import annotations + +from flask import request +from flask_restx import Resource, reqparse + +from configs import dify_config +from controllers.openapi import openapi_ns +from extensions.ext_redis import redis_client +from libs.helper import extract_remote_ip +from libs.rate_limit import LIMIT_DEVICE_CODE_PER_IP, rate_limit +from services.oauth_device_flow import ( + DEFAULT_POLL_INTERVAL_SECONDS, + DeviceFlowRedis, +) + +_code_parser = reqparse.RequestParser() +_code_parser.add_argument("client_id", type=str, required=True, location="json") +_code_parser.add_argument("device_label", type=str, required=True, location="json") + + +@openapi_ns.route("/oauth/device/code") +class OAuthDeviceCodeApi(Resource): + @rate_limit(LIMIT_DEVICE_CODE_PER_IP) + def post(self): + args = _code_parser.parse_args() + client_id = args["client_id"] + device_label = args["device_label"] + + if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS: + return {"error": "unsupported_client"}, 400 + + store = DeviceFlowRedis(redis_client) + ip = extract_remote_ip(request) + device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip) + + return { + "device_code": device_code, + "user_code": user_code, + "verification_uri": _verification_uri(), + "expires_in": expires_in, + "interval": DEFAULT_POLL_INTERVAL_SECONDS, + }, 200 + + +def _verification_uri() -> str: + base = getattr(dify_config, "CONSOLE_WEB_URL", None) + if base: + return f"{base.rstrip('/')}/device" + return f"{request.host_url.rstrip('/')}/device" diff --git a/api/controllers/service_api/oauth.py b/api/controllers/service_api/oauth.py index e1ab831d16..b0e8a867ac 100644 --- a/api/controllers/service_api/oauth.py +++ b/api/controllers/service_api/oauth.py @@ -13,6 +13,7 @@ from flask_restx import Resource, reqparse from sqlalchemy import update from werkzeug.exceptions import BadRequest +from controllers.openapi.oauth_device.code import OAuthDeviceCodeApi from controllers.service_api import service_api_ns from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -24,7 +25,6 @@ from libs.oauth_bearer import ( validate_bearer, ) from libs.rate_limit import ( - LIMIT_DEVICE_CODE_PER_IP, LIMIT_LOOKUP_PUBLIC, LIMIT_ME_PER_ACCOUNT, LIMIT_ME_PER_EMAIL, @@ -42,7 +42,9 @@ from services.oauth_device_flow import ( logger = logging.getLogger(__name__) -KNOWN_CLIENT_IDS = frozenset({"difyctl"}) +# Legacy /v1/oauth/device/code mount — handler lives in +# controllers/openapi/oauth_device/code.py. Removed in Phase F. +service_api_ns.add_resource(OAuthDeviceCodeApi, "/oauth/device/code") # ============================================================================ @@ -160,49 +162,6 @@ class OAuthAuthorizationsSelfApi(Resource): return {"status": "revoked"}, 200 -# ============================================================================ -# POST /v1/oauth/device/code (unauthenticated — CLI starts a flow) -# ============================================================================ - - -_code_parser = reqparse.RequestParser() -_code_parser.add_argument("client_id", type=str, required=True, location="json") -_code_parser.add_argument("device_label", type=str, required=True, location="json") - - -@service_api_ns.route("/oauth/device/code") -class OAuthDeviceCodeApi(Resource): - @rate_limit(LIMIT_DEVICE_CODE_PER_IP) - def post(self): - args = _code_parser.parse_args() - client_id = args["client_id"] - device_label = args["device_label"] - - if client_id not in KNOWN_CLIENT_IDS: - return {"error": "unsupported_client"}, 400 - - store = DeviceFlowRedis(redis_client) - ip = extract_remote_ip(request) - device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip) - - return { - "device_code": device_code, - "user_code": user_code, - "verification_uri": _verification_uri(), - "expires_in": expires_in, - "interval": DEFAULT_POLL_INTERVAL_SECONDS, - }, 200 - - -def _verification_uri() -> str: - from configs import dify_config - - base = getattr(dify_config, "CONSOLE_WEB_URL", None) - if base: - return f"{base.rstrip('/')}/device" - return f"{request.host_url.rstrip('/')}/device" - - # ============================================================================ # POST /v1/oauth/device/token (unauthenticated — CLI polls) # ============================================================================ diff --git a/api/tests/unit_tests/controllers/openapi/test_device_code.py b/api/tests/unit_tests/controllers/openapi/test_device_code.py new file mode 100644 index 0000000000..54ba90a81c --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_device_code.py @@ -0,0 +1,81 @@ +"""Phase B step 6: POST /openapi/v1/oauth/device/code is the canonical +RFC 8628 device authorization endpoint. The legacy /v1/oauth/device/code +mount stays until Phase F; both paths must dispatch to the same class. + +Tests verify URL routing and re-registration without invoking the +handler — invoking would require Redis, which the unit-test runtime +does not initialise. +""" +import builtins + +import pytest +from flask import Flask +from flask.views import MethodView + +from controllers.openapi import bp as openapi_bp +from controllers.openapi.oauth_device.code import OAuthDeviceCodeApi +from controllers.service_api import bp as service_api_bp + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def dual_app() -> Flask: + """Both blueprints registered, mirroring production layout.""" + app = Flask(__name__) + app.config["TESTING"] = True + app.register_blueprint(service_api_bp) + app.register_blueprint(openapi_bp) + return app + + +def test_openapi_route_registered(dual_app: Flask): + rules = {r.rule for r in dual_app.url_map.iter_rules()} + assert "/openapi/v1/oauth/device/code" in rules + + +def test_legacy_v1_route_still_registered(dual_app: Flask): + """service_api/oauth.py re-registers the lifted class on /v1/.""" + rules = {r.rule for r in dual_app.url_map.iter_rules()} + assert "/v1/oauth/device/code" in rules + + +def test_both_paths_dispatch_to_same_class(dual_app: Flask): + """Single source of truth — no duplicated handler logic.""" + new = next( + r for r in dual_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code" + ) + legacy = next( + r for r in dual_app.url_map.iter_rules() if r.rule == "/v1/oauth/device/code" + ) + + new_view = dual_app.view_functions[new.endpoint] + legacy_view = dual_app.view_functions[legacy.endpoint] + # Flask-RESTX wraps Resource classes in a `view_class` attribute. + assert new_view.view_class is OAuthDeviceCodeApi + assert legacy_view.view_class is OAuthDeviceCodeApi + + +def test_route_accepts_post_and_options(dual_app: Flask): + new = next( + r for r in dual_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code" + ) + legacy = next( + r for r in dual_app.url_map.iter_rules() if r.rule == "/v1/oauth/device/code" + ) + assert "POST" in new.methods + assert "POST" in legacy.methods + + +def test_handler_class_imports_match(): + """service_api re-uses the openapi class, not a copy.""" + from controllers.service_api import oauth as service_api_oauth + + assert service_api_oauth.OAuthDeviceCodeApi is OAuthDeviceCodeApi + + +def test_known_client_ids_default_includes_difyctl(): + from configs import dify_config + + assert "difyctl" in dify_config.OPENAPI_KNOWN_CLIENT_IDS