mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 04:36:31 +08:00
Type and lint pass over the openapi controllers, auth pipeline, and
oauth bearer/device-flow plumbing. Down from 36 pyright errors and 16
ruff errors to 0/0; 93 openapi unit tests pass.
Logic fixes:
- libs/oauth_bearer.py: drop private-naming on the friend-API methods
consumed by _VariantResolver (cache_get / cache_set_positive /
cache_set_negative / hard_expire / session_factory). They were always
cross-class accessors — leading underscore was misleading. Add public
registry property on BearerAuthenticator. _hard_expire row_id widened
to UUID | str (matches the StringUUID column type).
- libs/oauth_bearer.py: type validate_bearer / bearer_feature_required
with ParamSpec / PEP-695 so wrapped routes preserve their signature.
- libs/rate_limit.py: same — typed rate_limit decorator.
- services/oauth_device_flow.py: mint_oauth_token / _upsert accept
Session | scoped_session (Flask-SQLAlchemy proxy). Guard row-is-None
after upsert.
- controllers/openapi/{chat,completion,workflow}_messages.py: tuple-vs-
Mapping shape narrowing on AppGenerateService.generate return —
production returns Mapping, tests mock as (body, status). Validate
through Pydantic Response model in both shapes.
- controllers/openapi/oauth_device.py: replace flask_restx.reqparse (banned)
with Pydantic Request/Query models — DeviceCodeRequest, DevicePollRequest,
DeviceLookupQuery, DeviceMutateRequest. Two PEP-695 generic helpers
(_validate_json / _validate_query) translate ValidationError to BadRequest.
- controllers/openapi/auth/strategies.py: Protocol param-name match
(subject_type), Optional narrowing on app/tenant/account_id/subject_email.
- controllers/openapi/auth/steps.py: subject_type-is-None guard before
mounter dispatch.
- core/app/apps/workflow/generate_task_pipeline.py + models/workflow.py:
add WorkflowAppLogCreatedFrom.OPENAPI + matching match-case branch.
Fixes match-exhaustiveness and possibly-unbound created_from.
- libs/device_flow_security.py: pyright ignore on flask after_request
hook (registered by the framework, pyright sees as unused).
- services/oauth_device_flow.py: rename Exceptions to *Error suffix
(StateNotFoundError / InvalidTransitionError / UserCodeExhaustedError);
same for libs/oauth_bearer.py (InvalidBearerError / TokenExpiredError).
Update all callers across openapi controllers.
- controllers/openapi/{oauth_device,oauth_device_sso}.py +
services/oauth_device_flow.py: switch logger.error in except blocks
to logger.exception (TRY400) — keeps the traceback for ops.
- configs/feature/__init__.py: OPENAPI_KNOWN_CLIENT_IDS computed_field
needs an @property alongside for pyright to see it as a value, not a
method. Matches the existing line-451 pattern.
Plus ruff format + import-sort across the openapi tree (pure formatting).
393 lines
13 KiB
Python
393 lines
13 KiB
Python
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
|
sub-groups in one module:
|
|
|
|
Protocol (RFC 8628, public + rate-limited):
|
|
POST /oauth/device/code
|
|
POST /oauth/device/token
|
|
GET /oauth/device/lookup
|
|
|
|
Approval (account branch, console-cookie authed):
|
|
POST /oauth/device/approve
|
|
POST /oauth/device/deny
|
|
|
|
SSO branch lives in oauth_device_sso.py.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from flask import request
|
|
from flask_login import login_required
|
|
from flask_restx import Resource
|
|
from pydantic import BaseModel, ValidationError
|
|
from werkzeug.exceptions import BadRequest
|
|
|
|
from configs import dify_config
|
|
from controllers.console.wraps import account_initialization_required, setup_required
|
|
from controllers.openapi import openapi_ns
|
|
from extensions.ext_database import db
|
|
from extensions.ext_redis import redis_client
|
|
from libs.helper import extract_remote_ip
|
|
from libs.login import current_account_with_tenant
|
|
from libs.oauth_bearer import SubjectType, bearer_feature_required
|
|
from libs.rate_limit import (
|
|
LIMIT_APPROVE_CONSOLE,
|
|
LIMIT_DEVICE_CODE_PER_IP,
|
|
LIMIT_LOOKUP_PUBLIC,
|
|
rate_limit,
|
|
)
|
|
from services.oauth_device_flow import (
|
|
ACCOUNT_ISSUER_SENTINEL,
|
|
DEFAULT_POLL_INTERVAL_SECONDS,
|
|
DEVICE_FLOW_TTL_SECONDS,
|
|
PREFIX_OAUTH_ACCOUNT,
|
|
DeviceFlowRedis,
|
|
DeviceFlowStatus,
|
|
InvalidTransitionError,
|
|
SlowDownDecision,
|
|
StateNotFoundError,
|
|
mint_oauth_token,
|
|
oauth_ttl_days,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =========================================================================
|
|
# Request / query schemas
|
|
# =========================================================================
|
|
|
|
|
|
class DeviceCodeRequest(BaseModel):
|
|
client_id: str
|
|
device_label: str
|
|
|
|
|
|
class DevicePollRequest(BaseModel):
|
|
device_code: str
|
|
client_id: str
|
|
|
|
|
|
class DeviceLookupQuery(BaseModel):
|
|
user_code: str
|
|
|
|
|
|
class DeviceMutateRequest(BaseModel):
|
|
user_code: str
|
|
|
|
|
|
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
|
body = request.get_json(silent=True) or {}
|
|
try:
|
|
return model.model_validate(body)
|
|
except ValidationError as exc:
|
|
raise BadRequest(str(exc))
|
|
|
|
|
|
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
|
try:
|
|
return model.model_validate(request.args.to_dict(flat=True))
|
|
except ValidationError as exc:
|
|
raise BadRequest(str(exc))
|
|
|
|
|
|
# =========================================================================
|
|
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
|
# =========================================================================
|
|
|
|
|
|
@openapi_ns.route("/oauth/device/code")
|
|
class OAuthDeviceCodeApi(Resource):
|
|
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
|
def post(self):
|
|
payload = _validate_json(DeviceCodeRequest)
|
|
client_id = payload.client_id
|
|
device_label = payload.device_label
|
|
|
|
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
|
return {"error": "unsupported_client"}, 400
|
|
|
|
store = DeviceFlowRedis(redis_client)
|
|
ip = extract_remote_ip(request)
|
|
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
|
|
|
return {
|
|
"device_code": device_code,
|
|
"user_code": user_code,
|
|
"verification_uri": _verification_uri(),
|
|
"expires_in": expires_in,
|
|
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
|
}, 200
|
|
|
|
|
|
@openapi_ns.route("/oauth/device/token")
|
|
class OAuthDeviceTokenApi(Resource):
|
|
"""RFC 8628 poll."""
|
|
|
|
def post(self):
|
|
payload = _validate_json(DevicePollRequest)
|
|
device_code = payload.device_code
|
|
|
|
store = DeviceFlowRedis(redis_client)
|
|
|
|
# slow_down beats every other branch — polling-too-fast clients
|
|
# see only that response regardless of underlying state.
|
|
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
|
return {"error": "slow_down"}, 400
|
|
|
|
state = store.load_by_device_code(device_code)
|
|
if state is None:
|
|
return {"error": "expired_token"}, 400
|
|
|
|
if state.status is DeviceFlowStatus.PENDING:
|
|
return {"error": "authorization_pending"}, 400
|
|
|
|
terminal = store.consume_on_poll(device_code)
|
|
if terminal is None:
|
|
return {"error": "expired_token"}, 400
|
|
|
|
if terminal.status is DeviceFlowStatus.DENIED:
|
|
return {"error": "access_denied"}, 400
|
|
|
|
poll_payload = terminal.poll_payload or {}
|
|
if "token" not in poll_payload:
|
|
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
|
return {"error": "expired_token"}, 400
|
|
|
|
_audit_cross_ip_if_needed(state)
|
|
return poll_payload, 200
|
|
|
|
|
|
@openapi_ns.route("/oauth/device/lookup")
|
|
class OAuthDeviceLookupApi(Resource):
|
|
"""Read-only — public for pre-validate before login. user_code is
|
|
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
|
"""
|
|
|
|
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
|
def get(self):
|
|
payload = _validate_query(DeviceLookupQuery)
|
|
user_code = payload.user_code.strip().upper()
|
|
|
|
store = DeviceFlowRedis(redis_client)
|
|
found = store.load_by_user_code(user_code)
|
|
if found is None:
|
|
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
|
|
|
_device_code, state = found
|
|
if state.status is not DeviceFlowStatus.PENDING:
|
|
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
|
|
|
return {
|
|
"valid": True,
|
|
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
|
"client_id": state.client_id,
|
|
}, 200
|
|
|
|
|
|
# =========================================================================
|
|
# Approval endpoints — account branch (cookie-authed)
|
|
# =========================================================================
|
|
|
|
|
|
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
|
_APPROVE_GUARD_TTL_SECONDS = 10
|
|
|
|
|
|
@openapi_ns.route("/oauth/device/approve")
|
|
class DeviceApproveApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@bearer_feature_required
|
|
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
|
def post(self):
|
|
payload = _validate_json(DeviceMutateRequest)
|
|
user_code = payload.user_code.strip().upper()
|
|
|
|
account, tenant = current_account_with_tenant()
|
|
store = DeviceFlowRedis(redis_client)
|
|
|
|
found = store.load_by_user_code(user_code)
|
|
if found is None:
|
|
return {"error": "expired_or_unknown"}, 404
|
|
device_code, state = found
|
|
if state.status is not DeviceFlowStatus.PENDING:
|
|
return {"error": "already_resolved"}, 409
|
|
|
|
# SET NX guard — without it, two in-flight approves both pass
|
|
# PENDING, both mint, and the second upsert silently rotates the
|
|
# first caller into an already-revoked token.
|
|
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
|
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
|
return {"error": "approve_in_progress"}, 409
|
|
|
|
try:
|
|
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
|
mint = mint_oauth_token(
|
|
db.session,
|
|
redis_client,
|
|
subject_email=account.email,
|
|
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
|
account_id=str(account.id),
|
|
client_id=state.client_id,
|
|
device_label=state.device_label,
|
|
prefix=PREFIX_OAUTH_ACCOUNT,
|
|
ttl_days=ttl_days,
|
|
)
|
|
|
|
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
|
try:
|
|
store.approve(
|
|
device_code,
|
|
subject_email=account.email,
|
|
account_id=str(account.id),
|
|
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
|
minted_token=mint.token,
|
|
token_id=str(mint.token_id),
|
|
poll_payload=poll_payload,
|
|
)
|
|
except (StateNotFoundError, InvalidTransitionError):
|
|
# Row minted but state vanished — roll forward; the orphan
|
|
# token is revocable via auth devices list / Authorized Apps.
|
|
logger.exception("device_flow: approve raced on %s", device_code)
|
|
return {"error": "state_lost"}, 409
|
|
finally:
|
|
redis_client.delete(guard_key)
|
|
|
|
_emit_approve_audit(state, account, tenant, mint)
|
|
return {"status": "approved"}, 200
|
|
|
|
|
|
@openapi_ns.route("/oauth/device/deny")
|
|
class DeviceDenyApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@bearer_feature_required
|
|
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
|
def post(self):
|
|
payload = _validate_json(DeviceMutateRequest)
|
|
user_code = payload.user_code.strip().upper()
|
|
|
|
store = DeviceFlowRedis(redis_client)
|
|
found = store.load_by_user_code(user_code)
|
|
if found is None:
|
|
return {"error": "expired_or_unknown"}, 404
|
|
device_code, state = found
|
|
if state.status is not DeviceFlowStatus.PENDING:
|
|
return {"error": "already_resolved"}, 409
|
|
|
|
try:
|
|
store.deny(device_code)
|
|
except (StateNotFoundError, InvalidTransitionError):
|
|
logger.exception("device_flow: deny raced on %s", device_code)
|
|
return {"error": "state_lost"}, 409
|
|
|
|
_emit_deny_audit(state)
|
|
return {"status": "denied"}, 200
|
|
|
|
|
|
# =========================================================================
|
|
# Helpers
|
|
# =========================================================================
|
|
|
|
|
|
def _verification_uri() -> str:
|
|
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
|
if base:
|
|
return f"{base.rstrip('/')}/device"
|
|
return f"{request.host_url.rstrip('/')}/device"
|
|
|
|
|
|
def _audit_cross_ip_if_needed(state) -> None:
|
|
poll_ip = extract_remote_ip(request)
|
|
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
|
logger.warning(
|
|
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
|
state.token_id,
|
|
state.created_ip,
|
|
poll_ip,
|
|
extra={
|
|
"audit": True,
|
|
"token_id": state.token_id,
|
|
"creation_ip": state.created_ip,
|
|
"poll_ip": poll_ip,
|
|
},
|
|
)
|
|
|
|
|
|
def _build_account_poll_payload(account, tenant, mint) -> dict:
|
|
"""Pre-render the poll-response body so the unauthenticated poll
|
|
handler doesn't re-query accounts/tenants for authz data.
|
|
"""
|
|
from models import Tenant, TenantAccountJoin
|
|
|
|
rows = (
|
|
db.session.query(Tenant, TenantAccountJoin)
|
|
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
|
.filter(TenantAccountJoin.account_id == account.id)
|
|
.all()
|
|
)
|
|
workspaces = [{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} for t, m in rows]
|
|
# Prefer active session tenant → DB-flagged current join → first membership.
|
|
default_ws_id = None
|
|
if tenant and any(w["id"] == str(tenant) for w in workspaces):
|
|
default_ws_id = str(tenant)
|
|
if default_ws_id is None:
|
|
for _t, m in rows:
|
|
if getattr(m, "current", False):
|
|
default_ws_id = str(m.tenant_id)
|
|
break
|
|
if default_ws_id is None and workspaces:
|
|
default_ws_id = workspaces[0]["id"]
|
|
|
|
return {
|
|
"token": mint.token,
|
|
"expires_at": mint.expires_at.isoformat(),
|
|
"subject_type": SubjectType.ACCOUNT,
|
|
"account": {"id": str(account.id), "email": account.email, "name": account.name},
|
|
"workspaces": workspaces,
|
|
"default_workspace_id": default_ws_id,
|
|
"token_id": str(mint.token_id),
|
|
}
|
|
|
|
|
|
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
|
logger.warning(
|
|
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
|
mint.token_id,
|
|
account.email,
|
|
state.client_id,
|
|
state.device_label,
|
|
mint.expires_at,
|
|
extra={
|
|
"audit": True,
|
|
"event": "oauth.device_flow_approved",
|
|
"token_id": str(mint.token_id),
|
|
"subject_type": SubjectType.ACCOUNT,
|
|
"subject_email": account.email,
|
|
"account_id": str(account.id),
|
|
"tenant_id": tenant,
|
|
"client_id": state.client_id,
|
|
"device_label": state.device_label,
|
|
"scopes": ["full"],
|
|
"expires_at": mint.expires_at.isoformat(),
|
|
},
|
|
)
|
|
|
|
|
|
def _emit_deny_audit(state) -> None:
|
|
logger.warning(
|
|
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
|
state.client_id,
|
|
state.device_label,
|
|
extra={
|
|
"audit": True,
|
|
"event": "oauth.device_flow_denied",
|
|
"client_id": state.client_id,
|
|
"device_label": state.device_label,
|
|
},
|
|
)
|