chore(api): pyright + ruff cleanup for openapi/cli surface

Type and lint pass over the openapi controllers, auth pipeline, and
oauth bearer/device-flow plumbing. Down from 36 pyright errors and 16
ruff errors to 0/0; 93 openapi unit tests pass.

Logic fixes:
- libs/oauth_bearer.py: drop private-naming on the friend-API methods
  consumed by _VariantResolver (cache_get / cache_set_positive /
  cache_set_negative / hard_expire / session_factory). They were always
  cross-class accessors — leading underscore was misleading. Add public
  registry property on BearerAuthenticator. _hard_expire row_id widened
  to UUID | str (matches the StringUUID column type).
- libs/oauth_bearer.py: type validate_bearer / bearer_feature_required
  with ParamSpec / PEP-695 so wrapped routes preserve their signature.
- libs/rate_limit.py: same — typed rate_limit decorator.
- services/oauth_device_flow.py: mint_oauth_token / _upsert accept
  Session | scoped_session (Flask-SQLAlchemy proxy). Guard row-is-None
  after upsert.
- controllers/openapi/{chat,completion,workflow}_messages.py: tuple-vs-
  Mapping shape narrowing on AppGenerateService.generate return —
  production returns Mapping, tests mock as (body, status). Validate
  through Pydantic Response model in both shapes.
- controllers/openapi/oauth_device.py: replace flask_restx.reqparse (banned)
  with Pydantic Request/Query models — DeviceCodeRequest, DevicePollRequest,
  DeviceLookupQuery, DeviceMutateRequest. Two PEP-695 generic helpers
  (_validate_json / _validate_query) translate ValidationError to BadRequest.
- controllers/openapi/auth/strategies.py: Protocol param-name match
  (subject_type), Optional narrowing on app/tenant/account_id/subject_email.
- controllers/openapi/auth/steps.py: subject_type-is-None guard before
  mounter dispatch.
- core/app/apps/workflow/generate_task_pipeline.py + models/workflow.py:
  add WorkflowAppLogCreatedFrom.OPENAPI + matching match-case branch.
  Fixes match-exhaustiveness and possibly-unbound created_from.
- libs/device_flow_security.py: pyright ignore on flask after_request
  hook (registered by the framework, pyright sees as unused).
- services/oauth_device_flow.py: rename Exceptions to *Error suffix
  (StateNotFoundError / InvalidTransitionError / UserCodeExhaustedError);
  same for libs/oauth_bearer.py (InvalidBearerError / TokenExpiredError).
  Update all callers across openapi controllers.
- controllers/openapi/{oauth_device,oauth_device_sso}.py +
  services/oauth_device_flow.py: switch logger.error in except blocks
  to logger.exception (TRY400) — keeps the traceback for ops.
- configs/feature/__init__.py: OPENAPI_KNOWN_CLIENT_IDS computed_field
  needs an @property alongside for pyright to see it as a value, not a
  method. Matches the existing line-451 pattern.

Plus ruff format + import-sort across the openapi tree (pure formatting).
This commit is contained in:
GareArc 2026-04-28 21:44:54 -07:00
parent b083c910b3
commit 8a62c1d915
No known key found for this signature in database
40 changed files with 337 additions and 234 deletions

View File

@ -523,7 +523,8 @@ class HttpConfig(BaseSettings):
default="difyctl",
)
@computed_field
@computed_field # type: ignore[misc]
@property
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)

View File

@ -4,6 +4,7 @@ Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
matches the existing oauth_device convention. The EE OTel exporter consults
its own allowlist to decide whether to ship the line.
"""
from __future__ import annotations
import logging

View File

@ -1,4 +1,5 @@
"""Shared response substructures for openapi endpoints."""
from __future__ import annotations
from typing import Any

View File

@ -2,6 +2,7 @@
identity read; /account/sessions and /account/sessions/<id> manage
the user's active OAuth tokens.
"""
from __future__ import annotations
from datetime import UTC, datetime
@ -16,9 +17,9 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
TOKEN_CACHE_KEY_FMT,
AuthContext,
SubjectType,
TOKEN_CACHE_KEY_FMT,
validate_bearer,
)
from libs.rate_limit import (
@ -51,8 +52,7 @@ class AccountApi(Resource):
}
account = (
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none()
if ctx.account_id else None
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() if ctx.account_id else None
)
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
default_ws_id = _pick_default_workspace(memberships)
@ -129,8 +129,7 @@ class AccountSessionByIdApi(Resource):
# Subject-match guard. 404 (not 403) on cross-subject so the
# endpoint doesn't leak token IDs that belong to other subjects.
owns = db.session.execute(
select(OAuthAccessToken.id)
.where(
select(OAuthAccessToken.id).where(
and_(
OAuthAccessToken.id == session_id,
*_subject_match(ctx),
@ -160,8 +159,7 @@ def _subject_match(ctx: AuthContext) -> tuple:
def _require_oauth_subject(ctx: AuthContext) -> None:
if not ctx.source.startswith("oauth"):
raise BadRequest(
"this endpoint revokes OAuth bearer tokens; "
"use /openapi/v1/personal-access-tokens/self for PATs"
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
)

View File

@ -1,4 +1,5 @@
"""GET /openapi/v1/apps/<app_id>/info — port of service_api/app/app.py:AppInfoApi."""
from __future__ import annotations
from flask_restx import Resource

View File

@ -2,6 +2,7 @@
Endpoints attach via @APP_PIPELINE.guard(scope=). No alternative paths.
"""
from __future__ import annotations
from controllers.openapi.auth.pipeline import Pipeline

View File

@ -4,6 +4,7 @@ Every field starts None / empty and is filled in by a step. The pipeline
is the only thing that should construct or mutate Context handlers
read populated values via the decorator's kwargs unpacking.
"""
from __future__ import annotations
from dataclasses import dataclass, field

View File

@ -4,6 +4,7 @@
that is the design lock-in: forgetting an auth layer is structurally
impossible because there is no "sometimes wrap, sometimes don't" choice.
"""
from __future__ import annotations
from functools import wraps

View File

@ -3,21 +3,22 @@
BearerCheck is the only step that touches the token registry; downstream
steps see only the populated Context.
"""
from __future__ import annotations
from typing import Callable
from collections.abc import Callable
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
from extensions.ext_database import db
from libs.oauth_bearer import TokenExpired, get_authenticator, sha256_hex
from libs.oauth_bearer import TokenExpiredError, get_authenticator, sha256_hex
from models import App, Tenant, TenantStatus
def _registry():
return get_authenticator()._registry # noqa: SLF001
return get_authenticator().registry
def _extract_bearer(req) -> str | None:
@ -45,7 +46,7 @@ class BearerCheck:
try:
row = kind.resolver.resolve(_hash_token(token))
except TokenExpired:
except TokenExpiredError:
raise Unauthorized("token expired")
if row is None:
raise Unauthorized("invalid bearer")
@ -105,6 +106,8 @@ class CallerMount:
self._mounters = mounters
def __call__(self, ctx: Context) -> None:
if ctx.subject_type is None:
raise Unauthorized("subject_type unset — BearerCheck did not run")
for m in self._mounters:
if m.applies_to(ctx.subject_type):
m.mount(ctx)

View File

@ -4,6 +4,7 @@ App authorization (Acl/Membership) and caller mounting (Account/EndUser)
vary along independent axes; each strategy is one class so the pipeline
composition stays a flat list.
"""
from __future__ import annotations
from typing import Protocol
@ -33,9 +34,11 @@ class AclStrategy:
"""
def authorize(self, ctx: Context) -> bool:
if ctx.subject_email is None or ctx.app is None:
return False
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=ctx.subject_email,
app_id=ctx.app.id,
app_id=ctx.app.id, # type: ignore[attr-defined]
)
@ -50,7 +53,9 @@ class MembershipStrategy:
def authorize(self, ctx: Context) -> bool:
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return False
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
if ctx.tenant is None:
return False
return _has_tenant_membership(ctx.account_id, ctx.tenant.id) # type: ignore[attr-defined]
def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool:
@ -67,8 +72,8 @@ def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool:
def _login_as(user) -> None:
"""Set Flask-Login request user so downstream services see the caller."""
current_app.login_manager._update_request_context_with_user(user) # noqa: SLF001
user_logged_in.send(current_app._get_current_object(), user=user) # noqa: SLF001
current_app.login_manager._update_request_context_with_user(user)
user_logged_in.send(current_app._get_current_object(), user=user)
class CallerMounter(Protocol):
@ -78,25 +83,31 @@ class CallerMounter(Protocol):
class AccountMounter:
def applies_to(self, st: SubjectType) -> bool:
return st == SubjectType.ACCOUNT
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.ACCOUNT
def mount(self, ctx: Context) -> None:
if ctx.account_id is None:
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
account = db.session.get(Account, ctx.account_id)
account.current_tenant = ctx.tenant
if account is None:
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
account.current_tenant = ctx.tenant # type: ignore[assignment]
_login_as(account)
ctx.caller, ctx.caller_kind = account, "account"
class EndUserMounter:
def applies_to(self, st: SubjectType) -> bool:
return st == SubjectType.EXTERNAL_SSO
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.EXTERNAL_SSO
def mount(self, ctx: Context) -> None:
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=ctx.tenant.id,
app_id=ctx.app.id,
tenant_id=ctx.tenant.id, # type: ignore[attr-defined]
app_id=ctx.app.id, # type: ignore[attr-defined]
user_id=ctx.subject_email,
)
_login_as(end_user)

View File

@ -8,9 +8,11 @@ Differences from service_api:
- Typed Request and Response models.
- invoke_from = InvokeFrom.OPENAPI.
"""
from __future__ import annotations
import logging
from collections.abc import Mapping
from typing import Any, Literal
from uuid import UUID
@ -163,5 +165,12 @@ class ChatMessagesApi(Resource):
if streaming:
return helper.compact_generate_response(response)
body_dict = response[0] if isinstance(response, tuple) else response
return ChatMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200
# Some upstream paths (and tests) return (body, status); production
# generate returns Mapping. Accept both, then validate.
if isinstance(response, tuple):
body_dict: Any = response[0] # pyright: ignore[reportArgumentType]
else:
body_dict = response
if not isinstance(body_dict, Mapping):
raise InternalServerError("blocking generate returned non-mapping response")
return ChatMessageResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200

View File

@ -1,8 +1,10 @@
"""POST /openapi/v1/apps/<app_id>/completion-messages — port of
service_api/app/completion.py:CompletionApi."""
from __future__ import annotations
import logging
from collections.abc import Mapping
from typing import Any, Literal
from flask import request
@ -120,5 +122,10 @@ class CompletionMessagesApi(Resource):
if streaming:
return helper.compact_generate_response(response)
body_dict = response[0] if isinstance(response, tuple) else response
return CompletionMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200
if isinstance(response, tuple):
body_dict: Any = response[0] # pyright: ignore[reportArgumentType]
else:
body_dict = response
if not isinstance(body_dict, Mapping):
raise InternalServerError("blocking generate returned non-mapping response")
return CompletionMessageResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200

View File

@ -12,13 +12,16 @@ sub-groups in one module:
SSO branch lives in oauth_device_sso.py.
"""
from __future__ import annotations
import logging
from flask import request
from flask_login import login_required
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest
from configs import dify_config
from controllers.console.wraps import account_initialization_required, setup_required
@ -41,9 +44,9 @@ from services.oauth_device_flow import (
PREFIX_OAUTH_ACCOUNT,
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransition,
InvalidTransitionError,
SlowDownDecision,
StateNotFound,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
@ -52,22 +55,41 @@ logger = logging.getLogger(__name__)
# =========================================================================
# Parsers
# Request / query schemas
# =========================================================================
_code_parser = reqparse.RequestParser()
_code_parser.add_argument("client_id", type=str, required=True, location="json")
_code_parser.add_argument("device_label", type=str, required=True, location="json")
_poll_parser = reqparse.RequestParser()
_poll_parser.add_argument("device_code", type=str, required=True, location="json")
_poll_parser.add_argument("client_id", type=str, required=True, location="json")
class DeviceCodeRequest(BaseModel):
client_id: str
device_label: str
_lookup_parser = reqparse.RequestParser()
_lookup_parser.add_argument("user_code", type=str, required=True, location="args")
_mutate_parser = reqparse.RequestParser()
_mutate_parser.add_argument("user_code", type=str, required=True, location="json")
class DevicePollRequest(BaseModel):
device_code: str
client_id: str
class DeviceLookupQuery(BaseModel):
user_code: str
class DeviceMutateRequest(BaseModel):
user_code: str
def _validate_json[M: BaseModel](model: type[M]) -> M:
body = request.get_json(silent=True) or {}
try:
return model.model_validate(body)
except ValidationError as exc:
raise BadRequest(str(exc))
def _validate_query[M: BaseModel](model: type[M]) -> M:
try:
return model.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise BadRequest(str(exc))
# =========================================================================
@ -79,9 +101,9 @@ _mutate_parser.add_argument("user_code", type=str, required=True, location="json
class OAuthDeviceCodeApi(Resource):
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
def post(self):
args = _code_parser.parse_args()
client_id = args["client_id"]
device_label = args["device_label"]
payload = _validate_json(DeviceCodeRequest)
client_id = payload.client_id
device_label = payload.device_label
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
return {"error": "unsupported_client"}, 400
@ -104,8 +126,8 @@ class OAuthDeviceTokenApi(Resource):
"""RFC 8628 poll."""
def post(self):
args = _poll_parser.parse_args()
device_code = args["device_code"]
payload = _validate_json(DevicePollRequest)
device_code = payload.device_code
store = DeviceFlowRedis(redis_client)
@ -145,8 +167,8 @@ class OAuthDeviceLookupApi(Resource):
@rate_limit(LIMIT_LOOKUP_PUBLIC)
def get(self):
args = _lookup_parser.parse_args()
user_code = args["user_code"].strip().upper()
payload = _validate_query(DeviceLookupQuery)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
@ -181,8 +203,8 @@ class DeviceApproveApi(Resource):
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
args = _mutate_parser.parse_args()
user_code = args["user_code"].strip().upper()
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
@ -226,10 +248,10 @@ class DeviceApproveApi(Resource):
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFound, InvalidTransition) as e:
except (StateNotFoundError, InvalidTransitionError):
# Row minted but state vanished — roll forward; the orphan
# token is revocable via auth devices list / Authorized Apps.
logger.error("device_flow: approve raced on %s: %s", device_code, e)
logger.exception("device_flow: approve raced on %s", device_code)
return {"error": "state_lost"}, 409
finally:
redis_client.delete(guard_key)
@ -246,8 +268,8 @@ class DeviceDenyApi(Resource):
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
args = _mutate_parser.parse_args()
user_code = args["user_code"].strip().upper()
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
@ -259,8 +281,8 @@ class DeviceDenyApi(Resource):
try:
store.deny(device_code)
except (StateNotFound, InvalidTransition) as e:
logger.error("device_flow: deny raced on %s: %s", device_code, e)
except (StateNotFoundError, InvalidTransitionError):
logger.exception("device_flow: deny raced on %s", device_code)
return {"error": "state_lost"}, 409
_emit_deny_audit(state)
@ -284,7 +306,9 @@ def _audit_cross_ip_if_needed(state) -> None:
if state.created_ip and poll_ip and poll_ip != state.created_ip:
logger.warning(
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
state.token_id, state.created_ip, poll_ip,
state.token_id,
state.created_ip,
poll_ip,
extra={
"audit": True,
"token_id": state.token_id,
@ -299,16 +323,14 @@ def _build_account_poll_payload(account, tenant, mint) -> dict:
handler doesn't re-query accounts/tenants for authz data.
"""
from models import Tenant, TenantAccountJoin
rows = (
db.session.query(Tenant, TenantAccountJoin)
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == account.id)
.all()
)
workspaces = [
{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")}
for t, m in rows
]
workspaces = [{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} for t, m in rows]
# Prefer active session tenant → DB-flagged current join → first membership.
default_ws_id = None
if tenant and any(w["id"] == str(tenant) for w in workspaces):
@ -335,7 +357,11 @@ def _build_account_poll_payload(account, tenant, mint) -> dict:
def _emit_approve_audit(state, account, tenant, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
mint.token_id, account.email, state.client_id, state.device_label, mint.expires_at,
mint.token_id,
account.email,
state.client_id,
state.device_label,
mint.expires_at,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
@ -355,7 +381,8 @@ def _emit_approve_audit(state, account, tenant, mint) -> None:
def _emit_deny_audit(state) -> None:
logger.warning(
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
state.client_id, state.device_label,
state.client_id,
state.device_label,
extra={
"audit": True,
"event": "oauth.device_flow_denied",
@ -363,5 +390,3 @@ def _emit_deny_audit(state) -> None:
"device_label": state.device_label,
},
)

View File

@ -9,6 +9,7 @@ EE-only. Browser flow:
Function-based (raw @bp.route) rather than Resource classes because the
handlers do redirects + cookie kwargs that don't fit the Resource shape.
"""
from __future__ import annotations
import logging
@ -51,8 +52,8 @@ from services.oauth_device_flow import (
PREFIX_OAUTH_EXTERNAL_SSO,
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransition,
StateNotFound,
InvalidTransitionError,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
@ -171,13 +172,15 @@ def approval_context():
logger.warning("approval-context: bad cookie: %s", e)
raise Unauthorized("no_session") from e
return jsonify({
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"user_code": claims.user_code,
"csrf_token": claims.csrf_token,
"expires_at": claims.expires_at.isoformat(),
}), 200
return jsonify(
{
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"user_code": claims.user_code,
"csrf_token": claims.csrf_token,
"expires_at": claims.expires_at.isoformat(),
}
), 200
@bp.route("/oauth/device/approve-external", methods=["POST"])
@ -251,8 +254,8 @@ def approve_external():
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFound, InvalidTransition) as e:
logger.error("approve-external: state transition raced: %s", e)
except (StateNotFoundError, InvalidTransitionError) as e:
logger.exception("approve-external: state transition raced")
raise Conflict("state_lost") from e
_emit_approve_external_audit(state, claims, mint)
@ -264,9 +267,11 @@ def approve_external():
def _emit_approve_external_audit(state, claims, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved subject_type=%s "
"subject_email=%s subject_issuer=%s token_id=%s",
SubjectType.EXTERNAL_SSO, claims.subject_email, claims.subject_issuer, mint.token_id,
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
SubjectType.EXTERNAL_SSO,
claims.subject_email,
claims.subject_issuer,
mint.token_id,
extra={
"audit": True,
"event": "oauth.device_flow_approved",

View File

@ -1,8 +1,10 @@
"""POST /openapi/v1/apps/<app_id>/workflows/run — port of
service_api/app/workflow.py:WorkflowRunApi."""
from __future__ import annotations
import logging
from collections.abc import Mapping
from typing import Any, Literal
from flask import request
@ -128,5 +130,10 @@ class WorkflowRunApi(Resource):
if streaming:
return helper.compact_generate_response(response)
body_dict = response[0] if isinstance(response, tuple) else response
return WorkflowRunResponse.model_validate(body_dict).model_dump(mode="json"), 200
if isinstance(response, tuple):
body_dict: Any = response[0] # pyright: ignore[reportArgumentType]
else:
body_dict = response
if not isinstance(body_dict, Mapping):
raise InternalServerError("blocking generate returned non-mapping response")
return WorkflowRunResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200

View File

@ -5,8 +5,11 @@ Account bearers (dfoa_) see every tenant they're a member of. External
SSO bearers (dfoe_) have no account_id and so see an empty list that
matches /openapi/v1/account.
"""
from __future__ import annotations
from itertools import starmap
from flask import g
from flask_restx import Resource
from sqlalchemy import select
@ -37,7 +40,7 @@ class WorkspacesApi(Resource):
.order_by(Tenant.created_at.asc())
).all()
return {"workspaces": [_workspace_summary(t, m) for t, m in rows]}, 200
return {"workspaces": list(starmap(_workspace_summary, rows))}, 200
@openapi_ns.route("/workspaces/<string:workspace_id>")

View File

@ -685,6 +685,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
match invoke_from:
case InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
case InvokeFrom.OPENAPI:
created_from = WorkflowAppLogCreatedFrom.OPENAPI
case InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
case InvokeFrom.WEB_APP:

View File

@ -1,6 +1,7 @@
"""Bind the bearer authenticator at startup. Must run after ext_database
and ext_redis (needs both factories).
"""
from __future__ import annotations
from configs import dify_config

View File

@ -1,14 +1,15 @@
"""Device-flow security primitives: enterprise_only gate, approval-grant
cookie mint/verify/consume, and anti-framing headers.
"""
from __future__ import annotations
import logging
import secrets
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from functools import wraps
from typing import Callable
from flask import Blueprint
from werkzeug.exceptions import NotFound
@ -122,7 +123,10 @@ def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
return False
return bool(
redis_client.set(
NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS,
NONCE_KEY_FMT.format(nonce=nonce),
"1",
nx=True,
ex=NONCE_TTL_SECONDS,
)
)
@ -132,7 +136,10 @@ def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
return False
return bool(
redis_client.set(
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS,
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
"1",
nx=True,
ex=NONCE_TTL_SECONDS,
)
)
@ -183,7 +190,7 @@ def attach_anti_framing(bp: Blueprint) -> None:
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
@bp.after_request
def _apply_headers(response):
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
for name, value in _ANTI_FRAMING_HEADERS.items():
response.headers.setdefault(name, value)
return response

View File

@ -2,11 +2,13 @@
state envelope, external subject assertion, and approval-grant cookie
all three share one key-set so api enterprise can verify each other.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import jwt
from configs import dify_config
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
@ -32,14 +34,14 @@ class KeySet:
self._active_kid = active_kid
@classmethod
def from_shared_secret(cls) -> "KeySet":
def from_shared_secret(cls) -> KeySet:
secret = dify_config.SECRET_KEY
if not secret:
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
@classmethod
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> "KeySet":
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
return cls(entries, active_kid)
@property

View File

@ -4,17 +4,19 @@ To add a token kind: write a Resolver, add a SubjectType + Accepts member,
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
Authenticator + validate_bearer stay untouched.
"""
from __future__ import annotations
import hashlib
import json
import logging
import uuid
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from datetime import UTC, datetime
from enum import StrEnum
from functools import wraps
from typing import Callable, Iterable, Literal, Protocol
from typing import Literal, ParamSpec, Protocol, TypeVar
from flask import g, request
from sqlalchemy import update
@ -79,11 +81,11 @@ class TokenKind:
return token.startswith(self.prefix)
class InvalidBearer(Exception):
class InvalidBearerError(Exception):
"""Token missing, unknown prefix, or no live row."""
class TokenExpired(Exception):
class TokenExpiredError(Exception):
"""Hard-expire bookkeeping is the resolver's job before raising."""
@ -122,13 +124,17 @@ class BearerAuthenticator:
def __init__(self, registry: TokenKindRegistry) -> None:
self._registry = registry
@property
def registry(self) -> TokenKindRegistry:
return self._registry
def authenticate(self, token: str) -> AuthContext:
kind = self._registry.find(token)
if kind is None:
raise InvalidBearer("unknown token prefix")
raise InvalidBearerError("unknown token prefix")
row = kind.resolver.resolve(sha256_hex(token))
if row is None:
raise InvalidBearer("token unknown or revoked")
raise InvalidBearerError("token unknown or revoked")
return AuthContext(
subject_type=kind.subject_type,
subject_email=row.subject_email,
@ -165,7 +171,9 @@ class OAuthAccessTokenResolver:
positive_ttl: int = POSITIVE_TTL_SECONDS,
negative_ttl: int = NEGATIVE_TTL_SECONDS,
) -> None:
self._session_factory = session_factory
# session_factory and the cache helpers below are friend-API for
# _VariantResolver in this module — kept public-named on purpose.
self.session_factory = session_factory
self._redis = redis_client
self._positive_ttl = positive_ttl
self._negative_ttl = negative_ttl
@ -179,7 +187,7 @@ class OAuthAccessTokenResolver:
def _cache_key(self, token_hash: str) -> str:
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
def _cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
raw = self._redis.get(self._cache_key(token_hash))
if raw is None:
return None
@ -193,17 +201,17 @@ class OAuthAccessTokenResolver:
logger.warning("auth:token cache entry malformed; treating as miss")
return None
def _cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
self._redis.setex(
self._cache_key(token_hash),
self._positive_ttl,
json.dumps(_row_to_cache(row)),
)
def _cache_set_negative(self, token_hash: str) -> None:
def cache_set_negative(self, token_hash: str) -> None:
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
def _hard_expire(self, session: Session, row_id: uuid.UUID, token_hash: str) -> None:
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
"""Atomic CAS — only the worker that flips revoked_at emits audit;
replays are idempotent. Spec: tokens.md §Detection + hard-expire.
"""
@ -216,21 +224,22 @@ class OAuthAccessTokenResolver:
session.commit()
if result.rowcount == 1:
logger.warning(
"audit: %s token_id=%s", AUDIT_OAUTH_EXPIRED, row_id,
"audit: %s token_id=%s",
AUDIT_OAUTH_EXPIRED,
row_id,
extra={"audit": True, "token_id": str(row_id)},
)
self._redis.delete(self._cache_key(token_hash))
self._cache_set_negative(token_hash)
self.cache_set_negative(token_hash)
class _VariantResolver:
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
self._parent = parent
self._variant = variant
def resolve(self, token_hash: str) -> ResolvedRow | None:
cached = self._parent._cache_get(token_hash)
cached = self._parent.cache_get(token_hash)
if cached == "invalid":
return None
if cached is not None and not isinstance(cached, str):
@ -238,23 +247,24 @@ class _VariantResolver:
return None
return cached
# _session_factory returns Flask-SQLAlchemy's scoped_session, which is
# session_factory returns Flask-SQLAlchemy's scoped_session, which is
# request-bound and not a context manager; use it directly.
session = self._parent._session_factory()
session = self._parent.session_factory()
row = self._load_from_db(session, token_hash)
if row is None:
self._parent._cache_set_negative(token_hash)
self._parent.cache_set_negative(token_hash)
return None
now = datetime.now(UTC)
if row.expires_at is not None and row.expires_at <= now:
self._parent._hard_expire(session, row.id, token_hash)
self._parent.hard_expire(session, row.id, token_hash)
return None
if not self._matches_variant_model(row):
logger.error(
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
row.id, row.prefix,
row.id,
row.prefix,
)
return None
@ -265,7 +275,7 @@ class _VariantResolver:
token_id=uuid.UUID(str(row.id)),
expires_at=row.expires_at,
)
self._parent._cache_set_positive(token_hash, resolved)
self._parent.cache_set_positive(token_hash, resolved)
return resolved
def _matches_variant(self, row: ResolvedRow) -> bool:
@ -352,7 +362,11 @@ def _extract_bearer(req) -> str | None:
return value.strip()
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable:
_DP = ParamSpec("_DP")
_DR = TypeVar("_DR")
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
"""Opt-in: omitting it leaves the route unauthenticated.
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
@ -360,21 +374,19 @@ def validate_bearer(*, accept: frozenset[Accepts]) -> Callable:
and are rejected here as the wrong auth scheme for this surface.
"""
def wrap(fn: Callable) -> Callable:
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
@wraps(fn)
def inner(*args, **kwargs):
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
token = _extract_bearer(request)
if token is None:
raise Unauthorized("missing bearer token")
if _authenticator is None:
raise ServiceUnavailable(
"bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable"
)
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
try:
ctx = get_authenticator().authenticate(token)
except InvalidBearer as e:
except InvalidBearerError as e:
raise Unauthorized(str(e))
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
@ -388,17 +400,15 @@ def validate_bearer(*, accept: frozenset[Accepts]) -> Callable:
return wrap
def bearer_feature_required(fn: Callable) -> Callable:
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
without the authenticator, so fail fast instead of approving silently.
"""
@wraps(fn)
def inner(*args, **kwargs):
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.ENABLE_OAUTH_BEARER:
raise ServiceUnavailable(
"bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable"
)
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
return fn(*args, **kwargs)
return inner
@ -423,8 +433,7 @@ def require_scope(scope: str) -> Callable:
ctx = getattr(g, "auth_ctx", None)
if ctx is None:
raise RuntimeError(
"require_scope used without validate_bearer; "
"stack @validate_bearer above @require_scope"
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
)
if SCOPE_FULL not in ctx.scopes and scope not in ctx.scopes:
raise Forbidden(f"insufficient_scope: {scope}")
@ -442,22 +451,24 @@ def require_scope(scope: str) -> Callable:
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
return TokenKindRegistry([
TokenKind(
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({"full"}),
source="oauth_account",
resolver=oauth.for_account(),
),
TokenKind(
prefix="dfoe_",
subject_type=SubjectType.EXTERNAL_SSO,
scopes=frozenset({"apps:run"}),
source="oauth_external_sso",
resolver=oauth.for_external_sso(),
),
])
return TokenKindRegistry(
[
TokenKind(
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({"full"}),
source="oauth_account",
resolver=oauth.for_account(),
),
TokenKind(
prefix="dfoe_",
subject_type=SubjectType.EXTERNAL_SSO,
scopes=frozenset({"apps:run"}),
source="oauth_external_sso",
resolver=oauth.for_external_sso(),
),
]
)
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:

View File

@ -4,13 +4,15 @@ window Redis ZSET). Apply after auth decorators so scopes can read
in-handler. RFC-8628 ``slow_down`` is inline its response shape isn't
generic 429. Spec: docs/specs/v1.0/server/security.md.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from enum import StrEnum
from functools import wraps
from typing import Callable
from typing import ParamSpec, TypeVar
from flask import g, request, session
from werkzeug.exceptions import TooManyRequests
@ -81,13 +83,17 @@ def _build_limiter(spec: RateLimit) -> RateLimiter:
)
def rate_limit(spec: RateLimit) -> Callable:
_P = ParamSpec("_P")
_R = TypeVar("_R")
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Apply after auth decorators that the scopes read from."""
limiter = _build_limiter(spec)
def wrap(fn: Callable) -> Callable:
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def inner(*args, **kwargs):
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
key = _composite_key(spec.scopes)
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Optional
from typing import Any
import sqlalchemy as sa
from sqlalchemy import func
@ -97,9 +97,7 @@ class OAuthAccessToken(TypeBase):
"""
__tablename__ = "oauth_access_tokens"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),
)
__table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
@ -109,15 +107,11 @@ class OAuthAccessToken(TypeBase):
device_label: Mapped[str] = mapped_column(sa.Text, nullable=False)
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
subject_issuer: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True, default=None)
account_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True, default=None)
token_hash: Mapped[Optional[str]] = mapped_column(sa.String(64), nullable=True, default=None)
last_used_at: Mapped[Optional[datetime]] = mapped_column(
sa.DateTime(timezone=True), nullable=True, default=None
)
revoked_at: Mapped[Optional[datetime]] = mapped_column(
sa.DateTime(timezone=True), nullable=True, default=None
)
subject_issuer: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None)
last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False

View File

@ -1206,6 +1206,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
SERVICE_API = "service-api"
WEB_APP = "web-app"
INSTALLED_APP = "installed-app"
OPENAPI = "openapi"
@classmethod
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":

View File

@ -3,6 +3,7 @@
expired-but-never-presented rows have no hard-expire trigger both get
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
"""
from __future__ import annotations
import logging
@ -32,26 +33,22 @@ def clean_oauth_access_tokens_task():
candidates = or_(
OAuthAccessToken.revoked_at < cutoff,
# Zombies: expired but never re-presented, so middleware never flipped them.
(OAuthAccessToken.revoked_at.is_(None))
& (OAuthAccessToken.expires_at < cutoff),
(OAuthAccessToken.revoked_at.is_(None)) & (OAuthAccessToken.expires_at < cutoff),
)
total = 0
while True:
ids = db.session.scalars(
select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)
).all()
ids = db.session.scalars(select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)).all()
if not ids:
break
db.session.execute(
delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids))
)
db.session.execute(delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)))
db.session.commit()
total += len(ids)
end_at = time.perf_counter()
click.echo(click.style(
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d "
f"in {end_at - start_at:.2f}s",
fg="green",
))
click.echo(
click.style(
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d in {end_at - start_at:.2f}s",
fg="green",
)
)

View File

@ -2,6 +2,7 @@
(DB upsert + plaintext generation), and TTL policy. Specs:
docs/specs/v1.0/server/{device-flow.md, tokens.md}.
"""
from __future__ import annotations
import hashlib
@ -15,11 +16,12 @@ from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
from models.oauth import OAuthAccessToken
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, scoped_session
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
from models.oauth import OAuthAccessToken
logger = logging.getLogger(__name__)
@ -51,7 +53,7 @@ return raw
"""
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
USER_CODE_SEGMENT_LEN = 4
@ -95,7 +97,7 @@ class DeviceFlowState:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, raw: str) -> "DeviceFlowState":
def from_json(cls, raw: str) -> DeviceFlowState:
data = json.loads(raw)
if "status" in data:
data["status"] = DeviceFlowStatus(data["status"])
@ -114,20 +116,19 @@ def _random_user_code() -> str:
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
class StateNotFound(Exception):
class StateNotFoundError(Exception):
pass
class InvalidTransition(Exception):
class InvalidTransitionError(Exception):
pass
class UserCodeExhausted(Exception):
class UserCodeExhaustedError(Exception):
pass
class DeviceFlowRedis:
def __init__(self, redis_client) -> None:
self._redis = redis_client
self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA)
@ -157,7 +158,7 @@ class DeviceFlowRedis:
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
if ok:
return user_code
raise UserCodeExhausted("could not allocate a unique user_code in 5 attempts")
raise UserCodeExhaustedError("could not allocate a unique user_code in 5 attempts")
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
@ -180,7 +181,7 @@ class DeviceFlowRedis:
try:
return DeviceFlowState.from_json(text_)
except (ValueError, KeyError):
logger.error("device_flow: corrupt state for %s", device_code)
logger.exception("device_flow: corrupt state for %s", device_code)
return None
def approve(
@ -195,9 +196,9 @@ class DeviceFlowRedis:
) -> None:
state = self._load_state(device_code)
if state is None:
raise StateNotFound(device_code)
raise StateNotFoundError(device_code)
if state.status is not DeviceFlowStatus.PENDING:
raise InvalidTransition(f"cannot approve {state.status}")
raise InvalidTransitionError(f"cannot approve {state.status}")
state.status = DeviceFlowStatus.APPROVED
state.subject_email = subject_email
@ -213,9 +214,9 @@ class DeviceFlowRedis:
def deny(self, device_code: str) -> None:
state = self._load_state(device_code)
if state is None:
raise StateNotFound(device_code)
raise StateNotFoundError(device_code)
if state.status is not DeviceFlowStatus.PENDING:
raise InvalidTransition(f"cannot deny {state.status}")
raise InvalidTransitionError(f"cannot deny {state.status}")
state.status = DeviceFlowStatus.DENIED
self._redis.setex(
DEVICE_CODE_KEY_FMT.format(code=device_code),
@ -239,7 +240,7 @@ class DeviceFlowRedis:
try:
return DeviceFlowState.from_json(text_)
except (ValueError, KeyError):
logger.error("device_flow: corrupt state on consume %s", device_code)
logger.exception("device_flow: corrupt state on consume %s", device_code)
return None
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
@ -287,6 +288,7 @@ ACCOUNT_ISSUER_SENTINEL = "dify:account"
@dataclass(frozen=True, slots=True)
class MintResult:
"""Plaintext token surfaces to the caller once."""
token: str
token_id: uuid.UUID
expires_at: datetime
@ -308,7 +310,9 @@ def sha256_hex(token: str) -> str:
def mint_oauth_token(
session: Session,
# Accept either Session or Flask-SQLAlchemy's request-scoped wrapper —
# the wrapper proxies the same execute/commit surface.
session: Session | scoped_session,
redis_client,
*,
subject_email: str,
@ -328,9 +332,7 @@ def mint_oauth_token(
# Account flow always writes the sentinel — caller may pass None
# (for clarity) or the sentinel itself; nothing else is valid.
if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL):
raise ValueError(
f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}"
)
raise ValueError(f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}")
subject_issuer = ACCOUNT_ISSUER_SENTINEL
elif prefix == PREFIX_OAUTH_EXTERNAL_SSO:
# Defense in depth: enterprise canonicalises + rejects empty,
@ -363,7 +365,7 @@ def mint_oauth_token(
def _upsert(
session: Session,
session: Session | scoped_session,
*,
subject_email: str,
subject_issuer: str | None,
@ -415,6 +417,8 @@ def _upsert(
row = session.execute(upsert_stmt).first()
session.commit()
if row is None:
raise RuntimeError("oauth_token upsert returned no row")
token_id = uuid.UUID(str(row.id))
return UpsertOutcome(
token_id=token_id,
@ -449,7 +453,9 @@ def oauth_ttl_days(tenant_id: str | None = None) -> int:
except ValueError:
logger.warning(
"%s=%r is not an int; falling back to %d",
_TTL_ENV_VAR, raw, DEFAULT_OAUTH_TTL_DAYS,
_TTL_ENV_VAR,
raw,
DEFAULT_OAUTH_TTL_DAYS,
)
return DEFAULT_OAUTH_TTL_DAYS
if value < MIN_TTL_DAYS:

View File

@ -38,12 +38,7 @@ def test_membership_strategy_uses_join_lookup(member):
def test_membership_strategy_rejects_external_sso():
assert (
MembershipStrategy().authorize(
_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)
)
is False
)
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
def test_app_authz_check_raises_when_strategy_denies():

View File

@ -1,4 +1,5 @@
"""User-scoped identity + session endpoints under /openapi/v1/account."""
import builtins
import pytest

View File

@ -6,8 +6,10 @@ from flask_restx import Api
def _client():
from controllers.openapi import app_info # noqa: F401
from controllers.openapi import openapi_ns
from controllers.openapi import (
app_info, # noqa: F401
openapi_ns,
)
app = Flask(__name__)
api = Api(app)

View File

@ -6,8 +6,10 @@ from flask_restx import Api
def _client():
from controllers.openapi import chat_messages # noqa: F401
from controllers.openapi import openapi_ns
from controllers.openapi import (
chat_messages, # noqa: F401
openapi_ns,
)
app = Flask(__name__)
api = Api(app)
@ -32,12 +34,11 @@ def test_chat_dispatches_and_returns_response_model(svc, bypass_pipeline):
200,
)
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch(
"controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()
with (
patch("controllers.openapi.chat_messages._unpack_app", return_value=fake),
patch("controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()),
):
r = _client().post(
"/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}}
)
r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}})
assert r.status_code == 200
body = r.get_json()
assert body["conversation_id"] == "c1"
@ -62,8 +63,9 @@ def test_chat_strips_user_field_from_body(svc, bypass_pipeline):
200,
)
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch(
"controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()
with (
patch("controllers.openapi.chat_messages._unpack_app", return_value=fake),
patch("controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()),
):
_client().post(
"/openapi/v1/apps/app1/chat-messages",
@ -76,9 +78,7 @@ def test_chat_strips_user_field_from_body(svc, bypass_pipeline):
def test_chat_rejects_non_chat_mode(bypass_pipeline):
fake = SimpleNamespace(mode="completion")
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake):
r = _client().post(
"/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}}
)
r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}})
assert r.status_code in (400, 403)

View File

@ -6,8 +6,10 @@ from flask_restx import Api
def _client():
from controllers.openapi import completion_messages # noqa: F401
from controllers.openapi import openapi_ns
from controllers.openapi import (
completion_messages, # noqa: F401
openapi_ns,
)
app = Flask(__name__)
api = Api(app)
@ -31,8 +33,9 @@ def test_completion_returns_response_model(svc, bypass_pipeline):
200,
)
fake = SimpleNamespace(mode="completion", id="app1", tenant_id="t1")
with patch("controllers.openapi.completion_messages._unpack_app", return_value=fake), patch(
"controllers.openapi.completion_messages._unpack_caller", return_value=SimpleNamespace()
with (
patch("controllers.openapi.completion_messages._unpack_app", return_value=fake),
patch("controllers.openapi.completion_messages._unpack_caller", return_value=SimpleNamespace()),
):
r = _client().post(
"/openapi/v1/apps/app1/completion-messages",

View File

@ -7,9 +7,9 @@ Tests use a fresh Blueprint + Flask-CORS per case because the production
blueprint is a module-level singleton and can't be reconfigured once
registered.
"""
import builtins
import pytest
from flask import Blueprint, Flask
from flask.views import MethodView
from flask_cors import CORS

View File

@ -1,4 +1,5 @@
"""Account-branch device-flow approve/deny under /openapi/v1."""
import builtins
import pytest

View File

@ -4,6 +4,7 @@ authorization endpoint.
Tests verify URL routing without invoking the handler invoking would
require Redis, which the unit-test runtime does not initialise.
"""
import builtins
import pytest
@ -31,16 +32,12 @@ def test_openapi_route_registered(openapi_app: Flask):
def test_route_dispatches_to_class(openapi_app: Flask):
rule = next(
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code"
)
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceCodeApi
def test_route_accepts_post(openapi_app: Flask):
rule = next(
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code"
)
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
assert "POST" in rule.methods

View File

@ -1,4 +1,5 @@
"""GET /openapi/v1/oauth/device/lookup is the canonical user-code lookup."""
import builtins
import pytest
@ -26,14 +27,10 @@ def test_openapi_route_registered(openapi_app: Flask):
def test_route_dispatches_to_class(openapi_app: Flask):
rule = next(
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup"
)
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceLookupApi
def test_route_accepts_get(openapi_app: Flask):
rule = next(
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup"
)
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
assert "GET" in rule.methods

View File

@ -1,4 +1,5 @@
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""
import builtins
import pytest

View File

@ -1,4 +1,5 @@
"""POST /openapi/v1/oauth/device/token is the canonical poll endpoint."""
import builtins
import pytest
@ -26,7 +27,5 @@ def test_openapi_route_registered(openapi_app: Flask):
def test_route_dispatches_to_class(openapi_app: Flask):
rule = next(
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token"
)
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token")
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceTokenApi

View File

@ -6,8 +6,10 @@ from flask_restx import Api
def _client():
from controllers.openapi import openapi_ns
from controllers.openapi import workflow_run # noqa: F401
from controllers.openapi import (
openapi_ns,
workflow_run, # noqa: F401
)
app = Flask(__name__)
api = Api(app)
@ -32,8 +34,9 @@ def test_workflow_run_returns_response_model(svc, bypass_pipeline):
200,
)
fake = SimpleNamespace(mode="workflow", id="app1", tenant_id="t1")
with patch("controllers.openapi.workflow_run._unpack_app", return_value=fake), patch(
"controllers.openapi.workflow_run._unpack_caller", return_value=SimpleNamespace()
with (
patch("controllers.openapi.workflow_run._unpack_app", return_value=fake),
patch("controllers.openapi.workflow_run._unpack_caller", return_value=SimpleNamespace()),
):
r = _client().post("/openapi/v1/apps/app1/workflows/run", json={"inputs": {"x": 1}})
assert r.status_code == 200

View File

@ -2,6 +2,7 @@
list + member-gated detail. No legacy /v1/ equivalent the cookie-authed
/console/api/workspaces is a separate consumer that stays in console.
"""
import builtins
import pytest

View File

@ -2,6 +2,7 @@
Tests use a fake auth_ctx attached directly to flask.g no
authenticator wiring needed.
"""
from __future__ import annotations
import uuid