mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 04:36:31 +08:00
Type and lint pass over the openapi controllers, auth pipeline, and
oauth bearer/device-flow plumbing. Down from 36 pyright errors and 16
ruff errors to 0/0; 93 openapi unit tests pass.
Logic fixes:
- libs/oauth_bearer.py: drop private-naming on the friend-API methods
consumed by _VariantResolver (cache_get / cache_set_positive /
cache_set_negative / hard_expire / session_factory). They were always
cross-class accessors — leading underscore was misleading. Add public
registry property on BearerAuthenticator. _hard_expire row_id widened
to UUID | str (matches the StringUUID column type).
- libs/oauth_bearer.py: type validate_bearer / bearer_feature_required
with ParamSpec / PEP-695 so wrapped routes preserve their signature.
- libs/rate_limit.py: same — typed rate_limit decorator.
- services/oauth_device_flow.py: mint_oauth_token / _upsert accept
Session | scoped_session (Flask-SQLAlchemy proxy). Guard row-is-None
after upsert.
- controllers/openapi/{chat,completion,workflow}_messages.py: tuple-vs-
Mapping shape narrowing on AppGenerateService.generate return —
production returns Mapping, tests mock as (body, status). Validate
through Pydantic Response model in both shapes.
- controllers/openapi/oauth_device.py: replace flask_restx.reqparse (banned)
with Pydantic Request/Query models — DeviceCodeRequest, DevicePollRequest,
DeviceLookupQuery, DeviceMutateRequest. Two PEP-695 generic helpers
(_validate_json / _validate_query) translate ValidationError to BadRequest.
- controllers/openapi/auth/strategies.py: Protocol param-name match
(subject_type), Optional narrowing on app/tenant/account_id/subject_email.
- controllers/openapi/auth/steps.py: subject_type-is-None guard before
mounter dispatch.
- core/app/apps/workflow/generate_task_pipeline.py + models/workflow.py:
add WorkflowAppLogCreatedFrom.OPENAPI + matching match-case branch.
Fixes match-exhaustiveness and possibly-unbound created_from.
- libs/device_flow_security.py: pyright ignore on flask after_request
hook (registered by the framework, pyright sees as unused).
- services/oauth_device_flow.py: rename Exceptions to *Error suffix
(StateNotFoundError / InvalidTransitionError / UserCodeExhaustedError);
same for libs/oauth_bearer.py (InvalidBearerError / TokenExpiredError).
Update all callers across openapi controllers.
- controllers/openapi/{oauth_device,oauth_device_sso}.py +
services/oauth_device_flow.py: switch logger.error in except blocks
to logger.exception (TRY400) — keeps the traceback for ops.
- configs/feature/__init__.py: OPENAPI_KNOWN_CLIENT_IDS computed_field
needs an @property alongside for pyright to see it as a value, not a
method. Matches the existing line-451 pattern.
Plus ruff format + import-sort across the openapi tree (pure formatting).
109 lines
3.5 KiB
Python
109 lines
3.5 KiB
Python
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
|
|
state envelope, external subject assertion, and approval-grant cookie —
|
|
all three share one key-set so api ↔ enterprise can verify each other.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import jwt
|
|
|
|
from configs import dify_config
|
|
|
|
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
|
|
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
|
|
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
|
|
|
|
ACTIVE_KID_V1 = "dify-shared-v1"
|
|
|
|
|
|
class KeySetError(Exception):
|
|
pass
|
|
|
|
|
|
class KeySet:
|
|
"""``from_entries`` reserves multi-kid construction for rotation slots."""
|
|
|
|
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
|
|
if active_kid not in entries:
|
|
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
|
|
if not entries[active_kid]:
|
|
raise KeySetError(f"active kid {active_kid!r} has empty secret")
|
|
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
|
|
self._active_kid = active_kid
|
|
|
|
@classmethod
|
|
def from_shared_secret(cls) -> KeySet:
|
|
secret = dify_config.SECRET_KEY
|
|
if not secret:
|
|
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
|
|
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
|
|
|
|
@classmethod
|
|
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
|
|
return cls(entries, active_kid)
|
|
|
|
@property
|
|
def active_kid(self) -> str:
|
|
return self._active_kid
|
|
|
|
def lookup(self, kid: str) -> bytes | None:
|
|
return self._entries.get(kid)
|
|
|
|
|
|
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
|
|
"""``iat`` + ``exp`` are injected here; callers must not set them."""
|
|
if "aud" in payload or "iat" in payload or "exp" in payload:
|
|
raise ValueError("reserved claim present in payload (aud/iat/exp)")
|
|
if ttl_seconds <= 0:
|
|
raise ValueError("ttl_seconds must be positive")
|
|
|
|
kid = keyset.active_kid
|
|
secret = keyset.lookup(kid)
|
|
if secret is None:
|
|
raise KeySetError(f"active kid {kid!r} lookup miss")
|
|
|
|
iat = datetime.now(UTC)
|
|
exp = iat + timedelta(seconds=ttl_seconds)
|
|
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
|
|
return jwt.encode(
|
|
claims,
|
|
secret,
|
|
algorithm="HS256",
|
|
headers={"kid": kid, "typ": "JWT"},
|
|
)
|
|
|
|
|
|
class VerifyError(Exception):
|
|
pass
|
|
|
|
|
|
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
|
|
"""Unknown kid is rejected — never fall back to the active kid, since
|
|
a past kid value would otherwise be forgeable by anyone who saw it.
|
|
"""
|
|
try:
|
|
header = jwt.get_unverified_header(token)
|
|
except jwt.PyJWTError as e:
|
|
raise VerifyError(f"decode header: {e}") from e
|
|
kid = header.get("kid")
|
|
if not kid:
|
|
raise VerifyError("no kid in header")
|
|
secret = keyset.lookup(kid)
|
|
if secret is None:
|
|
raise VerifyError(f"unknown kid {kid!r}")
|
|
try:
|
|
return jwt.decode(
|
|
token,
|
|
secret,
|
|
algorithms=["HS256"],
|
|
audience=expected_aud,
|
|
)
|
|
except jwt.ExpiredSignatureError as e:
|
|
raise VerifyError("token expired") from e
|
|
except jwt.InvalidAudienceError as e:
|
|
raise VerifyError("aud mismatch") from e
|
|
except jwt.PyJWTError as e:
|
|
raise VerifyError(f"decode: {e}") from e
|