dify/api/models/oauth.py
GareArc 8a62c1d915
chore(api): pyright + ruff cleanup for openapi/cli surface
Type and lint pass over the openapi controllers, auth pipeline, and
oauth bearer/device-flow plumbing. Down from 36 pyright errors and 16
ruff errors to 0/0; 93 openapi unit tests pass.

Logic fixes:
- libs/oauth_bearer.py: drop private-naming on the friend-API methods
  consumed by _VariantResolver (cache_get / cache_set_positive /
  cache_set_negative / hard_expire / session_factory). They were always
  cross-class accessors — leading underscore was misleading. Add public
  registry property on BearerAuthenticator. _hard_expire row_id widened
  to UUID | str (matches the StringUUID column type).
- libs/oauth_bearer.py: type validate_bearer / bearer_feature_required
  with ParamSpec / PEP-695 so wrapped routes preserve their signature.
- libs/rate_limit.py: same — typed rate_limit decorator.
- services/oauth_device_flow.py: mint_oauth_token / _upsert accept
  Session | scoped_session (Flask-SQLAlchemy proxy). Guard row-is-None
  after upsert.
- controllers/openapi/{chat,completion,workflow}_messages.py: tuple-vs-
  Mapping shape narrowing on AppGenerateService.generate return —
  production returns Mapping, tests mock as (body, status). Validate
  through Pydantic Response model in both shapes.
- controllers/openapi/oauth_device.py: replace flask_restx.reqparse (banned)
  with Pydantic Request/Query models — DeviceCodeRequest, DevicePollRequest,
  DeviceLookupQuery, DeviceMutateRequest. Two PEP-695 generic helpers
  (_validate_json / _validate_query) translate ValidationError to BadRequest.
- controllers/openapi/auth/strategies.py: Protocol param-name match
  (subject_type), Optional narrowing on app/tenant/account_id/subject_email.
- controllers/openapi/auth/steps.py: subject_type-is-None guard before
  mounter dispatch.
- core/app/apps/workflow/generate_task_pipeline.py + models/workflow.py:
  add WorkflowAppLogCreatedFrom.OPENAPI + matching match-case branch.
  Fixes match-exhaustiveness and possibly-unbound created_from.
- libs/device_flow_security.py: pyright ignore on flask after_request
  hook (registered by the framework, pyright sees as unused).
- services/oauth_device_flow.py: rename Exceptions to *Error suffix
  (StateNotFoundError / InvalidTransitionError / UserCodeExhaustedError);
  same for libs/oauth_bearer.py (InvalidBearerError / TokenExpiredError).
  Update all callers across openapi controllers.
- controllers/openapi/{oauth_device,oauth_device_sso}.py +
  services/oauth_device_flow.py: switch logger.error in except blocks
  to logger.exception (TRY400) — keeps the traceback for ops.
- configs/feature/__init__.py: OPENAPI_KNOWN_CLIENT_IDS computed_field
  needs an @property alongside for pyright to see it as a value, not a
  method. Matches the existing line-451 pattern.

Plus ruff format + import-sort across the openapi tree (pure formatting).
2026-04-28 21:44:54 -07:00

119 lines
5.6 KiB
Python

from datetime import datetime
from typing import Any
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .base import TypeBase
from .types import AdjustedJSON, LongText, StringUUID
class DatasourceOauthParamConfig(TypeBase):
__tablename__ = "datasource_oauth_params"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
system_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
class DatasourceProvider(TypeBase):
__tablename__ = "datasource_providers"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
encrypted_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
class DatasourceOauthTenantParamConfig(TypeBase):
__tablename__ = "datasource_oauth_tenant_params"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
client_params: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
class OAuthAccessToken(TypeBase):
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account,
subject_issuer = "dify:account" sentinel); account_id NULL +
subject_issuer = verified IdP issuer ⇒ dfoe_ (external SSO, EE-only).
subject_issuer is non-NULL for all rows the app writes — Postgres
treats NULLs as distinct in unique indices, so the partial unique
index on (subject_email, subject_issuer, client_id, device_label)
WHERE revoked_at IS NULL would otherwise fail to rotate in place.
"""
__tablename__ = "oauth_access_tokens"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
subject_email: Mapped[str] = mapped_column(sa.Text, nullable=False)
client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False)
device_label: Mapped[str] = mapped_column(sa.Text, nullable=False)
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
subject_issuer: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None)
last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
)