diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 00ac6ecb46..e036003ca3 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -523,7 +523,8 @@ class HttpConfig(BaseSettings): default="difyctl", ) - @computed_field + @computed_field # type: ignore[misc] + @property def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]: return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c) diff --git a/api/controllers/openapi/_audit.py b/api/controllers/openapi/_audit.py index 30c3e1d143..4c3ae888b9 100644 --- a/api/controllers/openapi/_audit.py +++ b/api/controllers/openapi/_audit.py @@ -4,6 +4,7 @@ Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...} matches the existing oauth_device convention. The EE OTel exporter consults its own allowlist to decide whether to ship the line. """ + from __future__ import annotations import logging diff --git a/api/controllers/openapi/_models.py b/api/controllers/openapi/_models.py index 5971f59e42..9e70250823 100644 --- a/api/controllers/openapi/_models.py +++ b/api/controllers/openapi/_models.py @@ -1,4 +1,5 @@ """Shared response substructures for openapi endpoints.""" + from __future__ import annotations from typing import Any diff --git a/api/controllers/openapi/account.py b/api/controllers/openapi/account.py index 6ce043d49e..e6a7974266 100644 --- a/api/controllers/openapi/account.py +++ b/api/controllers/openapi/account.py @@ -2,6 +2,7 @@ identity read; /account/sessions and /account/sessions/ manage the user's active OAuth tokens. """ + from __future__ import annotations from datetime import UTC, datetime @@ -16,9 +17,9 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.oauth_bearer import ( ACCEPT_USER_ANY, + TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType, - TOKEN_CACHE_KEY_FMT, validate_bearer, ) from libs.rate_limit import ( @@ -51,8 +52,7 @@ class AccountApi(Resource): } account = ( - db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() - if ctx.account_id else None + db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() if ctx.account_id else None ) memberships = _load_memberships(ctx.account_id) if ctx.account_id else [] default_ws_id = _pick_default_workspace(memberships) @@ -129,8 +129,7 @@ class AccountSessionByIdApi(Resource): # Subject-match guard. 404 (not 403) on cross-subject so the # endpoint doesn't leak token IDs that belong to other subjects. owns = db.session.execute( - select(OAuthAccessToken.id) - .where( + select(OAuthAccessToken.id).where( and_( OAuthAccessToken.id == session_id, *_subject_match(ctx), @@ -160,8 +159,7 @@ def _subject_match(ctx: AuthContext) -> tuple: def _require_oauth_subject(ctx: AuthContext) -> None: if not ctx.source.startswith("oauth"): raise BadRequest( - "this endpoint revokes OAuth bearer tokens; " - "use /openapi/v1/personal-access-tokens/self for PATs" + "this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs" ) diff --git a/api/controllers/openapi/app_info.py b/api/controllers/openapi/app_info.py index aa9b8d20a7..b9a805d015 100644 --- a/api/controllers/openapi/app_info.py +++ b/api/controllers/openapi/app_info.py @@ -1,4 +1,5 @@ """GET /openapi/v1/apps//info — port of service_api/app/app.py:AppInfoApi.""" + from __future__ import annotations from flask_restx import Resource diff --git a/api/controllers/openapi/auth/composition.py b/api/controllers/openapi/auth/composition.py index fa78f07e3a..a8da919f29 100644 --- a/api/controllers/openapi/auth/composition.py +++ b/api/controllers/openapi/auth/composition.py @@ -2,6 +2,7 @@ Endpoints attach via @APP_PIPELINE.guard(scope=…). No alternative paths. """ + from __future__ import annotations from controllers.openapi.auth.pipeline import Pipeline diff --git a/api/controllers/openapi/auth/context.py b/api/controllers/openapi/auth/context.py index a23e4e981d..df43a5183d 100644 --- a/api/controllers/openapi/auth/context.py +++ b/api/controllers/openapi/auth/context.py @@ -4,6 +4,7 @@ Every field starts None / empty and is filled in by a step. The pipeline is the only thing that should construct or mutate Context — handlers read populated values via the decorator's kwargs unpacking. """ + from __future__ import annotations from dataclasses import dataclass, field diff --git a/api/controllers/openapi/auth/pipeline.py b/api/controllers/openapi/auth/pipeline.py index c0df85367e..b4ca1e793b 100644 --- a/api/controllers/openapi/auth/pipeline.py +++ b/api/controllers/openapi/auth/pipeline.py @@ -4,6 +4,7 @@ that is the design lock-in: forgetting an auth layer is structurally impossible because there is no "sometimes wrap, sometimes don't" choice. """ + from __future__ import annotations from functools import wraps diff --git a/api/controllers/openapi/auth/steps.py b/api/controllers/openapi/auth/steps.py index bf64cc5472..c671fec21f 100644 --- a/api/controllers/openapi/auth/steps.py +++ b/api/controllers/openapi/auth/steps.py @@ -3,21 +3,22 @@ BearerCheck is the only step that touches the token registry; downstream steps see only the populated Context. """ + from __future__ import annotations -from typing import Callable +from collections.abc import Callable from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized from controllers.openapi.auth.context import Context from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter from extensions.ext_database import db -from libs.oauth_bearer import TokenExpired, get_authenticator, sha256_hex +from libs.oauth_bearer import TokenExpiredError, get_authenticator, sha256_hex from models import App, Tenant, TenantStatus def _registry(): - return get_authenticator()._registry # noqa: SLF001 + return get_authenticator().registry def _extract_bearer(req) -> str | None: @@ -45,7 +46,7 @@ class BearerCheck: try: row = kind.resolver.resolve(_hash_token(token)) - except TokenExpired: + except TokenExpiredError: raise Unauthorized("token expired") if row is None: raise Unauthorized("invalid bearer") @@ -105,6 +106,8 @@ class CallerMount: self._mounters = mounters def __call__(self, ctx: Context) -> None: + if ctx.subject_type is None: + raise Unauthorized("subject_type unset — BearerCheck did not run") for m in self._mounters: if m.applies_to(ctx.subject_type): m.mount(ctx) diff --git a/api/controllers/openapi/auth/strategies.py b/api/controllers/openapi/auth/strategies.py index cda2e2ae51..d8c02f7881 100644 --- a/api/controllers/openapi/auth/strategies.py +++ b/api/controllers/openapi/auth/strategies.py @@ -4,6 +4,7 @@ App authorization (Acl/Membership) and caller mounting (Account/EndUser) vary along independent axes; each strategy is one class so the pipeline composition stays a flat list. """ + from __future__ import annotations from typing import Protocol @@ -33,9 +34,11 @@ class AclStrategy: """ def authorize(self, ctx: Context) -> bool: + if ctx.subject_email is None or ctx.app is None: + return False return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( user_id=ctx.subject_email, - app_id=ctx.app.id, + app_id=ctx.app.id, # type: ignore[attr-defined] ) @@ -50,7 +53,9 @@ class MembershipStrategy: def authorize(self, ctx: Context) -> bool: if ctx.subject_type == SubjectType.EXTERNAL_SSO: return False - return _has_tenant_membership(ctx.account_id, ctx.tenant.id) + if ctx.tenant is None: + return False + return _has_tenant_membership(ctx.account_id, ctx.tenant.id) # type: ignore[attr-defined] def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool: @@ -67,8 +72,8 @@ def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool: def _login_as(user) -> None: """Set Flask-Login request user so downstream services see the caller.""" - current_app.login_manager._update_request_context_with_user(user) # noqa: SLF001 - user_logged_in.send(current_app._get_current_object(), user=user) # noqa: SLF001 + current_app.login_manager._update_request_context_with_user(user) + user_logged_in.send(current_app._get_current_object(), user=user) class CallerMounter(Protocol): @@ -78,25 +83,31 @@ class CallerMounter(Protocol): class AccountMounter: - def applies_to(self, st: SubjectType) -> bool: - return st == SubjectType.ACCOUNT + def applies_to(self, subject_type: SubjectType) -> bool: + return subject_type == SubjectType.ACCOUNT def mount(self, ctx: Context) -> None: + if ctx.account_id is None: + raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run") account = db.session.get(Account, ctx.account_id) - account.current_tenant = ctx.tenant + if account is None: + raise RuntimeError("AccountMounter: account row missing for resolved bearer") + account.current_tenant = ctx.tenant # type: ignore[assignment] _login_as(account) ctx.caller, ctx.caller_kind = account, "account" class EndUserMounter: - def applies_to(self, st: SubjectType) -> bool: - return st == SubjectType.EXTERNAL_SSO + def applies_to(self, subject_type: SubjectType) -> bool: + return subject_type == SubjectType.EXTERNAL_SSO def mount(self, ctx: Context) -> None: + if ctx.tenant is None or ctx.app is None or ctx.subject_email is None: + raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run") end_user = EndUserService.get_or_create_end_user_by_type( InvokeFrom.OPENAPI, - tenant_id=ctx.tenant.id, - app_id=ctx.app.id, + tenant_id=ctx.tenant.id, # type: ignore[attr-defined] + app_id=ctx.app.id, # type: ignore[attr-defined] user_id=ctx.subject_email, ) _login_as(end_user) diff --git a/api/controllers/openapi/chat_messages.py b/api/controllers/openapi/chat_messages.py index 2335e59da9..f746edc7da 100644 --- a/api/controllers/openapi/chat_messages.py +++ b/api/controllers/openapi/chat_messages.py @@ -8,9 +8,11 @@ Differences from service_api: - Typed Request and Response models. - invoke_from = InvokeFrom.OPENAPI. """ + from __future__ import annotations import logging +from collections.abc import Mapping from typing import Any, Literal from uuid import UUID @@ -163,5 +165,12 @@ class ChatMessagesApi(Resource): if streaming: return helper.compact_generate_response(response) - body_dict = response[0] if isinstance(response, tuple) else response - return ChatMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200 + # Some upstream paths (and tests) return (body, status); production + # generate returns Mapping. Accept both, then validate. + if isinstance(response, tuple): + body_dict: Any = response[0] # pyright: ignore[reportArgumentType] + else: + body_dict = response + if not isinstance(body_dict, Mapping): + raise InternalServerError("blocking generate returned non-mapping response") + return ChatMessageResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200 diff --git a/api/controllers/openapi/completion_messages.py b/api/controllers/openapi/completion_messages.py index 9085297793..1180d43113 100644 --- a/api/controllers/openapi/completion_messages.py +++ b/api/controllers/openapi/completion_messages.py @@ -1,8 +1,10 @@ """POST /openapi/v1/apps//completion-messages — port of service_api/app/completion.py:CompletionApi.""" + from __future__ import annotations import logging +from collections.abc import Mapping from typing import Any, Literal from flask import request @@ -120,5 +122,10 @@ class CompletionMessagesApi(Resource): if streaming: return helper.compact_generate_response(response) - body_dict = response[0] if isinstance(response, tuple) else response - return CompletionMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200 + if isinstance(response, tuple): + body_dict: Any = response[0] # pyright: ignore[reportArgumentType] + else: + body_dict = response + if not isinstance(body_dict, Mapping): + raise InternalServerError("blocking generate returned non-mapping response") + return CompletionMessageResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200 diff --git a/api/controllers/openapi/oauth_device.py b/api/controllers/openapi/oauth_device.py index 48a7d7f8c4..b45865d2b0 100644 --- a/api/controllers/openapi/oauth_device.py +++ b/api/controllers/openapi/oauth_device.py @@ -12,13 +12,16 @@ sub-groups in one module: SSO branch lives in oauth_device_sso.py. """ + from __future__ import annotations import logging from flask import request from flask_login import login_required -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, ValidationError +from werkzeug.exceptions import BadRequest from configs import dify_config from controllers.console.wraps import account_initialization_required, setup_required @@ -41,9 +44,9 @@ from services.oauth_device_flow import ( PREFIX_OAUTH_ACCOUNT, DeviceFlowRedis, DeviceFlowStatus, - InvalidTransition, + InvalidTransitionError, SlowDownDecision, - StateNotFound, + StateNotFoundError, mint_oauth_token, oauth_ttl_days, ) @@ -52,22 +55,41 @@ logger = logging.getLogger(__name__) # ========================================================================= -# Parsers +# Request / query schemas # ========================================================================= -_code_parser = reqparse.RequestParser() -_code_parser.add_argument("client_id", type=str, required=True, location="json") -_code_parser.add_argument("device_label", type=str, required=True, location="json") -_poll_parser = reqparse.RequestParser() -_poll_parser.add_argument("device_code", type=str, required=True, location="json") -_poll_parser.add_argument("client_id", type=str, required=True, location="json") +class DeviceCodeRequest(BaseModel): + client_id: str + device_label: str -_lookup_parser = reqparse.RequestParser() -_lookup_parser.add_argument("user_code", type=str, required=True, location="args") -_mutate_parser = reqparse.RequestParser() -_mutate_parser.add_argument("user_code", type=str, required=True, location="json") +class DevicePollRequest(BaseModel): + device_code: str + client_id: str + + +class DeviceLookupQuery(BaseModel): + user_code: str + + +class DeviceMutateRequest(BaseModel): + user_code: str + + +def _validate_json[M: BaseModel](model: type[M]) -> M: + body = request.get_json(silent=True) or {} + try: + return model.model_validate(body) + except ValidationError as exc: + raise BadRequest(str(exc)) + + +def _validate_query[M: BaseModel](model: type[M]) -> M: + try: + return model.model_validate(request.args.to_dict(flat=True)) + except ValidationError as exc: + raise BadRequest(str(exc)) # ========================================================================= @@ -79,9 +101,9 @@ _mutate_parser.add_argument("user_code", type=str, required=True, location="json class OAuthDeviceCodeApi(Resource): @rate_limit(LIMIT_DEVICE_CODE_PER_IP) def post(self): - args = _code_parser.parse_args() - client_id = args["client_id"] - device_label = args["device_label"] + payload = _validate_json(DeviceCodeRequest) + client_id = payload.client_id + device_label = payload.device_label if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS: return {"error": "unsupported_client"}, 400 @@ -104,8 +126,8 @@ class OAuthDeviceTokenApi(Resource): """RFC 8628 poll.""" def post(self): - args = _poll_parser.parse_args() - device_code = args["device_code"] + payload = _validate_json(DevicePollRequest) + device_code = payload.device_code store = DeviceFlowRedis(redis_client) @@ -145,8 +167,8 @@ class OAuthDeviceLookupApi(Resource): @rate_limit(LIMIT_LOOKUP_PUBLIC) def get(self): - args = _lookup_parser.parse_args() - user_code = args["user_code"].strip().upper() + payload = _validate_query(DeviceLookupQuery) + user_code = payload.user_code.strip().upper() store = DeviceFlowRedis(redis_client) found = store.load_by_user_code(user_code) @@ -181,8 +203,8 @@ class DeviceApproveApi(Resource): @bearer_feature_required @rate_limit(LIMIT_APPROVE_CONSOLE) def post(self): - args = _mutate_parser.parse_args() - user_code = args["user_code"].strip().upper() + payload = _validate_json(DeviceMutateRequest) + user_code = payload.user_code.strip().upper() account, tenant = current_account_with_tenant() store = DeviceFlowRedis(redis_client) @@ -226,10 +248,10 @@ class DeviceApproveApi(Resource): token_id=str(mint.token_id), poll_payload=poll_payload, ) - except (StateNotFound, InvalidTransition) as e: + except (StateNotFoundError, InvalidTransitionError): # Row minted but state vanished — roll forward; the orphan # token is revocable via auth devices list / Authorized Apps. - logger.error("device_flow: approve raced on %s: %s", device_code, e) + logger.exception("device_flow: approve raced on %s", device_code) return {"error": "state_lost"}, 409 finally: redis_client.delete(guard_key) @@ -246,8 +268,8 @@ class DeviceDenyApi(Resource): @bearer_feature_required @rate_limit(LIMIT_APPROVE_CONSOLE) def post(self): - args = _mutate_parser.parse_args() - user_code = args["user_code"].strip().upper() + payload = _validate_json(DeviceMutateRequest) + user_code = payload.user_code.strip().upper() store = DeviceFlowRedis(redis_client) found = store.load_by_user_code(user_code) @@ -259,8 +281,8 @@ class DeviceDenyApi(Resource): try: store.deny(device_code) - except (StateNotFound, InvalidTransition) as e: - logger.error("device_flow: deny raced on %s: %s", device_code, e) + except (StateNotFoundError, InvalidTransitionError): + logger.exception("device_flow: deny raced on %s", device_code) return {"error": "state_lost"}, 409 _emit_deny_audit(state) @@ -284,7 +306,9 @@ def _audit_cross_ip_if_needed(state) -> None: if state.created_ip and poll_ip and poll_ip != state.created_ip: logger.warning( "audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s", - state.token_id, state.created_ip, poll_ip, + state.token_id, + state.created_ip, + poll_ip, extra={ "audit": True, "token_id": state.token_id, @@ -299,16 +323,14 @@ def _build_account_poll_payload(account, tenant, mint) -> dict: handler doesn't re-query accounts/tenants for authz data. """ from models import Tenant, TenantAccountJoin + rows = ( db.session.query(Tenant, TenantAccountJoin) .join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id) .filter(TenantAccountJoin.account_id == account.id) .all() ) - workspaces = [ - {"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} - for t, m in rows - ] + workspaces = [{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} for t, m in rows] # Prefer active session tenant → DB-flagged current join → first membership. default_ws_id = None if tenant and any(w["id"] == str(tenant) for w in workspaces): @@ -335,7 +357,11 @@ def _build_account_poll_payload(account, tenant, mint) -> dict: def _emit_approve_audit(state, account, tenant, mint) -> None: logger.warning( "audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s", - mint.token_id, account.email, state.client_id, state.device_label, mint.expires_at, + mint.token_id, + account.email, + state.client_id, + state.device_label, + mint.expires_at, extra={ "audit": True, "event": "oauth.device_flow_approved", @@ -355,7 +381,8 @@ def _emit_approve_audit(state, account, tenant, mint) -> None: def _emit_deny_audit(state) -> None: logger.warning( "audit: oauth.device_flow_denied client_id=%s device_label=%s", - state.client_id, state.device_label, + state.client_id, + state.device_label, extra={ "audit": True, "event": "oauth.device_flow_denied", @@ -363,5 +390,3 @@ def _emit_deny_audit(state) -> None: "device_label": state.device_label, }, ) - - diff --git a/api/controllers/openapi/oauth_device_sso.py b/api/controllers/openapi/oauth_device_sso.py index 1b0e55993a..334c5806f6 100644 --- a/api/controllers/openapi/oauth_device_sso.py +++ b/api/controllers/openapi/oauth_device_sso.py @@ -9,6 +9,7 @@ EE-only. Browser flow: Function-based (raw @bp.route) rather than Resource classes because the handlers do redirects + cookie kwargs that don't fit the Resource shape. """ + from __future__ import annotations import logging @@ -51,8 +52,8 @@ from services.oauth_device_flow import ( PREFIX_OAUTH_EXTERNAL_SSO, DeviceFlowRedis, DeviceFlowStatus, - InvalidTransition, - StateNotFound, + InvalidTransitionError, + StateNotFoundError, mint_oauth_token, oauth_ttl_days, ) @@ -171,13 +172,15 @@ def approval_context(): logger.warning("approval-context: bad cookie: %s", e) raise Unauthorized("no_session") from e - return jsonify({ - "subject_email": claims.subject_email, - "subject_issuer": claims.subject_issuer, - "user_code": claims.user_code, - "csrf_token": claims.csrf_token, - "expires_at": claims.expires_at.isoformat(), - }), 200 + return jsonify( + { + "subject_email": claims.subject_email, + "subject_issuer": claims.subject_issuer, + "user_code": claims.user_code, + "csrf_token": claims.csrf_token, + "expires_at": claims.expires_at.isoformat(), + } + ), 200 @bp.route("/oauth/device/approve-external", methods=["POST"]) @@ -251,8 +254,8 @@ def approve_external(): token_id=str(mint.token_id), poll_payload=poll_payload, ) - except (StateNotFound, InvalidTransition) as e: - logger.error("approve-external: state transition raced: %s", e) + except (StateNotFoundError, InvalidTransitionError) as e: + logger.exception("approve-external: state transition raced") raise Conflict("state_lost") from e _emit_approve_external_audit(state, claims, mint) @@ -264,9 +267,11 @@ def approve_external(): def _emit_approve_external_audit(state, claims, mint) -> None: logger.warning( - "audit: oauth.device_flow_approved subject_type=%s " - "subject_email=%s subject_issuer=%s token_id=%s", - SubjectType.EXTERNAL_SSO, claims.subject_email, claims.subject_issuer, mint.token_id, + "audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s", + SubjectType.EXTERNAL_SSO, + claims.subject_email, + claims.subject_issuer, + mint.token_id, extra={ "audit": True, "event": "oauth.device_flow_approved", diff --git a/api/controllers/openapi/workflow_run.py b/api/controllers/openapi/workflow_run.py index d76ff553de..c71e9ab529 100644 --- a/api/controllers/openapi/workflow_run.py +++ b/api/controllers/openapi/workflow_run.py @@ -1,8 +1,10 @@ """POST /openapi/v1/apps//workflows/run — port of service_api/app/workflow.py:WorkflowRunApi.""" + from __future__ import annotations import logging +from collections.abc import Mapping from typing import Any, Literal from flask import request @@ -128,5 +130,10 @@ class WorkflowRunApi(Resource): if streaming: return helper.compact_generate_response(response) - body_dict = response[0] if isinstance(response, tuple) else response - return WorkflowRunResponse.model_validate(body_dict).model_dump(mode="json"), 200 + if isinstance(response, tuple): + body_dict: Any = response[0] # pyright: ignore[reportArgumentType] + else: + body_dict = response + if not isinstance(body_dict, Mapping): + raise InternalServerError("blocking generate returned non-mapping response") + return WorkflowRunResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200 diff --git a/api/controllers/openapi/workspaces.py b/api/controllers/openapi/workspaces.py index 173ebcbb57..fd22c2e620 100644 --- a/api/controllers/openapi/workspaces.py +++ b/api/controllers/openapi/workspaces.py @@ -5,8 +5,11 @@ Account bearers (dfoa_) see every tenant they're a member of. External SSO bearers (dfoe_) have no account_id and so see an empty list — that matches /openapi/v1/account. """ + from __future__ import annotations +from itertools import starmap + from flask import g from flask_restx import Resource from sqlalchemy import select @@ -37,7 +40,7 @@ class WorkspacesApi(Resource): .order_by(Tenant.created_at.asc()) ).all() - return {"workspaces": [_workspace_summary(t, m) for t, m in rows]}, 200 + return {"workspaces": list(starmap(_workspace_summary, rows))}, 200 @openapi_ns.route("/workspaces/") diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 15645add57..23f040d838 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -685,6 +685,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): match invoke_from: case InvokeFrom.SERVICE_API: created_from = WorkflowAppLogCreatedFrom.SERVICE_API + case InvokeFrom.OPENAPI: + created_from = WorkflowAppLogCreatedFrom.OPENAPI case InvokeFrom.EXPLORE: created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP case InvokeFrom.WEB_APP: diff --git a/api/extensions/ext_oauth_bearer.py b/api/extensions/ext_oauth_bearer.py index d881a88c87..58c2ac2d2c 100644 --- a/api/extensions/ext_oauth_bearer.py +++ b/api/extensions/ext_oauth_bearer.py @@ -1,6 +1,7 @@ """Bind the bearer authenticator at startup. Must run after ext_database and ext_redis (needs both factories). """ + from __future__ import annotations from configs import dify_config diff --git a/api/libs/device_flow_security.py b/api/libs/device_flow_security.py index db962bfca5..d973a0820b 100644 --- a/api/libs/device_flow_security.py +++ b/api/libs/device_flow_security.py @@ -1,14 +1,15 @@ """Device-flow security primitives: enterprise_only gate, approval-grant cookie mint/verify/consume, and anti-framing headers. """ + from __future__ import annotations import logging import secrets +from collections.abc import Callable from dataclasses import dataclass from datetime import UTC, datetime, timedelta from functools import wraps -from typing import Callable from flask import Blueprint from werkzeug.exceptions import NotFound @@ -122,7 +123,10 @@ def consume_approval_grant_nonce(redis_client, nonce: str) -> bool: return False return bool( redis_client.set( - NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS, + NONCE_KEY_FMT.format(nonce=nonce), + "1", + nx=True, + ex=NONCE_TTL_SECONDS, ) ) @@ -132,7 +136,10 @@ def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool: return False return bool( redis_client.set( - SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce), "1", nx=True, ex=NONCE_TTL_SECONDS, + SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce), + "1", + nx=True, + ex=NONCE_TTL_SECONDS, ) ) @@ -183,7 +190,7 @@ def attach_anti_framing(bp: Blueprint) -> None: """X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4).""" @bp.after_request - def _apply_headers(response): + def _apply_headers(response): # pyright: ignore[reportUnusedFunction] for name, value in _ANTI_FRAMING_HEADERS.items(): response.headers.setdefault(name, value) return response diff --git a/api/libs/jws.py b/api/libs/jws.py index f66811aabd..692ccb39fa 100644 --- a/api/libs/jws.py +++ b/api/libs/jws.py @@ -2,11 +2,13 @@ state envelope, external subject assertion, and approval-grant cookie — all three share one key-set so api ↔ enterprise can verify each other. """ + from __future__ import annotations from datetime import UTC, datetime, timedelta import jwt + from configs import dify_config AUD_STATE_ENVELOPE = "api.sso.state_envelope" @@ -32,14 +34,14 @@ class KeySet: self._active_kid = active_kid @classmethod - def from_shared_secret(cls) -> "KeySet": + def from_shared_secret(cls) -> KeySet: secret = dify_config.SECRET_KEY if not secret: raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set") return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1) @classmethod - def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> "KeySet": + def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet: return cls(entries, active_kid) @property diff --git a/api/libs/oauth_bearer.py b/api/libs/oauth_bearer.py index eaa7f2d5f8..f524c0b0b4 100644 --- a/api/libs/oauth_bearer.py +++ b/api/libs/oauth_bearer.py @@ -4,17 +4,19 @@ To add a token kind: write a Resolver, add a SubjectType + Accepts member, append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT. Authenticator + validate_bearer stay untouched. """ + from __future__ import annotations import hashlib import json import logging import uuid +from collections.abc import Callable, Iterable from dataclasses import dataclass from datetime import UTC, datetime from enum import StrEnum from functools import wraps -from typing import Callable, Iterable, Literal, Protocol +from typing import Literal, ParamSpec, Protocol, TypeVar from flask import g, request from sqlalchemy import update @@ -79,11 +81,11 @@ class TokenKind: return token.startswith(self.prefix) -class InvalidBearer(Exception): +class InvalidBearerError(Exception): """Token missing, unknown prefix, or no live row.""" -class TokenExpired(Exception): +class TokenExpiredError(Exception): """Hard-expire bookkeeping is the resolver's job before raising.""" @@ -122,13 +124,17 @@ class BearerAuthenticator: def __init__(self, registry: TokenKindRegistry) -> None: self._registry = registry + @property + def registry(self) -> TokenKindRegistry: + return self._registry + def authenticate(self, token: str) -> AuthContext: kind = self._registry.find(token) if kind is None: - raise InvalidBearer("unknown token prefix") + raise InvalidBearerError("unknown token prefix") row = kind.resolver.resolve(sha256_hex(token)) if row is None: - raise InvalidBearer("token unknown or revoked") + raise InvalidBearerError("token unknown or revoked") return AuthContext( subject_type=kind.subject_type, subject_email=row.subject_email, @@ -165,7 +171,9 @@ class OAuthAccessTokenResolver: positive_ttl: int = POSITIVE_TTL_SECONDS, negative_ttl: int = NEGATIVE_TTL_SECONDS, ) -> None: - self._session_factory = session_factory + # session_factory and the cache helpers below are friend-API for + # _VariantResolver in this module — kept public-named on purpose. + self.session_factory = session_factory self._redis = redis_client self._positive_ttl = positive_ttl self._negative_ttl = negative_ttl @@ -179,7 +187,7 @@ class OAuthAccessTokenResolver: def _cache_key(self, token_hash: str) -> str: return TOKEN_CACHE_KEY_FMT.format(hash=token_hash) - def _cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]: + def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]: raw = self._redis.get(self._cache_key(token_hash)) if raw is None: return None @@ -193,17 +201,17 @@ class OAuthAccessTokenResolver: logger.warning("auth:token cache entry malformed; treating as miss") return None - def _cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None: + def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None: self._redis.setex( self._cache_key(token_hash), self._positive_ttl, json.dumps(_row_to_cache(row)), ) - def _cache_set_negative(self, token_hash: str) -> None: + def cache_set_negative(self, token_hash: str) -> None: self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid") - def _hard_expire(self, session: Session, row_id: uuid.UUID, token_hash: str) -> None: + def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None: """Atomic CAS — only the worker that flips revoked_at emits audit; replays are idempotent. Spec: tokens.md §Detection + hard-expire. """ @@ -216,21 +224,22 @@ class OAuthAccessTokenResolver: session.commit() if result.rowcount == 1: logger.warning( - "audit: %s token_id=%s", AUDIT_OAUTH_EXPIRED, row_id, + "audit: %s token_id=%s", + AUDIT_OAUTH_EXPIRED, + row_id, extra={"audit": True, "token_id": str(row_id)}, ) self._redis.delete(self._cache_key(token_hash)) - self._cache_set_negative(token_hash) + self.cache_set_negative(token_hash) class _VariantResolver: - def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None: self._parent = parent self._variant = variant def resolve(self, token_hash: str) -> ResolvedRow | None: - cached = self._parent._cache_get(token_hash) + cached = self._parent.cache_get(token_hash) if cached == "invalid": return None if cached is not None and not isinstance(cached, str): @@ -238,23 +247,24 @@ class _VariantResolver: return None return cached - # _session_factory returns Flask-SQLAlchemy's scoped_session, which is + # session_factory returns Flask-SQLAlchemy's scoped_session, which is # request-bound and not a context manager; use it directly. - session = self._parent._session_factory() + session = self._parent.session_factory() row = self._load_from_db(session, token_hash) if row is None: - self._parent._cache_set_negative(token_hash) + self._parent.cache_set_negative(token_hash) return None now = datetime.now(UTC) if row.expires_at is not None and row.expires_at <= now: - self._parent._hard_expire(session, row.id, token_hash) + self._parent.hard_expire(session, row.id, token_hash) return None if not self._matches_variant_model(row): logger.error( "internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s", - row.id, row.prefix, + row.id, + row.prefix, ) return None @@ -265,7 +275,7 @@ class _VariantResolver: token_id=uuid.UUID(str(row.id)), expires_at=row.expires_at, ) - self._parent._cache_set_positive(token_hash, resolved) + self._parent.cache_set_positive(token_hash, resolved) return resolved def _matches_variant(self, row: ResolvedRow) -> bool: @@ -352,7 +362,11 @@ def _extract_bearer(req) -> str | None: return value.strip() -def validate_bearer(*, accept: frozenset[Accepts]) -> Callable: +_DP = ParamSpec("_DP") +_DR = TypeVar("_DR") + + +def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]: """Opt-in: omitting it leaves the route unauthenticated. Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy @@ -360,21 +374,19 @@ def validate_bearer(*, accept: frozenset[Accepts]) -> Callable: and are rejected here as the wrong auth scheme for this surface. """ - def wrap(fn: Callable) -> Callable: + def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]: @wraps(fn) - def inner(*args, **kwargs): + def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR: token = _extract_bearer(request) if token is None: raise Unauthorized("missing bearer token") if _authenticator is None: - raise ServiceUnavailable( - "bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable" - ) + raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable") try: ctx = get_authenticator().authenticate(token) - except InvalidBearer as e: + except InvalidBearerError as e: raise Unauthorized(str(e)) if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept: @@ -388,17 +400,15 @@ def validate_bearer(*, accept: frozenset[Accepts]) -> Callable: return wrap -def bearer_feature_required(fn: Callable) -> Callable: +def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]: """503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable without the authenticator, so fail fast instead of approving silently. """ @wraps(fn) - def inner(*args, **kwargs): + def inner(*args: P.args, **kwargs: P.kwargs) -> R: if not dify_config.ENABLE_OAUTH_BEARER: - raise ServiceUnavailable( - "bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable" - ) + raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable") return fn(*args, **kwargs) return inner @@ -423,8 +433,7 @@ def require_scope(scope: str) -> Callable: ctx = getattr(g, "auth_ctx", None) if ctx is None: raise RuntimeError( - "require_scope used without validate_bearer; " - "stack @validate_bearer above @require_scope" + "require_scope used without validate_bearer; stack @validate_bearer above @require_scope" ) if SCOPE_FULL not in ctx.scopes and scope not in ctx.scopes: raise Forbidden(f"insufficient_scope: {scope}") @@ -442,22 +451,24 @@ def require_scope(scope: str) -> Callable: def build_registry(session_factory, redis_client) -> TokenKindRegistry: oauth = OAuthAccessTokenResolver(session_factory, redis_client) - return TokenKindRegistry([ - TokenKind( - prefix="dfoa_", - subject_type=SubjectType.ACCOUNT, - scopes=frozenset({"full"}), - source="oauth_account", - resolver=oauth.for_account(), - ), - TokenKind( - prefix="dfoe_", - subject_type=SubjectType.EXTERNAL_SSO, - scopes=frozenset({"apps:run"}), - source="oauth_external_sso", - resolver=oauth.for_external_sso(), - ), - ]) + return TokenKindRegistry( + [ + TokenKind( + prefix="dfoa_", + subject_type=SubjectType.ACCOUNT, + scopes=frozenset({"full"}), + source="oauth_account", + resolver=oauth.for_account(), + ), + TokenKind( + prefix="dfoe_", + subject_type=SubjectType.EXTERNAL_SSO, + scopes=frozenset({"apps:run"}), + source="oauth_external_sso", + resolver=oauth.for_external_sso(), + ), + ] + ) def build_and_bind(session_factory, redis_client) -> BearerAuthenticator: diff --git a/api/libs/rate_limit.py b/api/libs/rate_limit.py index dd9322bba6..8f43f1b312 100644 --- a/api/libs/rate_limit.py +++ b/api/libs/rate_limit.py @@ -4,13 +4,15 @@ window Redis ZSET). Apply after auth decorators so scopes can read in-handler. RFC-8628 ``slow_down`` is inline — its response shape isn't generic 429. Spec: docs/specs/v1.0/server/security.md. """ + from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta from enum import StrEnum from functools import wraps -from typing import Callable +from typing import ParamSpec, TypeVar from flask import g, request, session from werkzeug.exceptions import TooManyRequests @@ -81,13 +83,17 @@ def _build_limiter(spec: RateLimit) -> RateLimiter: ) -def rate_limit(spec: RateLimit) -> Callable: +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: """Apply after auth decorators that the scopes read from.""" limiter = _build_limiter(spec) - def wrap(fn: Callable) -> Callable: + def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) - def inner(*args, **kwargs): + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: key = _composite_key(spec.scopes) if limiter.is_rate_limited(key): raise TooManyRequests("rate_limited") diff --git a/api/models/oauth.py b/api/models/oauth.py index 5ab10fb7d0..f85448ea75 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Optional +from typing import Any import sqlalchemy as sa from sqlalchemy import func @@ -97,9 +97,7 @@ class OAuthAccessToken(TypeBase): """ __tablename__ = "oauth_access_tokens" - __table_args__ = ( - sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"), - ) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),) id: Mapped[str] = mapped_column( StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False @@ -109,15 +107,11 @@ class OAuthAccessToken(TypeBase): device_label: Mapped[str] = mapped_column(sa.Text, nullable=False) prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False) expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False) - subject_issuer: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True, default=None) - account_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True, default=None) - token_hash: Mapped[Optional[str]] = mapped_column(sa.String(64), nullable=True, default=None) - last_used_at: Mapped[Optional[datetime]] = mapped_column( - sa.DateTime(timezone=True), nullable=True, default=None - ) - revoked_at: Mapped[Optional[datetime]] = mapped_column( - sa.DateTime(timezone=True), nullable=True, default=None - ) + subject_issuer: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) + account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None) + last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None) + revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False diff --git a/api/models/workflow.py b/api/models/workflow.py index d127244b0f..23133f51dd 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1206,6 +1206,7 @@ class WorkflowAppLogCreatedFrom(StrEnum): SERVICE_API = "service-api" WEB_APP = "web-app" INSTALLED_APP = "installed-app" + OPENAPI = "openapi" @classmethod def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": diff --git a/api/schedule/clean_oauth_access_tokens_task.py b/api/schedule/clean_oauth_access_tokens_task.py index b4b7dc0236..10250e986e 100644 --- a/api/schedule/clean_oauth_access_tokens_task.py +++ b/api/schedule/clean_oauth_access_tokens_task.py @@ -3,6 +3,7 @@ expired-but-never-presented rows have no hard-expire trigger — both get pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire. """ + from __future__ import annotations import logging @@ -32,26 +33,22 @@ def clean_oauth_access_tokens_task(): candidates = or_( OAuthAccessToken.revoked_at < cutoff, # Zombies: expired but never re-presented, so middleware never flipped them. - (OAuthAccessToken.revoked_at.is_(None)) - & (OAuthAccessToken.expires_at < cutoff), + (OAuthAccessToken.revoked_at.is_(None)) & (OAuthAccessToken.expires_at < cutoff), ) total = 0 while True: - ids = db.session.scalars( - select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE) - ).all() + ids = db.session.scalars(select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)).all() if not ids: break - db.session.execute( - delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)) - ) + db.session.execute(delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids))) db.session.commit() total += len(ids) end_at = time.perf_counter() - click.echo(click.style( - f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d " - f"in {end_at - start_at:.2f}s", - fg="green", - )) + click.echo( + click.style( + f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d in {end_at - start_at:.2f}s", + fg="green", + ) + ) diff --git a/api/services/oauth_device_flow.py b/api/services/oauth_device_flow.py index 6aa12cd536..11e92f8ae9 100644 --- a/api/services/oauth_device_flow.py +++ b/api/services/oauth_device_flow.py @@ -2,6 +2,7 @@ (DB upsert + plaintext generation), and TTL policy. Specs: docs/specs/v1.0/server/{device-flow.md, tokens.md}. """ + from __future__ import annotations import hashlib @@ -15,11 +16,12 @@ from dataclasses import asdict, dataclass, field from datetime import UTC, datetime, timedelta from enum import StrEnum -from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT -from models.oauth import OAuthAccessToken from sqlalchemy import func, select from sqlalchemy.dialects.postgresql import insert as pg_insert -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, scoped_session + +from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT +from models.oauth import OAuthAccessToken logger = logging.getLogger(__name__) @@ -51,7 +53,7 @@ return raw """ DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in -APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor +APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped USER_CODE_SEGMENT_LEN = 4 @@ -95,7 +97,7 @@ class DeviceFlowState: return json.dumps(asdict(self)) @classmethod - def from_json(cls, raw: str) -> "DeviceFlowState": + def from_json(cls, raw: str) -> DeviceFlowState: data = json.loads(raw) if "status" in data: data["status"] = DeviceFlowStatus(data["status"]) @@ -114,20 +116,19 @@ def _random_user_code() -> str: return f"{_random_user_code_segment()}-{_random_user_code_segment()}" -class StateNotFound(Exception): +class StateNotFoundError(Exception): pass -class InvalidTransition(Exception): +class InvalidTransitionError(Exception): pass -class UserCodeExhausted(Exception): +class UserCodeExhaustedError(Exception): pass class DeviceFlowRedis: - def __init__(self, redis_client) -> None: self._redis = redis_client self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA) @@ -157,7 +158,7 @@ class DeviceFlowRedis: ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS) if ok: return user_code - raise UserCodeExhausted("could not allocate a unique user_code in 5 attempts") + raise UserCodeExhaustedError("could not allocate a unique user_code in 5 attempts") def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None: raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code)) @@ -180,7 +181,7 @@ class DeviceFlowRedis: try: return DeviceFlowState.from_json(text_) except (ValueError, KeyError): - logger.error("device_flow: corrupt state for %s", device_code) + logger.exception("device_flow: corrupt state for %s", device_code) return None def approve( @@ -195,9 +196,9 @@ class DeviceFlowRedis: ) -> None: state = self._load_state(device_code) if state is None: - raise StateNotFound(device_code) + raise StateNotFoundError(device_code) if state.status is not DeviceFlowStatus.PENDING: - raise InvalidTransition(f"cannot approve {state.status}") + raise InvalidTransitionError(f"cannot approve {state.status}") state.status = DeviceFlowStatus.APPROVED state.subject_email = subject_email @@ -213,9 +214,9 @@ class DeviceFlowRedis: def deny(self, device_code: str) -> None: state = self._load_state(device_code) if state is None: - raise StateNotFound(device_code) + raise StateNotFoundError(device_code) if state.status is not DeviceFlowStatus.PENDING: - raise InvalidTransition(f"cannot deny {state.status}") + raise InvalidTransitionError(f"cannot deny {state.status}") state.status = DeviceFlowStatus.DENIED self._redis.setex( DEVICE_CODE_KEY_FMT.format(code=device_code), @@ -239,7 +240,7 @@ class DeviceFlowRedis: try: return DeviceFlowState.from_json(text_) except (ValueError, KeyError): - logger.error("device_flow: corrupt state on consume %s", device_code) + logger.exception("device_flow: corrupt state on consume %s", device_code) return None def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision: @@ -287,6 +288,7 @@ ACCOUNT_ISSUER_SENTINEL = "dify:account" @dataclass(frozen=True, slots=True) class MintResult: """Plaintext token surfaces to the caller once.""" + token: str token_id: uuid.UUID expires_at: datetime @@ -308,7 +310,9 @@ def sha256_hex(token: str) -> str: def mint_oauth_token( - session: Session, + # Accept either Session or Flask-SQLAlchemy's request-scoped wrapper — + # the wrapper proxies the same execute/commit surface. + session: Session | scoped_session, redis_client, *, subject_email: str, @@ -328,9 +332,7 @@ def mint_oauth_token( # Account flow always writes the sentinel — caller may pass None # (for clarity) or the sentinel itself; nothing else is valid. if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL): - raise ValueError( - f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}" - ) + raise ValueError(f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}") subject_issuer = ACCOUNT_ISSUER_SENTINEL elif prefix == PREFIX_OAUTH_EXTERNAL_SSO: # Defense in depth: enterprise canonicalises + rejects empty, @@ -363,7 +365,7 @@ def mint_oauth_token( def _upsert( - session: Session, + session: Session | scoped_session, *, subject_email: str, subject_issuer: str | None, @@ -415,6 +417,8 @@ def _upsert( row = session.execute(upsert_stmt).first() session.commit() + if row is None: + raise RuntimeError("oauth_token upsert returned no row") token_id = uuid.UUID(str(row.id)) return UpsertOutcome( token_id=token_id, @@ -449,7 +453,9 @@ def oauth_ttl_days(tenant_id: str | None = None) -> int: except ValueError: logger.warning( "%s=%r is not an int; falling back to %d", - _TTL_ENV_VAR, raw, DEFAULT_OAUTH_TTL_DAYS, + _TTL_ENV_VAR, + raw, + DEFAULT_OAUTH_TTL_DAYS, ) return DEFAULT_OAUTH_TTL_DAYS if value < MIN_TTL_DAYS: diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py index e1f5114446..e0d286e783 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py @@ -38,12 +38,7 @@ def test_membership_strategy_uses_join_lookup(member): def test_membership_strategy_rejects_external_sso(): - assert ( - MembershipStrategy().authorize( - _ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None) - ) - is False - ) + assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False def test_app_authz_check_raises_when_strategy_denies(): diff --git a/api/tests/unit_tests/controllers/openapi/test_account.py b/api/tests/unit_tests/controllers/openapi/test_account.py index 69ec70a4ac..5a08db4964 100644 --- a/api/tests/unit_tests/controllers/openapi/test_account.py +++ b/api/tests/unit_tests/controllers/openapi/test_account.py @@ -1,4 +1,5 @@ """User-scoped identity + session endpoints under /openapi/v1/account.""" + import builtins import pytest diff --git a/api/tests/unit_tests/controllers/openapi/test_app_info.py b/api/tests/unit_tests/controllers/openapi/test_app_info.py index 3aeec7f0ca..f024e1bd3f 100644 --- a/api/tests/unit_tests/controllers/openapi/test_app_info.py +++ b/api/tests/unit_tests/controllers/openapi/test_app_info.py @@ -6,8 +6,10 @@ from flask_restx import Api def _client(): - from controllers.openapi import app_info # noqa: F401 - from controllers.openapi import openapi_ns + from controllers.openapi import ( + app_info, # noqa: F401 + openapi_ns, + ) app = Flask(__name__) api = Api(app) diff --git a/api/tests/unit_tests/controllers/openapi/test_chat_messages.py b/api/tests/unit_tests/controllers/openapi/test_chat_messages.py index 35ec43cdd2..22e45e473d 100644 --- a/api/tests/unit_tests/controllers/openapi/test_chat_messages.py +++ b/api/tests/unit_tests/controllers/openapi/test_chat_messages.py @@ -6,8 +6,10 @@ from flask_restx import Api def _client(): - from controllers.openapi import chat_messages # noqa: F401 - from controllers.openapi import openapi_ns + from controllers.openapi import ( + chat_messages, # noqa: F401 + openapi_ns, + ) app = Flask(__name__) api = Api(app) @@ -32,12 +34,11 @@ def test_chat_dispatches_and_returns_response_model(svc, bypass_pipeline): 200, ) fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1") - with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch( - "controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace() + with ( + patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), + patch("controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()), ): - r = _client().post( - "/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}} - ) + r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}}) assert r.status_code == 200 body = r.get_json() assert body["conversation_id"] == "c1" @@ -62,8 +63,9 @@ def test_chat_strips_user_field_from_body(svc, bypass_pipeline): 200, ) fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1") - with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch( - "controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace() + with ( + patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), + patch("controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()), ): _client().post( "/openapi/v1/apps/app1/chat-messages", @@ -76,9 +78,7 @@ def test_chat_strips_user_field_from_body(svc, bypass_pipeline): def test_chat_rejects_non_chat_mode(bypass_pipeline): fake = SimpleNamespace(mode="completion") with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake): - r = _client().post( - "/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}} - ) + r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}}) assert r.status_code in (400, 403) diff --git a/api/tests/unit_tests/controllers/openapi/test_completion_messages.py b/api/tests/unit_tests/controllers/openapi/test_completion_messages.py index 84fe214a26..5c4d0ee946 100644 --- a/api/tests/unit_tests/controllers/openapi/test_completion_messages.py +++ b/api/tests/unit_tests/controllers/openapi/test_completion_messages.py @@ -6,8 +6,10 @@ from flask_restx import Api def _client(): - from controllers.openapi import completion_messages # noqa: F401 - from controllers.openapi import openapi_ns + from controllers.openapi import ( + completion_messages, # noqa: F401 + openapi_ns, + ) app = Flask(__name__) api = Api(app) @@ -31,8 +33,9 @@ def test_completion_returns_response_model(svc, bypass_pipeline): 200, ) fake = SimpleNamespace(mode="completion", id="app1", tenant_id="t1") - with patch("controllers.openapi.completion_messages._unpack_app", return_value=fake), patch( - "controllers.openapi.completion_messages._unpack_caller", return_value=SimpleNamespace() + with ( + patch("controllers.openapi.completion_messages._unpack_app", return_value=fake), + patch("controllers.openapi.completion_messages._unpack_caller", return_value=SimpleNamespace()), ): r = _client().post( "/openapi/v1/apps/app1/completion-messages", diff --git a/api/tests/unit_tests/controllers/openapi/test_cors.py b/api/tests/unit_tests/controllers/openapi/test_cors.py index e13c285657..895c685da1 100644 --- a/api/tests/unit_tests/controllers/openapi/test_cors.py +++ b/api/tests/unit_tests/controllers/openapi/test_cors.py @@ -7,9 +7,9 @@ Tests use a fresh Blueprint + Flask-CORS per case because the production blueprint is a module-level singleton and can't be reconfigured once registered. """ + import builtins -import pytest from flask import Blueprint, Flask from flask.views import MethodView from flask_cors import CORS diff --git a/api/tests/unit_tests/controllers/openapi/test_device_approve_deny.py b/api/tests/unit_tests/controllers/openapi/test_device_approve_deny.py index 552a8164e8..dbe2f7bfae 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_approve_deny.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_approve_deny.py @@ -1,4 +1,5 @@ """Account-branch device-flow approve/deny under /openapi/v1.""" + import builtins import pytest diff --git a/api/tests/unit_tests/controllers/openapi/test_device_code.py b/api/tests/unit_tests/controllers/openapi/test_device_code.py index 374e5b03e3..821a423805 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_code.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_code.py @@ -4,6 +4,7 @@ authorization endpoint. Tests verify URL routing without invoking the handler — invoking would require Redis, which the unit-test runtime does not initialise. """ + import builtins import pytest @@ -31,16 +32,12 @@ def test_openapi_route_registered(openapi_app: Flask): def test_route_dispatches_to_class(openapi_app: Flask): - rule = next( - r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code" - ) + rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code") assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceCodeApi def test_route_accepts_post(openapi_app: Flask): - rule = next( - r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code" - ) + rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code") assert "POST" in rule.methods diff --git a/api/tests/unit_tests/controllers/openapi/test_device_lookup.py b/api/tests/unit_tests/controllers/openapi/test_device_lookup.py index 5a56ae5fc5..5907378a73 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_lookup.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_lookup.py @@ -1,4 +1,5 @@ """GET /openapi/v1/oauth/device/lookup is the canonical user-code lookup.""" + import builtins import pytest @@ -26,14 +27,10 @@ def test_openapi_route_registered(openapi_app: Flask): def test_route_dispatches_to_class(openapi_app: Flask): - rule = next( - r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup" - ) + rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup") assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceLookupApi def test_route_accepts_get(openapi_app: Flask): - rule = next( - r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup" - ) + rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup") assert "GET" in rule.methods diff --git a/api/tests/unit_tests/controllers/openapi/test_device_sso.py b/api/tests/unit_tests/controllers/openapi/test_device_sso.py index b40e6c1689..0125c583f0 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_sso.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_sso.py @@ -1,4 +1,5 @@ """SSO-branch device-flow endpoints under /openapi/v1/oauth/device/.""" + import builtins import pytest diff --git a/api/tests/unit_tests/controllers/openapi/test_device_token.py b/api/tests/unit_tests/controllers/openapi/test_device_token.py index 6a9577637d..8b83068856 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_token.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_token.py @@ -1,4 +1,5 @@ """POST /openapi/v1/oauth/device/token is the canonical poll endpoint.""" + import builtins import pytest @@ -26,7 +27,5 @@ def test_openapi_route_registered(openapi_app: Flask): def test_route_dispatches_to_class(openapi_app: Flask): - rule = next( - r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token" - ) + rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token") assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceTokenApi diff --git a/api/tests/unit_tests/controllers/openapi/test_workflow_run.py b/api/tests/unit_tests/controllers/openapi/test_workflow_run.py index ce0114a507..7c0de76916 100644 --- a/api/tests/unit_tests/controllers/openapi/test_workflow_run.py +++ b/api/tests/unit_tests/controllers/openapi/test_workflow_run.py @@ -6,8 +6,10 @@ from flask_restx import Api def _client(): - from controllers.openapi import openapi_ns - from controllers.openapi import workflow_run # noqa: F401 + from controllers.openapi import ( + openapi_ns, + workflow_run, # noqa: F401 + ) app = Flask(__name__) api = Api(app) @@ -32,8 +34,9 @@ def test_workflow_run_returns_response_model(svc, bypass_pipeline): 200, ) fake = SimpleNamespace(mode="workflow", id="app1", tenant_id="t1") - with patch("controllers.openapi.workflow_run._unpack_app", return_value=fake), patch( - "controllers.openapi.workflow_run._unpack_caller", return_value=SimpleNamespace() + with ( + patch("controllers.openapi.workflow_run._unpack_app", return_value=fake), + patch("controllers.openapi.workflow_run._unpack_caller", return_value=SimpleNamespace()), ): r = _client().post("/openapi/v1/apps/app1/workflows/run", json={"inputs": {"x": 1}}) assert r.status_code == 200 diff --git a/api/tests/unit_tests/controllers/openapi/test_workspaces.py b/api/tests/unit_tests/controllers/openapi/test_workspaces.py index 8e90bf27fe..9cdc13a395 100644 --- a/api/tests/unit_tests/controllers/openapi/test_workspaces.py +++ b/api/tests/unit_tests/controllers/openapi/test_workspaces.py @@ -2,6 +2,7 @@ list + member-gated detail. No legacy /v1/ equivalent — the cookie-authed /console/api/workspaces is a separate consumer that stays in console. """ + import builtins import pytest diff --git a/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py b/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py index 4545b38690..e9f26e59ea 100644 --- a/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py +++ b/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py @@ -2,6 +2,7 @@ Tests use a fake auth_ctx attached directly to flask.g — no authenticator wiring needed. """ + from __future__ import annotations import uuid