mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 04:36:31 +08:00
chore(api): pyright + ruff cleanup for openapi/cli surface
Type and lint pass over the openapi controllers, auth pipeline, and
oauth bearer/device-flow plumbing. Down from 36 pyright errors and 16
ruff errors to 0/0; 93 openapi unit tests pass.
Logic fixes:
- libs/oauth_bearer.py: drop private-naming on the friend-API methods
consumed by _VariantResolver (cache_get / cache_set_positive /
cache_set_negative / hard_expire / session_factory). They were always
cross-class accessors — leading underscore was misleading. Add public
registry property on BearerAuthenticator. _hard_expire row_id widened
to UUID | str (matches the StringUUID column type).
- libs/oauth_bearer.py: type validate_bearer / bearer_feature_required
with ParamSpec / PEP-695 so wrapped routes preserve their signature.
- libs/rate_limit.py: same — typed rate_limit decorator.
- services/oauth_device_flow.py: mint_oauth_token / _upsert accept
Session | scoped_session (Flask-SQLAlchemy proxy). Guard row-is-None
after upsert.
- controllers/openapi/{chat,completion,workflow}_messages.py: tuple-vs-
Mapping shape narrowing on AppGenerateService.generate return —
production returns Mapping, tests mock as (body, status). Validate
through Pydantic Response model in both shapes.
- controllers/openapi/oauth_device.py: replace flask_restx.reqparse (banned)
with Pydantic Request/Query models — DeviceCodeRequest, DevicePollRequest,
DeviceLookupQuery, DeviceMutateRequest. Two PEP-695 generic helpers
(_validate_json / _validate_query) translate ValidationError to BadRequest.
- controllers/openapi/auth/strategies.py: Protocol param-name match
(subject_type), Optional narrowing on app/tenant/account_id/subject_email.
- controllers/openapi/auth/steps.py: subject_type-is-None guard before
mounter dispatch.
- core/app/apps/workflow/generate_task_pipeline.py + models/workflow.py:
add WorkflowAppLogCreatedFrom.OPENAPI + matching match-case branch.
Fixes match-exhaustiveness and possibly-unbound created_from.
- libs/device_flow_security.py: pyright ignore on flask after_request
hook (registered by the framework, pyright sees as unused).
- services/oauth_device_flow.py: rename Exceptions to *Error suffix
(StateNotFoundError / InvalidTransitionError / UserCodeExhaustedError);
same for libs/oauth_bearer.py (InvalidBearerError / TokenExpiredError).
Update all callers across openapi controllers.
- controllers/openapi/{oauth_device,oauth_device_sso}.py +
services/oauth_device_flow.py: switch logger.error in except blocks
to logger.exception (TRY400) — keeps the traceback for ops.
- configs/feature/__init__.py: OPENAPI_KNOWN_CLIENT_IDS computed_field
needs an @property alongside for pyright to see it as a value, not a
method. Matches the existing line-451 pattern.
Plus ruff format + import-sort across the openapi tree (pure formatting).
This commit is contained in:
parent
b083c910b3
commit
8a62c1d915
@ -523,7 +523,8 @@ class HttpConfig(BaseSettings):
|
||||
default="difyctl",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
identity read; /account/sessions and /account/sessions/<id> manage
|
||||
the user's active OAuth tokens.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
@ -16,9 +17,9 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
TOKEN_CACHE_KEY_FMT,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
TOKEN_CACHE_KEY_FMT,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
@ -51,8 +52,7 @@ class AccountApi(Resource):
|
||||
}
|
||||
|
||||
account = (
|
||||
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none()
|
||||
if ctx.account_id else None
|
||||
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() if ctx.account_id else None
|
||||
)
|
||||
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
@ -129,8 +129,7 @@ class AccountSessionByIdApi(Resource):
|
||||
# Subject-match guard. 404 (not 403) on cross-subject so the
|
||||
# endpoint doesn't leak token IDs that belong to other subjects.
|
||||
owns = db.session.execute(
|
||||
select(OAuthAccessToken.id)
|
||||
.where(
|
||||
select(OAuthAccessToken.id).where(
|
||||
and_(
|
||||
OAuthAccessToken.id == session_id,
|
||||
*_subject_match(ctx),
|
||||
@ -160,8 +159,7 @@ def _subject_match(ctx: AuthContext) -> tuple:
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; "
|
||||
"use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""GET /openapi/v1/apps/<app_id>/info — port of service_api/app/app.py:AppInfoApi."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
Endpoints attach via @APP_PIPELINE.guard(scope=…). No alternative paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
@ -4,6 +4,7 @@ Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
@ -3,21 +3,22 @@
|
||||
BearerCheck is the only step that touches the token registry; downstream
|
||||
steps see only the populated Context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import TokenExpired, get_authenticator, sha256_hex
|
||||
from libs.oauth_bearer import TokenExpiredError, get_authenticator, sha256_hex
|
||||
from models import App, Tenant, TenantStatus
|
||||
|
||||
|
||||
def _registry():
|
||||
return get_authenticator()._registry # noqa: SLF001
|
||||
return get_authenticator().registry
|
||||
|
||||
|
||||
def _extract_bearer(req) -> str | None:
|
||||
@ -45,7 +46,7 @@ class BearerCheck:
|
||||
|
||||
try:
|
||||
row = kind.resolver.resolve(_hash_token(token))
|
||||
except TokenExpired:
|
||||
except TokenExpiredError:
|
||||
raise Unauthorized("token expired")
|
||||
if row is None:
|
||||
raise Unauthorized("invalid bearer")
|
||||
@ -105,6 +106,8 @@ class CallerMount:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.subject_type):
|
||||
m.mount(ctx)
|
||||
|
||||
@ -4,6 +4,7 @@ App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
@ -33,9 +34,11 @@ class AclStrategy:
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_email is None or ctx.app is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=ctx.subject_email,
|
||||
app_id=ctx.app.id,
|
||||
app_id=ctx.app.id, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
@ -50,7 +53,9 @@ class MembershipStrategy:
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return _has_tenant_membership(ctx.account_id, ctx.tenant.id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool:
|
||||
@ -67,8 +72,8 @@ def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool:
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user) # noqa: SLF001
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # noqa: SLF001
|
||||
current_app.login_manager._update_request_context_with_user(user)
|
||||
user_logged_in.send(current_app._get_current_object(), user=user)
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
@ -78,25 +83,31 @@ class CallerMounter(Protocol):
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, st: SubjectType) -> bool:
|
||||
return st == SubjectType.ACCOUNT
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = db.session.get(Account, ctx.account_id)
|
||||
account.current_tenant = ctx.tenant
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.tenant # type: ignore[assignment]
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, st: SubjectType) -> bool:
|
||||
return st == SubjectType.EXTERNAL_SSO
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
tenant_id=ctx.tenant.id, # type: ignore[attr-defined]
|
||||
app_id=ctx.app.id, # type: ignore[attr-defined]
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
|
||||
@ -8,9 +8,11 @@ Differences from service_api:
|
||||
- Typed Request and Response models.
|
||||
- invoke_from = InvokeFrom.OPENAPI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
@ -163,5 +165,12 @@ class ChatMessagesApi(Resource):
|
||||
if streaming:
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
body_dict = response[0] if isinstance(response, tuple) else response
|
||||
return ChatMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200
|
||||
# Some upstream paths (and tests) return (body, status); production
|
||||
# generate returns Mapping. Accept both, then validate.
|
||||
if isinstance(response, tuple):
|
||||
body_dict: Any = response[0] # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
body_dict = response
|
||||
if not isinstance(body_dict, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return ChatMessageResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/completion-messages — port of
|
||||
service_api/app/completion.py:CompletionApi."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
@ -120,5 +122,10 @@ class CompletionMessagesApi(Resource):
|
||||
if streaming:
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
body_dict = response[0] if isinstance(response, tuple) else response
|
||||
return CompletionMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200
|
||||
if isinstance(response, tuple):
|
||||
body_dict: Any = response[0] # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
body_dict = response
|
||||
if not isinstance(body_dict, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return CompletionMessageResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200
|
||||
|
||||
@ -12,13 +12,16 @@ sub-groups in one module:
|
||||
|
||||
SSO branch lives in oauth_device_sso.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -41,9 +44,9 @@ from services.oauth_device_flow import (
|
||||
PREFIX_OAUTH_ACCOUNT,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransition,
|
||||
InvalidTransitionError,
|
||||
SlowDownDecision,
|
||||
StateNotFound,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
@ -52,22 +55,41 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Parsers
|
||||
# Request / query schemas
|
||||
# =========================================================================
|
||||
|
||||
_code_parser = reqparse.RequestParser()
|
||||
_code_parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
_code_parser.add_argument("device_label", type=str, required=True, location="json")
|
||||
|
||||
_poll_parser = reqparse.RequestParser()
|
||||
_poll_parser.add_argument("device_code", type=str, required=True, location="json")
|
||||
_poll_parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
class DeviceCodeRequest(BaseModel):
|
||||
client_id: str
|
||||
device_label: str
|
||||
|
||||
_lookup_parser = reqparse.RequestParser()
|
||||
_lookup_parser.add_argument("user_code", type=str, required=True, location="args")
|
||||
|
||||
_mutate_parser = reqparse.RequestParser()
|
||||
_mutate_parser.add_argument("user_code", type=str, required=True, location="json")
|
||||
class DevicePollRequest(BaseModel):
|
||||
device_code: str
|
||||
client_id: str
|
||||
|
||||
|
||||
class DeviceLookupQuery(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class DeviceMutateRequest(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||
try:
|
||||
return model.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
@ -79,9 +101,9 @@ _mutate_parser.add_argument("user_code", type=str, required=True, location="json
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
args = _code_parser.parse_args()
|
||||
client_id = args["client_id"]
|
||||
device_label = args["device_label"]
|
||||
payload = _validate_json(DeviceCodeRequest)
|
||||
client_id = payload.client_id
|
||||
device_label = payload.device_label
|
||||
|
||||
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
@ -104,8 +126,8 @@ class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
def post(self):
|
||||
args = _poll_parser.parse_args()
|
||||
device_code = args["device_code"]
|
||||
payload = _validate_json(DevicePollRequest)
|
||||
device_code = payload.device_code
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
@ -145,8 +167,8 @@ class OAuthDeviceLookupApi(Resource):
|
||||
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
args = _lookup_parser.parse_args()
|
||||
user_code = args["user_code"].strip().upper()
|
||||
payload = _validate_query(DeviceLookupQuery)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
@ -181,8 +203,8 @@ class DeviceApproveApi(Resource):
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
args = _mutate_parser.parse_args()
|
||||
user_code = args["user_code"].strip().upper()
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
@ -226,10 +248,10 @@ class DeviceApproveApi(Resource):
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFound, InvalidTransition) as e:
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.error("device_flow: approve raced on %s: %s", device_code, e)
|
||||
logger.exception("device_flow: approve raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
@ -246,8 +268,8 @@ class DeviceDenyApi(Resource):
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
args = _mutate_parser.parse_args()
|
||||
user_code = args["user_code"].strip().upper()
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
@ -259,8 +281,8 @@ class DeviceDenyApi(Resource):
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFound, InvalidTransition) as e:
|
||||
logger.error("device_flow: deny raced on %s: %s", device_code, e)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
logger.exception("device_flow: deny raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
@ -284,7 +306,9 @@ def _audit_cross_ip_if_needed(state) -> None:
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id, state.created_ip, poll_ip,
|
||||
state.token_id,
|
||||
state.created_ip,
|
||||
poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
@ -299,16 +323,14 @@ def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
handler doesn't re-query accounts/tenants for authz data.
|
||||
"""
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
rows = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.account_id == account.id)
|
||||
.all()
|
||||
)
|
||||
workspaces = [
|
||||
{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")}
|
||||
for t, m in rows
|
||||
]
|
||||
workspaces = [{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w["id"] == str(tenant) for w in workspaces):
|
||||
@ -335,7 +357,11 @@ def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id, account.email, state.client_id, state.device_label, mint.expires_at,
|
||||
mint.token_id,
|
||||
account.email,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
@ -355,7 +381,8 @@ def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id, state.device_label,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
@ -363,5 +390,3 @@ def _emit_deny_audit(state) -> None:
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ EE-only. Browser flow:
|
||||
Function-based (raw @bp.route) rather than Resource classes because the
|
||||
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@ -51,8 +52,8 @@ from services.oauth_device_flow import (
|
||||
PREFIX_OAUTH_EXTERNAL_SSO,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransition,
|
||||
StateNotFound,
|
||||
InvalidTransitionError,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
@ -171,13 +172,15 @@ def approval_context():
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify({
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}), 200
|
||||
return jsonify(
|
||||
{
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@ -251,8 +254,8 @@ def approve_external():
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFound, InvalidTransition) as e:
|
||||
logger.error("approve-external: state transition raced: %s", e)
|
||||
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||
logger.exception("approve-external: state transition raced")
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
@ -264,9 +267,11 @@ def approve_external():
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s "
|
||||
"subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO, claims.subject_email, claims.subject_issuer, mint.token_id,
|
||||
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/workflows/run — port of
|
||||
service_api/app/workflow.py:WorkflowRunApi."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
@ -128,5 +130,10 @@ class WorkflowRunApi(Resource):
|
||||
if streaming:
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
body_dict = response[0] if isinstance(response, tuple) else response
|
||||
return WorkflowRunResponse.model_validate(body_dict).model_dump(mode="json"), 200
|
||||
if isinstance(response, tuple):
|
||||
body_dict: Any = response[0] # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
body_dict = response
|
||||
if not isinstance(body_dict, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return WorkflowRunResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200
|
||||
|
||||
@ -5,8 +5,11 @@ Account bearers (dfoa_) see every tenant they're a member of. External
|
||||
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||
matches /openapi/v1/account.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
|
||||
from flask import g
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
@ -37,7 +40,7 @@ class WorkspacesApi(Resource):
|
||||
.order_by(Tenant.created_at.asc())
|
||||
).all()
|
||||
|
||||
return {"workspaces": [_workspace_summary(t, m) for t, m in rows]}, 200
|
||||
return {"workspaces": list(starmap(_workspace_summary, rows))}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
|
||||
@ -685,6 +685,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
match invoke_from:
|
||||
case InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
case InvokeFrom.OPENAPI:
|
||||
created_from = WorkflowAppLogCreatedFrom.OPENAPI
|
||||
case InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
case InvokeFrom.WEB_APP:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Bind the bearer authenticator at startup. Must run after ext_database
|
||||
and ext_redis (needs both factories).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
"""Device-flow security primitives: enterprise_only gate, approval-grant
|
||||
cookie mint/verify/consume, and anti-framing headers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from flask import Blueprint
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -122,7 +123,10 @@ def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS,
|
||||
NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
@ -132,7 +136,10 @@ def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS,
|
||||
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
@ -183,7 +190,7 @@ def attach_anti_framing(bp: Blueprint) -> None:
|
||||
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
|
||||
|
||||
@bp.after_request
|
||||
def _apply_headers(response):
|
||||
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
|
||||
for name, value in _ANTI_FRAMING_HEADERS.items():
|
||||
response.headers.setdefault(name, value)
|
||||
return response
|
||||
|
||||
@ -2,11 +2,13 @@
|
||||
state envelope, external subject assertion, and approval-grant cookie —
|
||||
all three share one key-set so api ↔ enterprise can verify each other.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
|
||||
@ -32,14 +34,14 @@ class KeySet:
|
||||
self._active_kid = active_kid
|
||||
|
||||
@classmethod
|
||||
def from_shared_secret(cls) -> "KeySet":
|
||||
def from_shared_secret(cls) -> KeySet:
|
||||
secret = dify_config.SECRET_KEY
|
||||
if not secret:
|
||||
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
|
||||
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
|
||||
|
||||
@classmethod
|
||||
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> "KeySet":
|
||||
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
|
||||
return cls(entries, active_kid)
|
||||
|
||||
@property
|
||||
|
||||
@ -4,17 +4,19 @@ To add a token kind: write a Resolver, add a SubjectType + Accepts member,
|
||||
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
|
||||
Authenticator + validate_bearer stay untouched.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import Callable, Iterable, Literal, Protocol
|
||||
from typing import Literal, ParamSpec, Protocol, TypeVar
|
||||
|
||||
from flask import g, request
|
||||
from sqlalchemy import update
|
||||
@ -79,11 +81,11 @@ class TokenKind:
|
||||
return token.startswith(self.prefix)
|
||||
|
||||
|
||||
class InvalidBearer(Exception):
|
||||
class InvalidBearerError(Exception):
|
||||
"""Token missing, unknown prefix, or no live row."""
|
||||
|
||||
|
||||
class TokenExpired(Exception):
|
||||
class TokenExpiredError(Exception):
|
||||
"""Hard-expire bookkeeping is the resolver's job before raising."""
|
||||
|
||||
|
||||
@ -122,13 +124,17 @@ class BearerAuthenticator:
|
||||
def __init__(self, registry: TokenKindRegistry) -> None:
|
||||
self._registry = registry
|
||||
|
||||
@property
|
||||
def registry(self) -> TokenKindRegistry:
|
||||
return self._registry
|
||||
|
||||
def authenticate(self, token: str) -> AuthContext:
|
||||
kind = self._registry.find(token)
|
||||
if kind is None:
|
||||
raise InvalidBearer("unknown token prefix")
|
||||
raise InvalidBearerError("unknown token prefix")
|
||||
row = kind.resolver.resolve(sha256_hex(token))
|
||||
if row is None:
|
||||
raise InvalidBearer("token unknown or revoked")
|
||||
raise InvalidBearerError("token unknown or revoked")
|
||||
return AuthContext(
|
||||
subject_type=kind.subject_type,
|
||||
subject_email=row.subject_email,
|
||||
@ -165,7 +171,9 @@ class OAuthAccessTokenResolver:
|
||||
positive_ttl: int = POSITIVE_TTL_SECONDS,
|
||||
negative_ttl: int = NEGATIVE_TTL_SECONDS,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
# session_factory and the cache helpers below are friend-API for
|
||||
# _VariantResolver in this module — kept public-named on purpose.
|
||||
self.session_factory = session_factory
|
||||
self._redis = redis_client
|
||||
self._positive_ttl = positive_ttl
|
||||
self._negative_ttl = negative_ttl
|
||||
@ -179,7 +187,7 @@ class OAuthAccessTokenResolver:
|
||||
def _cache_key(self, token_hash: str) -> str:
|
||||
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||
|
||||
def _cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
|
||||
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
|
||||
raw = self._redis.get(self._cache_key(token_hash))
|
||||
if raw is None:
|
||||
return None
|
||||
@ -193,17 +201,17 @@ class OAuthAccessTokenResolver:
|
||||
logger.warning("auth:token cache entry malformed; treating as miss")
|
||||
return None
|
||||
|
||||
def _cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
|
||||
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
|
||||
self._redis.setex(
|
||||
self._cache_key(token_hash),
|
||||
self._positive_ttl,
|
||||
json.dumps(_row_to_cache(row)),
|
||||
)
|
||||
|
||||
def _cache_set_negative(self, token_hash: str) -> None:
|
||||
def cache_set_negative(self, token_hash: str) -> None:
|
||||
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
|
||||
|
||||
def _hard_expire(self, session: Session, row_id: uuid.UUID, token_hash: str) -> None:
|
||||
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
|
||||
"""Atomic CAS — only the worker that flips revoked_at emits audit;
|
||||
replays are idempotent. Spec: tokens.md §Detection + hard-expire.
|
||||
"""
|
||||
@ -216,21 +224,22 @@ class OAuthAccessTokenResolver:
|
||||
session.commit()
|
||||
if result.rowcount == 1:
|
||||
logger.warning(
|
||||
"audit: %s token_id=%s", AUDIT_OAUTH_EXPIRED, row_id,
|
||||
"audit: %s token_id=%s",
|
||||
AUDIT_OAUTH_EXPIRED,
|
||||
row_id,
|
||||
extra={"audit": True, "token_id": str(row_id)},
|
||||
)
|
||||
self._redis.delete(self._cache_key(token_hash))
|
||||
self._cache_set_negative(token_hash)
|
||||
self.cache_set_negative(token_hash)
|
||||
|
||||
|
||||
class _VariantResolver:
|
||||
|
||||
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
|
||||
self._parent = parent
|
||||
self._variant = variant
|
||||
|
||||
def resolve(self, token_hash: str) -> ResolvedRow | None:
|
||||
cached = self._parent._cache_get(token_hash)
|
||||
cached = self._parent.cache_get(token_hash)
|
||||
if cached == "invalid":
|
||||
return None
|
||||
if cached is not None and not isinstance(cached, str):
|
||||
@ -238,23 +247,24 @@ class _VariantResolver:
|
||||
return None
|
||||
return cached
|
||||
|
||||
# _session_factory returns Flask-SQLAlchemy's scoped_session, which is
|
||||
# session_factory returns Flask-SQLAlchemy's scoped_session, which is
|
||||
# request-bound and not a context manager; use it directly.
|
||||
session = self._parent._session_factory()
|
||||
session = self._parent.session_factory()
|
||||
row = self._load_from_db(session, token_hash)
|
||||
if row is None:
|
||||
self._parent._cache_set_negative(token_hash)
|
||||
self._parent.cache_set_negative(token_hash)
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
if row.expires_at is not None and row.expires_at <= now:
|
||||
self._parent._hard_expire(session, row.id, token_hash)
|
||||
self._parent.hard_expire(session, row.id, token_hash)
|
||||
return None
|
||||
|
||||
if not self._matches_variant_model(row):
|
||||
logger.error(
|
||||
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
|
||||
row.id, row.prefix,
|
||||
row.id,
|
||||
row.prefix,
|
||||
)
|
||||
return None
|
||||
|
||||
@ -265,7 +275,7 @@ class _VariantResolver:
|
||||
token_id=uuid.UUID(str(row.id)),
|
||||
expires_at=row.expires_at,
|
||||
)
|
||||
self._parent._cache_set_positive(token_hash, resolved)
|
||||
self._parent.cache_set_positive(token_hash, resolved)
|
||||
return resolved
|
||||
|
||||
def _matches_variant(self, row: ResolvedRow) -> bool:
|
||||
@ -352,7 +362,11 @@ def _extract_bearer(req) -> str | None:
|
||||
return value.strip()
|
||||
|
||||
|
||||
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable:
|
||||
_DP = ParamSpec("_DP")
|
||||
_DR = TypeVar("_DR")
|
||||
|
||||
|
||||
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
|
||||
"""Opt-in: omitting it leaves the route unauthenticated.
|
||||
|
||||
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
|
||||
@ -360,21 +374,19 @@ def validate_bearer(*, accept: frozenset[Accepts]) -> Callable:
|
||||
and are rejected here as the wrong auth scheme for this surface.
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable) -> Callable:
|
||||
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
|
||||
token = _extract_bearer(request)
|
||||
if token is None:
|
||||
raise Unauthorized("missing bearer token")
|
||||
|
||||
if _authenticator is None:
|
||||
raise ServiceUnavailable(
|
||||
"bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable"
|
||||
)
|
||||
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||
|
||||
try:
|
||||
ctx = get_authenticator().authenticate(token)
|
||||
except InvalidBearer as e:
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
|
||||
@ -388,17 +400,15 @@ def validate_bearer(*, accept: frozenset[Accepts]) -> Callable:
|
||||
return wrap
|
||||
|
||||
|
||||
def bearer_feature_required(fn: Callable) -> Callable:
|
||||
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
|
||||
without the authenticator, so fail fast instead of approving silently.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.ENABLE_OAUTH_BEARER:
|
||||
raise ServiceUnavailable(
|
||||
"bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable"
|
||||
)
|
||||
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
@ -423,8 +433,7 @@ def require_scope(scope: str) -> Callable:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"require_scope used without validate_bearer; "
|
||||
"stack @validate_bearer above @require_scope"
|
||||
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
|
||||
)
|
||||
if SCOPE_FULL not in ctx.scopes and scope not in ctx.scopes:
|
||||
raise Forbidden(f"insufficient_scope: {scope}")
|
||||
@ -442,22 +451,24 @@ def require_scope(scope: str) -> Callable:
|
||||
|
||||
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
|
||||
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
|
||||
return TokenKindRegistry([
|
||||
TokenKind(
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({"full"}),
|
||||
source="oauth_account",
|
||||
resolver=oauth.for_account(),
|
||||
),
|
||||
TokenKind(
|
||||
prefix="dfoe_",
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
scopes=frozenset({"apps:run"}),
|
||||
source="oauth_external_sso",
|
||||
resolver=oauth.for_external_sso(),
|
||||
),
|
||||
])
|
||||
return TokenKindRegistry(
|
||||
[
|
||||
TokenKind(
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({"full"}),
|
||||
source="oauth_account",
|
||||
resolver=oauth.for_account(),
|
||||
),
|
||||
TokenKind(
|
||||
prefix="dfoe_",
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
scopes=frozenset({"apps:run"}),
|
||||
source="oauth_external_sso",
|
||||
resolver=oauth.for_external_sso(),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
|
||||
|
||||
@ -4,13 +4,15 @@ window Redis ZSET). Apply after auth decorators so scopes can read
|
||||
in-handler. RFC-8628 ``slow_down`` is inline — its response shape isn't
|
||||
generic 429. Spec: docs/specs/v1.0/server/security.md.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import g, request, session
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
@ -81,13 +83,17 @@ def _build_limiter(spec: RateLimit) -> RateLimiter:
|
||||
)
|
||||
|
||||
|
||||
def rate_limit(spec: RateLimit) -> Callable:
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||
"""Apply after auth decorators that the scopes read from."""
|
||||
limiter = _build_limiter(spec)
|
||||
|
||||
def wrap(fn: Callable) -> Callable:
|
||||
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
key = _composite_key(spec.scopes)
|
||||
if limiter.is_rate_limited(key):
|
||||
raise TooManyRequests("rate_limited")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func
|
||||
@ -97,9 +97,7 @@ class OAuthAccessToken(TypeBase):
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_access_tokens"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),
|
||||
)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
@ -109,15 +107,11 @@ class OAuthAccessToken(TypeBase):
|
||||
device_label: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
|
||||
subject_issuer: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
account_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
token_hash: Mapped[Optional[str]] = mapped_column(sa.String(64), nullable=True, default=None)
|
||||
last_used_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
revoked_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
subject_issuer: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None)
|
||||
last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
|
||||
|
||||
@ -1206,6 +1206,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
INSTALLED_APP = "installed-app"
|
||||
OPENAPI = "openapi"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
expired-but-never-presented rows have no hard-expire trigger — both get
|
||||
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@ -32,26 +33,22 @@ def clean_oauth_access_tokens_task():
|
||||
candidates = or_(
|
||||
OAuthAccessToken.revoked_at < cutoff,
|
||||
# Zombies: expired but never re-presented, so middleware never flipped them.
|
||||
(OAuthAccessToken.revoked_at.is_(None))
|
||||
& (OAuthAccessToken.expires_at < cutoff),
|
||||
(OAuthAccessToken.revoked_at.is_(None)) & (OAuthAccessToken.expires_at < cutoff),
|
||||
)
|
||||
|
||||
total = 0
|
||||
while True:
|
||||
ids = db.session.scalars(
|
||||
select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)
|
||||
).all()
|
||||
ids = db.session.scalars(select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)).all()
|
||||
if not ids:
|
||||
break
|
||||
db.session.execute(
|
||||
delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids))
|
||||
)
|
||||
db.session.execute(delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)))
|
||||
db.session.commit()
|
||||
total += len(ids)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(click.style(
|
||||
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d "
|
||||
f"in {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
))
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d in {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
(DB upsert + plaintext generation), and TTL policy. Specs:
|
||||
docs/specs/v1.0/server/{device-flow.md, tokens.md}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
@ -15,11 +16,12 @@ from dataclasses import asdict, dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
|
||||
from models.oauth import OAuthAccessToken
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
|
||||
from models.oauth import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -51,7 +53,7 @@ return raw
|
||||
"""
|
||||
|
||||
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
|
||||
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
|
||||
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
|
||||
|
||||
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
|
||||
USER_CODE_SEGMENT_LEN = 4
|
||||
@ -95,7 +97,7 @@ class DeviceFlowState:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "DeviceFlowState":
|
||||
def from_json(cls, raw: str) -> DeviceFlowState:
|
||||
data = json.loads(raw)
|
||||
if "status" in data:
|
||||
data["status"] = DeviceFlowStatus(data["status"])
|
||||
@ -114,20 +116,19 @@ def _random_user_code() -> str:
|
||||
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
|
||||
|
||||
|
||||
class StateNotFound(Exception):
|
||||
class StateNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTransition(Exception):
|
||||
class InvalidTransitionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UserCodeExhausted(Exception):
|
||||
class UserCodeExhaustedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceFlowRedis:
|
||||
|
||||
def __init__(self, redis_client) -> None:
|
||||
self._redis = redis_client
|
||||
self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA)
|
||||
@ -157,7 +158,7 @@ class DeviceFlowRedis:
|
||||
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
|
||||
if ok:
|
||||
return user_code
|
||||
raise UserCodeExhausted("could not allocate a unique user_code in 5 attempts")
|
||||
raise UserCodeExhaustedError("could not allocate a unique user_code in 5 attempts")
|
||||
|
||||
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
|
||||
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
|
||||
@ -180,7 +181,7 @@ class DeviceFlowRedis:
|
||||
try:
|
||||
return DeviceFlowState.from_json(text_)
|
||||
except (ValueError, KeyError):
|
||||
logger.error("device_flow: corrupt state for %s", device_code)
|
||||
logger.exception("device_flow: corrupt state for %s", device_code)
|
||||
return None
|
||||
|
||||
def approve(
|
||||
@ -195,9 +196,9 @@ class DeviceFlowRedis:
|
||||
) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFound(device_code)
|
||||
raise StateNotFoundError(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransition(f"cannot approve {state.status}")
|
||||
raise InvalidTransitionError(f"cannot approve {state.status}")
|
||||
|
||||
state.status = DeviceFlowStatus.APPROVED
|
||||
state.subject_email = subject_email
|
||||
@ -213,9 +214,9 @@ class DeviceFlowRedis:
|
||||
def deny(self, device_code: str) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFound(device_code)
|
||||
raise StateNotFoundError(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransition(f"cannot deny {state.status}")
|
||||
raise InvalidTransitionError(f"cannot deny {state.status}")
|
||||
state.status = DeviceFlowStatus.DENIED
|
||||
self._redis.setex(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
@ -239,7 +240,7 @@ class DeviceFlowRedis:
|
||||
try:
|
||||
return DeviceFlowState.from_json(text_)
|
||||
except (ValueError, KeyError):
|
||||
logger.error("device_flow: corrupt state on consume %s", device_code)
|
||||
logger.exception("device_flow: corrupt state on consume %s", device_code)
|
||||
return None
|
||||
|
||||
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
|
||||
@ -287,6 +288,7 @@ ACCOUNT_ISSUER_SENTINEL = "dify:account"
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MintResult:
|
||||
"""Plaintext token surfaces to the caller once."""
|
||||
|
||||
token: str
|
||||
token_id: uuid.UUID
|
||||
expires_at: datetime
|
||||
@ -308,7 +310,9 @@ def sha256_hex(token: str) -> str:
|
||||
|
||||
|
||||
def mint_oauth_token(
|
||||
session: Session,
|
||||
# Accept either Session or Flask-SQLAlchemy's request-scoped wrapper —
|
||||
# the wrapper proxies the same execute/commit surface.
|
||||
session: Session | scoped_session,
|
||||
redis_client,
|
||||
*,
|
||||
subject_email: str,
|
||||
@ -328,9 +332,7 @@ def mint_oauth_token(
|
||||
# Account flow always writes the sentinel — caller may pass None
|
||||
# (for clarity) or the sentinel itself; nothing else is valid.
|
||||
if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL):
|
||||
raise ValueError(
|
||||
f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}"
|
||||
)
|
||||
raise ValueError(f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}")
|
||||
subject_issuer = ACCOUNT_ISSUER_SENTINEL
|
||||
elif prefix == PREFIX_OAUTH_EXTERNAL_SSO:
|
||||
# Defense in depth: enterprise canonicalises + rejects empty,
|
||||
@ -363,7 +365,7 @@ def mint_oauth_token(
|
||||
|
||||
|
||||
def _upsert(
|
||||
session: Session,
|
||||
session: Session | scoped_session,
|
||||
*,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
@ -415,6 +417,8 @@ def _upsert(
|
||||
row = session.execute(upsert_stmt).first()
|
||||
session.commit()
|
||||
|
||||
if row is None:
|
||||
raise RuntimeError("oauth_token upsert returned no row")
|
||||
token_id = uuid.UUID(str(row.id))
|
||||
return UpsertOutcome(
|
||||
token_id=token_id,
|
||||
@ -449,7 +453,9 @@ def oauth_ttl_days(tenant_id: str | None = None) -> int:
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"%s=%r is not an int; falling back to %d",
|
||||
_TTL_ENV_VAR, raw, DEFAULT_OAUTH_TTL_DAYS,
|
||||
_TTL_ENV_VAR,
|
||||
raw,
|
||||
DEFAULT_OAUTH_TTL_DAYS,
|
||||
)
|
||||
return DEFAULT_OAUTH_TTL_DAYS
|
||||
if value < MIN_TTL_DAYS:
|
||||
|
||||
@ -38,12 +38,7 @@ def test_membership_strategy_uses_join_lookup(member):
|
||||
|
||||
|
||||
def test_membership_strategy_rejects_external_sso():
|
||||
assert (
|
||||
MembershipStrategy().authorize(
|
||||
_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
|
||||
|
||||
|
||||
def test_app_authz_check_raises_when_strategy_denies():
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""User-scoped identity + session endpoints under /openapi/v1/account."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
|
||||
@ -6,8 +6,10 @@ from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import app_info # noqa: F401
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi import (
|
||||
app_info, # noqa: F401
|
||||
openapi_ns,
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
|
||||
@ -6,8 +6,10 @@ from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import chat_messages # noqa: F401
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi import (
|
||||
chat_messages, # noqa: F401
|
||||
openapi_ns,
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
@ -32,12 +34,11 @@ def test_chat_dispatches_and_returns_response_model(svc, bypass_pipeline):
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch(
|
||||
"controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()
|
||||
with (
|
||||
patch("controllers.openapi.chat_messages._unpack_app", return_value=fake),
|
||||
patch("controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()),
|
||||
):
|
||||
r = _client().post(
|
||||
"/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}}
|
||||
)
|
||||
r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}})
|
||||
assert r.status_code == 200
|
||||
body = r.get_json()
|
||||
assert body["conversation_id"] == "c1"
|
||||
@ -62,8 +63,9 @@ def test_chat_strips_user_field_from_body(svc, bypass_pipeline):
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch(
|
||||
"controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()
|
||||
with (
|
||||
patch("controllers.openapi.chat_messages._unpack_app", return_value=fake),
|
||||
patch("controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()),
|
||||
):
|
||||
_client().post(
|
||||
"/openapi/v1/apps/app1/chat-messages",
|
||||
@ -76,9 +78,7 @@ def test_chat_strips_user_field_from_body(svc, bypass_pipeline):
|
||||
def test_chat_rejects_non_chat_mode(bypass_pipeline):
|
||||
fake = SimpleNamespace(mode="completion")
|
||||
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake):
|
||||
r = _client().post(
|
||||
"/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}}
|
||||
)
|
||||
r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}})
|
||||
assert r.status_code in (400, 403)
|
||||
|
||||
|
||||
|
||||
@ -6,8 +6,10 @@ from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import completion_messages # noqa: F401
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi import (
|
||||
completion_messages, # noqa: F401
|
||||
openapi_ns,
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
@ -31,8 +33,9 @@ def test_completion_returns_response_model(svc, bypass_pipeline):
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="completion", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.completion_messages._unpack_app", return_value=fake), patch(
|
||||
"controllers.openapi.completion_messages._unpack_caller", return_value=SimpleNamespace()
|
||||
with (
|
||||
patch("controllers.openapi.completion_messages._unpack_app", return_value=fake),
|
||||
patch("controllers.openapi.completion_messages._unpack_caller", return_value=SimpleNamespace()),
|
||||
):
|
||||
r = _client().post(
|
||||
"/openapi/v1/apps/app1/completion-messages",
|
||||
|
||||
@ -7,9 +7,9 @@ Tests use a fresh Blueprint + Flask-CORS per case because the production
|
||||
blueprint is a module-level singleton and can't be reconfigured once
|
||||
registered.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Blueprint, Flask
|
||||
from flask.views import MethodView
|
||||
from flask_cors import CORS
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Account-branch device-flow approve/deny under /openapi/v1."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
|
||||
@ -4,6 +4,7 @@ authorization endpoint.
|
||||
Tests verify URL routing without invoking the handler — invoking would
|
||||
require Redis, which the unit-test runtime does not initialise.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
@ -31,16 +32,12 @@ def test_openapi_route_registered(openapi_app: Flask):
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(
|
||||
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code"
|
||||
)
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceCodeApi
|
||||
|
||||
|
||||
def test_route_accepts_post(openapi_app: Flask):
|
||||
rule = next(
|
||||
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code"
|
||||
)
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
|
||||
assert "POST" in rule.methods
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""GET /openapi/v1/oauth/device/lookup is the canonical user-code lookup."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
@ -26,14 +27,10 @@ def test_openapi_route_registered(openapi_app: Flask):
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(
|
||||
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup"
|
||||
)
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceLookupApi
|
||||
|
||||
|
||||
def test_route_accepts_get(openapi_app: Flask):
|
||||
rule = next(
|
||||
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup"
|
||||
)
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
|
||||
assert "GET" in rule.methods
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""POST /openapi/v1/oauth/device/token is the canonical poll endpoint."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
@ -26,7 +27,5 @@ def test_openapi_route_registered(openapi_app: Flask):
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(
|
||||
r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token"
|
||||
)
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceTokenApi
|
||||
|
||||
@ -6,8 +6,10 @@ from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi import workflow_run # noqa: F401
|
||||
from controllers.openapi import (
|
||||
openapi_ns,
|
||||
workflow_run, # noqa: F401
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
@ -32,8 +34,9 @@ def test_workflow_run_returns_response_model(svc, bypass_pipeline):
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="workflow", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.workflow_run._unpack_app", return_value=fake), patch(
|
||||
"controllers.openapi.workflow_run._unpack_caller", return_value=SimpleNamespace()
|
||||
with (
|
||||
patch("controllers.openapi.workflow_run._unpack_app", return_value=fake),
|
||||
patch("controllers.openapi.workflow_run._unpack_caller", return_value=SimpleNamespace()),
|
||||
):
|
||||
r = _client().post("/openapi/v1/apps/app1/workflows/run", json={"inputs": {"x": 1}})
|
||||
assert r.status_code == 200
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
list + member-gated detail. No legacy /v1/ equivalent — the cookie-authed
|
||||
/console/api/workspaces is a separate consumer that stays in console.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
Tests use a fake auth_ctx attached directly to flask.g — no
|
||||
authenticator wiring needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
Loading…
Reference in New Issue
Block a user