feat(api,web,cli): difyctl v1.0 — OAuth device flow, /openapi/v1 auth pipeline, CLI client

This commit is contained in:
GareArc 2026-05-11 18:40:39 -07:00
parent 8f070f2190
commit 6779366dca
No known key found for this signature in database
333 changed files with 28550 additions and 101 deletions

4
.github/CODEOWNERS vendored
View File

@ -15,6 +15,10 @@
# Docs
/docs/ @crazywoola
# CLI
/cli/ @langgenius/maintainers
/.github/workflows/cli-tests.yml @langgenius/maintainers
# Backend (default owner, more specific rules below will override)
/api/ @QuantumGhost

131
.github/workflows/cli-release.yml vendored Normal file
View File

@ -0,0 +1,131 @@
name: CLI Release
on:
release:
types: [published]
workflow_dispatch:
inputs:
dify_release_tag:
description: "dify release tag to attach cli artifacts to (e.g. 1.14.0). Bare semver — dify tags are NOT v-prefixed."
type: string
required: true
concurrency:
group: cli-release-${{ github.event.release.tag_name || inputs.dify_release_tag }}
cancel-in-progress: true
jobs:
release:
runs-on: ubuntu-latest
if: >-
github.repository == 'langgenius/dify' &&
(github.event_name == 'workflow_dispatch' ||
(vars.CLI_AUTO_RELEASE == 'true' && !github.event.release.prerelease))
env:
DIFY_TAG: ${{ github.event.release.tag_name || inputs.dify_release_tag }}
permissions:
contents: write
id-token: write
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
fetch-depth: 0
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Setup Node registry auth
uses: actions/setup-node@v4
with:
node-version-file: .nvmrc
registry-url: 'https://registry.npmjs.org'
- name: Read cli/package.json
id: manifest
run: |
version=$(node -p "require('./package.json').version")
channel=$(node -p "require('./package.json').difyctl.channel")
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
{
echo "version=$version"
echo "channel=$channel"
echo "minDify=$minDify"
echo "maxDify=$maxDify"
} >> "$GITHUB_OUTPUT"
- name: Validate manifest
run: scripts/release-validate-manifest.sh
- name: Bump guard (auto-path only)
if: github.event_name == 'release'
run: scripts/release-bump-guard.sh
env:
NEW_VERSION: ${{ steps.manifest.outputs.version }}
NEW_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
NEW_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
- name: Build cli
run: |
DIFYCTL_VERSION="${{ steps.manifest.outputs.version }}" \
DIFYCTL_CHANNEL="${{ steps.manifest.outputs.channel }}" \
DIFYCTL_MIN_DIFY="${{ steps.manifest.outputs.minDify }}" \
DIFYCTL_MAX_DIFY="${{ steps.manifest.outputs.maxDify }}" \
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
pnpm build
- name: Pack tarballs
run: pnpm pack:tarballs
- name: Verify target dify release exists
run: gh release view "$DIFY_TAG" --repo langgenius/dify > /dev/null
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Publish to npm (idempotent)
run: scripts/release-npm-publish.sh
env:
CHANNEL: ${{ steps.manifest.outputs.channel }}
NEW_VERSION: ${{ steps.manifest.outputs.version }}
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
- name: Generate sha256 checksum file
run: scripts/release-write-checksums.sh
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
- name: Install cosign
uses: sigstore/cosign-installer@v3
- name: Keyless-sign tarballs + checksum file (Sigstore)
run: scripts/release-cosign-sign.sh
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
COSIGN_EXPERIMENTAL: '1'
- name: Snapshot tarballs + checksum + signatures as workflow artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: difyctl-${{ steps.manifest.outputs.version }}-${{ env.DIFY_TAG }}
path: |
cli/dist/difyctl-v*.tar.xz
cli/dist/difyctl-v*-checksums.txt
cli/dist/difyctl-v*.sig
cli/dist/difyctl-v*.pem
retention-days: 90
if-no-files-found: error
- name: Upload tarballs + checksum + signatures to dify GH release (idempotent)
run: scripts/release-upload-tarballs.sh
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

57
.github/workflows/cli-smoke.yml vendored Normal file
View File

@ -0,0 +1,57 @@
name: CLI Smoke (live dify)
on:
workflow_dispatch:
inputs:
dify_version:
description: "Dify image tag to test against (e.g. 1.7.0)"
type: string
required: true
cli_ref:
description: "Git ref to build the cli from (default: current branch)"
type: string
required: false
jobs:
smoke:
runs-on: ubuntu-latest
timeout-minutes: 30
defaults:
run:
shell: bash
steps:
- name: Checkout cli ref
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Bring up dify
env:
DIFY_VERSION: ${{ inputs.dify_version }}
run: |
cd docker
cp .env.example .env
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
docker compose up -d api worker web db redis
for i in $(seq 1 60); do
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
echo "dify api ready after ${i}s"
break
fi
sleep 1
done
- name: Run smoke against live dify
working-directory: ./cli
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
- name: Dump dify logs on failure
if: failure()
run: |
cd docker
docker compose logs api worker web --tail=200

46
.github/workflows/cli-tests.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: CLI Tests
on:
workflow_call:
secrets:
CODECOV_TOKEN:
required: false
permissions:
contents: read
concurrency:
group: cli-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: CLI Tests
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: CI pipeline (typecheck, lint, coverage, build)
run: make ci
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: cli/coverage
flags: cli
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

View File

@ -42,6 +42,7 @@ jobs:
runs-on: ubuntu-latest
outputs:
api-changed: ${{ steps.changes.outputs.api }}
cli-changed: ${{ steps.changes.outputs.cli }}
e2e-changed: ${{ steps.changes.outputs.e2e }}
web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }}
@ -63,6 +64,16 @@ jobs:
- 'docker/generate_docker_compose'
- 'docker/ssrf_proxy/**'
- 'docker/volumes/sandbox/conf/**'
cli:
- 'cli/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- 'eslint.config.mjs'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/cli-tests.yml'
- '.github/actions/setup-web/**'
web:
- 'web/**'
- 'packages/**'
@ -186,6 +197,66 @@ jobs:
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
cli-tests-run:
name: Run CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
uses: ./.github/workflows/cli-tests.yml
secrets: inherit
cli-tests-skip:
name: Skip CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
runs-on: ubuntu-latest
steps:
- name: Report skipped CLI tests
run: echo "No CLI-related changes detected; skipping CLI tests."
cli-tests:
name: CLI Tests
if: ${{ always() }}
needs:
- pre_job
- check-changes
- cli-tests-run
- cli-tests-skip
runs-on: ubuntu-latest
steps:
- name: Finalize CLI Tests status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
RUN_RESULT: ${{ needs.cli-tests-run.result }}
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "CLI tests ran successfully."
exit 0
fi
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "CLI tests were skipped because no CLI-related files changed."
exit 0
fi
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
web-tests-run:
name: Run Web Tests
needs:

7
.gitignore vendored
View File

@ -115,6 +115,12 @@ venv/
ENV/
env.bak/
venv.bak/
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
!/cli/src/env/
!/cli/src/commands/env/
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
!/cli/scripts/lib/
.conda/
# Spyder project settings
@ -240,6 +246,7 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md
*.local.toml
# Code Agent Folder
.qoder/*

View File

@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
ext_logstore,
ext_mail,
ext_migrate,
ext_oauth_bearer,
ext_orjson,
ext_otel,
ext_proxy_fix,
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
ext_oauth_bearer,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

View File

@ -499,6 +499,35 @@ class HttpConfig(BaseSettings):
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
description=(
"Comma-separated allowlist for /openapi/v1/* CORS. "
"Default empty = same-origin only. Browser-cookie routes within "
"the group reject cross-origin OPTIONS regardless of this list."
),
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
default="",
)
@computed_field
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
description=(
"Comma-separated client_id values accepted at "
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
"without code changes. Unknown client_id returns 400 unsupported_client."
),
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
default="difyctl",
)
@computed_field # type: ignore[misc]
@property
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
)
@ -874,6 +903,17 @@ class AuthConfig(BaseSettings):
default=86400,
)
ENABLE_OAUTH_BEARER: bool = Field(
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
default=True,
)
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
default=60,
)
class ModerationConfig(BaseSettings):
"""
@ -1148,6 +1188,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable scheduled workflow run cleanup task",
default=False,
)
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
default=True,
)
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
description="Days to retain revoked OAuth access-token rows before deletion.",
default=30,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,

View File

@ -0,0 +1,41 @@
from flask import Blueprint
from flask_restx import Namespace
from libs.device_flow_security import attach_anti_framing
from libs.external_api import ExternalApi
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
attach_anti_framing(bp)
api = ExternalApi(
bp,
version="1.0",
title="OpenAPI",
description="User-scoped programmatic API (bearer auth)",
)
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
from . import (
account,
app_run,
apps,
apps_permitted_external,
index,
oauth_device,
oauth_device_sso,
workspaces,
)
__all__ = [
"account",
"app_run",
"apps",
"apps_permitted_external",
"index",
"oauth_device",
"oauth_device_sso",
"workspaces",
]
api.add_namespace(openapi_ns)

View File

@ -0,0 +1,66 @@
"""Audit emission for openapi app-run endpoints.
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
matches the existing oauth_device convention. The EE OTel exporter consults
its own allowlist to decide whether to ship the line.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
def emit_app_run(
*,
app_id: str,
tenant_id: str,
caller_kind: str,
mode: str,
surface: str,
) -> None:
logger.info(
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
EVENT_APP_RUN_OPENAPI,
app_id,
tenant_id,
caller_kind,
mode,
surface,
extra={
"audit": True,
"event": EVENT_APP_RUN_OPENAPI,
"app_id": app_id,
"tenant_id": tenant_id,
"caller_kind": caller_kind,
"mode": mode,
"surface": surface,
},
)
def emit_wrong_surface(
*,
subject_type: str | None,
attempted_path: str,
client_id: str | None,
token_id: str | None,
) -> None:
logger.warning(
"audit: %s subject_type=%s attempted_path=%s",
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
subject_type,
attempted_path,
extra={
"audit": True,
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
"subject_type": subject_type,
"attempted_path": attempted_path,
"client_id": client_id,
"token_id": token_id,
},
)

View File

@ -0,0 +1,143 @@
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
from __future__ import annotations
from typing import Any, cast
from controllers.service_api.app.error import AppUnavailableError
from models import App
from models.model import AppMode
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
"$schema": JSON_SCHEMA_DRAFT,
"type": "object",
"properties": {},
"required": [],
}
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
def _file_object_shape() -> dict[str, Any]:
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
return {
"type": "object",
"properties": {
"type": {"type": "string"},
"transfer_method": {"type": "string"},
"url": {"type": "string"},
"upload_file_id": {"type": "string"},
},
"additionalProperties": True,
}
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
label = row.get("label") or row.get("variable", "")
base: dict[str, Any] = {"title": label} if label else {}
if row_type in ("text-input", "paragraph"):
out = {"type": "string"} | base
max_length = row.get("max_length")
if isinstance(max_length, int) and max_length > 0:
out["maxLength"] = max_length
return out
if row_type == "select":
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
if row_type == "number":
return {"type": "number"} | base
if row_type == "file":
return _file_object_shape() | base
if row_type == "file-list":
return {
"type": "array",
"items": _file_object_shape(),
} | base
return None
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
"""Translate a user_input_form row list into (properties, required-list).
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
Unknown variable types are skipped (forward-compat).
"""
properties: dict[str, Any] = {}
required: list[str] = []
for row in form:
if not isinstance(row, dict) or len(row) != 1:
continue
((row_type, row_body),) = row.items()
if not isinstance(row_body, dict):
continue
variable = row_body.get("variable")
if not variable:
continue
schema = _row_to_schema(row_type, row_body)
if schema is None:
continue
properties[variable] = schema
if row_body.get("required"):
required.append(variable)
return properties, required
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
Raises `AppUnavailableError` on misconfigured apps.
"""
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app.workflow
if workflow is None:
raise AppUnavailableError()
return (
workflow.features_dict,
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
)
app_model_config = app.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
def build_input_schema(app: App) -> dict[str, Any]:
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
completion / workflow: `inputs` object only.
Raises `AppUnavailableError` on misconfigured apps.
"""
_, user_input_form = resolve_app_config(app)
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
properties: dict[str, Any] = {}
required: list[str] = []
if app.mode in _CHAT_FAMILY:
properties["query"] = {"type": "string", "minLength": 1}
required.append("query")
properties["inputs"] = {
"type": "object",
"properties": inputs_props,
"required": inputs_required,
"additionalProperties": False,
}
required.append("inputs")
return {
"$schema": JSON_SCHEMA_DRAFT,
"type": "object",
"properties": properties,
"required": required,
}

View File

@ -0,0 +1,112 @@
"""Shared response substructures for openapi endpoints."""
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, Field
# Server-side cap on `limit` query param for any /openapi/v1/* list endpoint.
# Sibling endpoints (`/apps`, `/account/sessions`, future routes) all clamp to
# this; do not introduce per-endpoint caps without raising the constant.
MAX_PAGE_LIMIT = 200
class UsageInfo(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class MessageMetadata(BaseModel):
usage: UsageInfo | None = None
retriever_resources: list[dict[str, Any]] = []
class PaginationEnvelope[T](BaseModel):
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
page: int
limit: int
total: int
has_more: bool
data: list[T]
@classmethod
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
class AppListRow(BaseModel):
id: str
name: str
description: str | None = None
mode: str
tags: list[dict[str, str]] = []
updated_at: str | None = None
created_by_name: str | None = None
workspace_id: str | None = None
workspace_name: str | None = None
class AppInfoResponse(BaseModel):
id: str
name: str
description: str | None = None
mode: str
author: str | None = None
tags: list[dict[str, str]] = []
class AppDescribeInfo(AppInfoResponse):
updated_at: str | None = None
service_api_enabled: bool
class AppDescribeResponse(BaseModel):
info: AppDescribeInfo | None = None
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
class ChatMessageResponse(BaseModel):
event: str
task_id: str
id: str
message_id: str
conversation_id: str
mode: str
answer: str
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
created_at: int
class CompletionMessageResponse(BaseModel):
event: str
task_id: str
id: str
message_id: str
mode: str
answer: str
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
created_at: int
class WorkflowRunData(BaseModel):
id: str
workflow_id: str
status: str
outputs: dict[str, Any] = Field(default_factory=dict)
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
class WorkflowRunResponse(BaseModel):
workflow_run_id: str
task_id: str
mode: Literal["workflow"] = "workflow"
data: WorkflowRunData

View File

@ -0,0 +1,236 @@
"""User-scoped account endpoints. /account is the bearer-authed
identity read; /account/sessions and /account/sessions/<id> manage
the user's active OAuth tokens.
"""
from __future__ import annotations
from datetime import UTC, datetime
from flask import g, request
from flask_restx import Resource
from sqlalchemy import and_, select, update
from werkzeug.exceptions import BadRequest, NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import MAX_PAGE_LIMIT, PaginationEnvelope
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
TOKEN_CACHE_KEY_FMT,
AuthContext,
SubjectType,
validate_bearer,
)
from libs.rate_limit import (
LIMIT_ME_PER_ACCOUNT,
LIMIT_ME_PER_EMAIL,
enforce,
)
from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin
@openapi_ns.route("/account")
class AccountApi(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = g.auth_ctx
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
else:
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return {
"subject_type": ctx.subject_type,
"subject_email": ctx.subject_email,
"subject_issuer": ctx.subject_issuer,
"account": None,
"workspaces": [],
"default_workspace_id": None,
}
account = (
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() if ctx.account_id else None
)
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
default_ws_id = _pick_default_workspace(memberships)
return {
"subject_type": ctx.subject_type,
"subject_email": ctx.subject_email or (account.email if account else None),
"account": _account_payload(account) if account else None,
"workspaces": [_workspace_payload(m) for m in memberships],
"default_workspace_id": default_ws_id,
}
@openapi_ns.route("/account/sessions/self")
class AccountSessionsSelfApi(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self):
ctx = g.auth_ctx
_require_oauth_subject(ctx)
_revoke_token_by_id(str(ctx.token_id))
return {"status": "revoked"}, 200
@openapi_ns.route("/account/sessions")
class AccountSessionsApi(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = g.auth_ctx
now = datetime.now(UTC)
page = int(request.args.get("page", "1"))
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
all_rows = db.session.execute(
select(
OAuthAccessToken.id,
OAuthAccessToken.prefix,
OAuthAccessToken.client_id,
OAuthAccessToken.device_label,
OAuthAccessToken.created_at,
OAuthAccessToken.last_used_at,
OAuthAccessToken.expires_at,
)
.where(
and_(
*_subject_match(ctx),
OAuthAccessToken.revoked_at.is_(None),
OAuthAccessToken.token_hash.is_not(None),
OAuthAccessToken.expires_at > now,
)
)
.order_by(OAuthAccessToken.created_at.desc())
).all()
total = len(all_rows)
sliced = all_rows[(page - 1) * limit : page * limit]
items = [
{
"id": str(r.id),
"prefix": r.prefix,
"client_id": r.client_id,
"device_label": r.device_label,
"created_at": _iso(r.created_at),
"last_used_at": _iso(r.last_used_at),
"expires_at": _iso(r.expires_at),
}
for r in sliced
]
return (
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
200,
)
@openapi_ns.route("/account/sessions/<string:session_id>")
class AccountSessionByIdApi(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self, session_id: str):
ctx = g.auth_ctx
_require_oauth_subject(ctx)
# Subject-match guard. 404 (not 403) on cross-subject so the
# endpoint doesn't leak token IDs that belong to other subjects.
owns = db.session.execute(
select(OAuthAccessToken.id).where(
and_(
OAuthAccessToken.id == session_id,
*_subject_match(ctx),
)
)
).first()
if owns is None:
raise NotFound("session not found")
_revoke_token_by_id(session_id)
return {"status": "revoked"}, 200
def _subject_match(ctx: AuthContext) -> tuple:
"""Where-clauses that scope a query to the bearer's subject. Works
for both account (account_id) and external_sso (email + issuer).
"""
if ctx.subject_type == SubjectType.ACCOUNT:
return (OAuthAccessToken.account_id == str(ctx.account_id),)
return (
OAuthAccessToken.subject_email == ctx.subject_email,
OAuthAccessToken.subject_issuer == ctx.subject_issuer,
OAuthAccessToken.account_id.is_(None),
)
def _require_oauth_subject(ctx: AuthContext) -> None:
if not ctx.source.startswith("oauth"):
raise BadRequest(
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
)
def _revoke_token_by_id(token_id: str) -> None:
# Snapshot pre-revoke hash for cache invalidation; UPDATE WHERE
# makes double-revoke idempotent.
row = (
db.session.query(OAuthAccessToken.token_hash)
.filter(
OAuthAccessToken.id == token_id,
OAuthAccessToken.revoked_at.is_(None),
)
.one_or_none()
)
pre_revoke_hash = row[0] if row else None
stmt = (
update(OAuthAccessToken)
.where(
OAuthAccessToken.id == token_id,
OAuthAccessToken.revoked_at.is_(None),
)
.values(revoked_at=datetime.now(UTC), token_hash=None)
)
db.session.execute(stmt)
db.session.commit()
if pre_revoke_hash:
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))
def _iso(dt: datetime | None) -> str | None:
if dt is None:
return None
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
return dt.isoformat().replace("+00:00", "Z")
def _load_memberships(account_id):
return (
db.session.query(TenantAccountJoin, Tenant)
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
.filter(TenantAccountJoin.account_id == account_id)
.all()
)
def _pick_default_workspace(memberships) -> str | None:
if not memberships:
return None
for join, tenant in memberships:
if getattr(join, "current", False):
return str(tenant.id)
return str(memberships[0][1].id)
def _workspace_payload(row) -> dict:
join, tenant = row
return {"id": str(tenant.id), "name": tenant.name, "role": getattr(join, "role", "")}
def _account_payload(account) -> dict:
return {"id": str(account.id), "email": account.email, "name": account.name}

View File

@ -0,0 +1,200 @@
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
from __future__ import annotations
import logging
from collections.abc import Callable, Iterator, Mapping
from contextlib import contextmanager
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, ValidationError, field_validator
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
import services
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._models import (
ChatMessageResponse,
CompletionMessageResponse,
WorkflowRunResponse,
)
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
ConversationCompletedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.oauth_bearer import Scope
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import (
IsDraftWorkflowError,
WorkflowIdFormatError,
WorkflowNotFoundError,
)
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
class AppRunRequest(BaseModel):
inputs: dict[str, Any]
query: str | None = None
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: UUIDStrOrEmpty | None = None
auto_generate_name: bool = True
workflow_id: str | None = None
workspace_id: UUIDStrOrEmpty | None = None
@field_validator("conversation_id", mode="before")
@classmethod
def _normalize_conv(cls, value: str | UUID | None) -> str | None:
if isinstance(value, str):
value = value.strip()
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
@contextmanager
def _translate_service_errors() -> Iterator[None]:
try:
yield
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
raise BadRequest(str(ex))
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
def _unpack_blocking(response: Any) -> Mapping[str, Any]:
if isinstance(response, tuple):
response = response[0]
if not isinstance(response, Mapping):
raise InternalServerError("blocking generate returned non-mapping response")
return response
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
return AppGenerateService.generate(
app_model=app,
user=caller,
args=args,
invoke_from=InvokeFrom.OPENAPI,
streaming=streaming,
)
def _run_chat(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
if not payload.query or not payload.query.strip():
raise UnprocessableEntity("query_required_for_chat")
args = payload.model_dump(exclude_none=True)
with _translate_service_errors():
response = _generate(app, caller, args, streaming)
if streaming:
return response, None
return None, ChatMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
def _run_completion(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
args = payload.model_dump(exclude_none=True)
args["auto_generate_name"] = False
args.setdefault("query", "")
with _translate_service_errors():
response = _generate(app, caller, args, streaming)
if streaming:
return response, None
return None, CompletionMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
def _run_workflow(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
if payload.query is not None:
raise UnprocessableEntity("query_not_supported_for_workflow")
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
with _translate_service_errors():
response = _generate(app, caller, args, streaming)
if streaming:
return response, None
return None, WorkflowRunResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest, bool], tuple[Any, dict[str, Any] | None]]] = {
AppMode.CHAT: _run_chat,
AppMode.AGENT_CHAT: _run_chat,
AppMode.ADVANCED_CHAT: _run_chat,
AppMode.COMPLETION: _run_completion,
AppMode.WORKFLOW: _run_workflow,
}
@openapi_ns.route("/apps/<string:app_id>/run")
class AppRunApi(Resource):
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
body = request.get_json(silent=True) or {}
body.pop("user", None)
try:
payload = AppRunRequest.model_validate(body)
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
handler = _DISPATCH.get(app_model.mode)
if handler is None:
raise UnprocessableEntity("mode_not_runnable")
streaming = payload.response_mode == "streaming"
try:
stream_obj, blocking_body = handler(app_model, caller, payload, streaming)
except HTTPException:
raise
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
emit_app_run(
app_id=app_model.id,
tenant_id=app_model.tenant_id,
caller_kind=caller_kind,
mode=str(app_model.mode),
surface="apps",
)
if streaming:
return helper.compact_generate_response(stream_obj)
return blocking_body, 200

View File

@ -0,0 +1,330 @@
"""GET /openapi/v1/apps and per-app reads.
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
is last outermost sets `g.auth_ctx` before `require_scope` reads it.
"""
from __future__ import annotations
import uuid as _uuid
from typing import Any
import sqlalchemy as sa
from flask import g, request
from flask_restx import Resource
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
from controllers.common.fields import Parameters
from controllers.openapi import openapi_ns
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
from controllers.openapi._models import (
MAX_PAGE_LIMIT,
AppDescribeInfo,
AppDescribeResponse,
AppListRow,
PaginationEnvelope,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from controllers.service_api.app.error import AppUnavailableError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
Scope,
SubjectType,
require_scope,
require_workspace_member,
validate_bearer,
)
from models import App, Tenant
from models.model import AppMode
from services.app_service import AppService
from services.openapi.visibility import apply_openapi_gate, is_openapi_visible
from services.tag_service import TagService
# method_decorators applies left-to-right innermost-first; flask_restx wraps
# in order, so the LAST entry is the outermost. Execution flows
# validate_bearer → accept_subjects → require_scope → handler.
_APPS_READ_DECORATORS = [
require_scope(Scope.APPS_READ),
accept_subjects(SubjectType.ACCOUNT),
validate_bearer(accept=ACCEPT_USER_ANY),
]
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
class AppDescribeQuery(BaseModel):
"""`?fields=` allow-list for GET /apps/<id>/describe.
Empty / omitted all blocks. Unknown member ValidationError 422.
"""
model_config = ConfigDict(extra="forbid")
fields: set[str] | None = None
workspace_id: str | None = None
@field_validator("workspace_id", mode="before")
@classmethod
def _validate_workspace_id(cls, v: object) -> str | None:
if v is None or v == "":
return None
if not isinstance(v, str):
raise ValueError("workspace_id must be a string")
try:
_uuid.UUID(v)
except ValueError:
raise ValueError("workspace_id must be a valid UUID")
return v
@field_validator("fields", mode="before")
@classmethod
def _parse_fields(cls, v: object) -> set[str] | None:
if v is None or v == "":
return None
if not isinstance(v, str):
raise ValueError("fields must be a comma-separated string")
members = {m.strip() for m in v.split(",") if m.strip()}
unknown = members - _ALLOWED_DESCRIBE_FIELDS
if unknown:
raise ValueError(f"unknown field(s): {sorted(unknown)}")
return members
_EMPTY_PARAMETERS: dict[str, Any] = {
"opening_statement": None,
"suggested_questions": [],
"user_input_form": [],
"file_upload": None,
"system_parameters": {},
}
class AppReadResource(Resource):
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
method_decorators = _APPS_READ_DECORATORS
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
ctx: AuthContext = g.auth_ctx
try:
parsed_uuid = _uuid.UUID(app_id)
is_uuid = True
except ValueError:
parsed_uuid = None
is_uuid = False
if is_uuid:
app = db.session.get(App, str(parsed_uuid)) # normalised dashed form
if not app or app.status != "normal" or not is_openapi_visible(app):
raise NotFound("app not found")
else:
if not workspace_id:
raise UnprocessableEntity("workspace_id is required for name-based lookup")
matches = list(
db.session.execute(
apply_openapi_gate(
sa.select(App).where(
App.name == app_id,
App.tenant_id == workspace_id,
App.status == "normal",
)
)
).scalars()
)
if len(matches) == 0:
raise NotFound("app not found")
if len(matches) > 1:
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
for m in matches:
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
raise Conflict("".join(lines))
app = matches[0]
require_workspace_member(ctx, str(app.tenant_id))
return app, ctx
def parameters_payload(app: App) -> dict:
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
features_dict, user_input_form = resolve_app_config(app)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return Parameters.model_validate(parameters).model_dump(mode="json")
@openapi_ns.route("/apps/<string:app_id>/describe")
class AppDescribeApi(AppReadResource):
def get(self, app_id: str):
try:
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
app, _ = self._load(app_id, workspace_id=query.workspace_id)
requested = query.fields
want_info = requested is None or "info" in requested
want_params = requested is None or "parameters" in requested
want_schema = requested is None or "input_schema" in requested
info = (
AppDescribeInfo(
id=str(app.id),
name=app.name,
mode=app.mode,
description=app.description,
tags=[{"name": t.name} for t in app.tags],
author=app.author_name,
updated_at=app.updated_at.isoformat() if app.updated_at else None,
service_api_enabled=bool(app.enable_api),
)
if want_info
else None
)
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
if want_params:
try:
parameters = parameters_payload(app)
except AppUnavailableError:
parameters = dict(_EMPTY_PARAMETERS)
if want_schema:
try:
input_schema = build_input_schema(app)
except AppUnavailableError:
input_schema = dict(EMPTY_INPUT_SCHEMA)
return (
AppDescribeResponse(
info=info,
parameters=parameters,
input_schema=input_schema,
).model_dump(mode="json", exclude_none=False),
200,
)
class AppListQuery(BaseModel):
"""`mode` is a closed enum — unknown values 422 instead of silently-empty data."""
workspace_id: str
page: int = Field(1, ge=1)
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)
tag: str | None = Field(None, max_length=100)
@openapi_ns.route("/apps")
class AppListApi(Resource):
method_decorators = _APPS_READ_DECORATORS
def get(self):
ctx: AuthContext = g.auth_ctx
try:
query = AppListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
workspace_id = query.workspace_id
require_workspace_member(ctx, workspace_id)
empty = (
PaginationEnvelope[AppListRow]
.build(page=query.page, limit=query.limit, total=0, items=[])
.model_dump(mode="json"),
200,
)
if query.name:
try:
parsed_uuid = _uuid.UUID(query.name)
except ValueError:
parsed_uuid = None
else:
parsed_uuid = None
if parsed_uuid is not None:
app = db.session.get(App, str(parsed_uuid))
if (
not app
or app.status != "normal"
or str(app.tenant_id) != workspace_id
or not is_openapi_visible(app)
):
return empty
tenant_name = db.session.execute(
sa.select(Tenant.name).where(Tenant.id == workspace_id)
).scalar_one_or_none()
item = AppListRow(
id=str(app.id),
name=app.name,
description=app.description,
mode=app.mode,
tags=[{"name": t.name} for t in app.tags],
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=getattr(app, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)
env = PaginationEnvelope[AppListRow].build(page=1, limit=1, total=1, items=[item])
return env.model_dump(mode="json"), 200
tag_ids: list[str] | None = None
if query.tag:
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
if not tags:
return empty
tag_ids = [tag.id for tag in tags]
args: dict[str, Any] = {
"page": query.page,
"limit": query.limit,
"mode": query.mode.value if query.mode else "",
"name": query.name,
"status": "normal",
# Visibility gate pushed into the query — pagination.total stays
# consistent across pages because invisible rows never count.
"openapi_visible": True,
}
if tag_ids:
args["tag_ids"] = tag_ids
pagination = AppService().get_paginate_apps(ctx.account_id, workspace_id, args)
if pagination is None:
return empty
tenant_name: str | None = None
if pagination.items:
tenant_name = db.session.execute(
sa.select(Tenant.name).where(Tenant.id == workspace_id)
).scalar_one_or_none()
items = [
AppListRow(
id=str(r.id),
name=r.name,
description=r.description,
mode=r.mode,
tags=[{"name": t.name} for t in r.tags],
updated_at=r.updated_at.isoformat() if r.updated_at else None,
created_by_name=getattr(r, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)
for r in pagination.items
]
env = PaginationEnvelope[AppListRow].build(
page=query.page,
limit=query.limit,
total=int(pagination.total),
items=items,
)
return env.model_dump(mode="json"), 200

View File

@ -0,0 +1,121 @@
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
(public / sso_verified). License-gated: CE deploys never enable the
EE blueprint chain so this module is unreachable there.
"""
from __future__ import annotations
import sqlalchemy as sa
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from werkzeug.exceptions import UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
MAX_PAGE_LIMIT,
AppListRow,
PaginationEnvelope,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.device_flow_security import enterprise_only
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
Scope,
SubjectType,
require_scope,
validate_bearer,
)
from models import App, Tenant
from models.model import AppMode
from services.enterprise.app_permitted_service import list_permitted_apps
from services.openapi.license_gate import license_required
from services.openapi.visibility import apply_openapi_gate
class PermittedExternalAppsListQuery(BaseModel):
"""Strict (`extra='forbid'`) — rejects `workspace_id`/`tag`/etc. that are valid on /apps but not here."""
model_config = ConfigDict(extra="forbid")
page: int = Field(1, ge=1)
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)
@openapi_ns.route("/permitted-external-apps")
class PermittedExternalAppsListApi(Resource):
# method_decorators applies left-to-right innermost-first; execution
# flows enterprise_only → validate_bearer → accept_subjects →
# license_required → require_scope → handler. validate_bearer is
# widened to ACCEPT_USER_ANY so accept_subjects can emit the
# `openapi.wrong_surface_denied` audit on dfoa_→external misses
# instead of validate_bearer rejecting silently with "subject type
# not accepted here".
method_decorators = [
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
license_required,
accept_subjects(SubjectType.EXTERNAL_SSO),
validate_bearer(accept=ACCEPT_USER_ANY),
enterprise_only,
]
def get(self):
try:
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
page_result = list_permitted_apps(
page=query.page,
limit=query.limit,
mode=query.mode.value if query.mode else None,
name=query.name,
)
if not page_result.app_ids:
env = PaginationEnvelope[AppListRow].build(
page=query.page, limit=query.limit, total=page_result.total, items=[]
)
return env.model_dump(mode="json"), 200
apps_by_id = {
str(a.id): a
for a in db.session.execute(
apply_openapi_gate(sa.select(App).where(App.id.in_(page_result.app_ids)))
).scalars().all()
}
tenant_ids = list({a.tenant_id for a in apps_by_id.values()})
tenants_by_id = {
str(t.id): t for t in db.session.execute(sa.select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all()
}
items: list[AppListRow] = []
for app_id in page_result.app_ids:
app = apps_by_id.get(app_id)
if not app or app.status != "normal":
continue
tenant = tenants_by_id.get(str(app.tenant_id))
items.append(
AppListRow(
id=str(app.id),
name=app.name,
description=app.description,
mode=app.mode,
tags=[], # tenant-scoped; not surfaced cross-tenant
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=None, # cross-tenant author leak prevention
workspace_id=str(app.tenant_id),
workspace_name=tenant.name if tenant else None,
)
)
# total/has_more reflect the EE-side allow-list; len(items) may be < limit when local rows are dropped.
env = PaginationEnvelope[AppListRow].build(
page=query.page, limit=query.limit, total=page_result.total, items=items
)
return env.model_dump(mode="json"), 200

View File

@ -0,0 +1,3 @@
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
__all__ = ["OAUTH_BEARER_PIPELINE"]

View File

@ -0,0 +1,52 @@
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=)`. No alternative
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
inline they don't need `AppAuthzCheck`/`CallerMount`.
"""
from __future__ import annotations
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
AppResolver,
BearerCheck,
CallerMount,
ScopeCheck,
SurfaceCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
AclStrategy,
AppAuthzStrategy,
EndUserMounter,
MembershipStrategy,
)
from libs.oauth_bearer import SubjectType
from services.feature_service import FeatureService
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
if FeatureService.get_system_features().webapp_auth.enabled:
return AclStrategy()
return MembershipStrategy()
# Pipeline currently serves only `/openapi/v1/apps/<id>/run` — an account
# (dfoa_) surface route. SurfaceCheck runs right after BearerCheck so
# pipeline-guarded routes get the same wrong_surface 403 + audit emit as
# the inline `@accept_subjects` decorator on read endpoints. When the
# external-surface run route lands, swap in an external-pipeline builder
# that constructs SurfaceCheck(accepted=frozenset({USER_EXT_SSO})).
OAUTH_BEARER_PIPELINE = Pipeline(
BearerCheck(),
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
ScopeCheck(),
AppResolver(),
WorkspaceMembershipCheck(),
AppAuthzCheck(_resolve_app_authz_strategy),
CallerMount(AccountMounter(), EndUserMounter()),
)

View File

@ -0,0 +1,46 @@
"""Mutable per-request context for the openapi auth pipeline.
Every field starts None / empty and is filled in by a step. The pipeline
is the only thing that should construct or mutate Context handlers
read populated values via the decorator's kwargs unpacking.
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Literal, Protocol
from flask import Request
from libs.oauth_bearer import Scope, SubjectType
if TYPE_CHECKING:
from models import App, Tenant
@dataclass
class Context:
request: Request
required_scope: Scope
subject_type: SubjectType | None = None
subject_email: str | None = None
subject_issuer: str | None = None
account_id: uuid.UUID | None = None
scopes: frozenset[Scope] = field(default_factory=frozenset)
token_id: uuid.UUID | None = None
token_hash: str | None = None
cached_verified_tenants: dict[str, bool] | None = None
source: str | None = None
expires_at: datetime | None = None
app: App | None = None
tenant: Tenant | None = None
caller: object | None = None
caller_kind: Literal["account", "end_user"] | None = None
class Step(Protocol):
"""One responsibility. Mutate ctx or raise to short-circuit."""
def __call__(self, ctx: Context) -> None: ...

View File

@ -0,0 +1,41 @@
"""Pipeline IS the auth scheme.
`Pipeline.guard(scope=)` is the only attachment point for endpoints
that is the design lock-in: forgetting an auth layer is structurally
impossible because there is no "sometimes wrap, sometimes don't" choice.
"""
from __future__ import annotations
from functools import wraps
from flask import request
from controllers.openapi.auth.context import Context, Step
from libs.oauth_bearer import Scope
class Pipeline:
def __init__(self, *steps: Step) -> None:
self._steps = steps
def run(self, ctx: Context) -> None:
for step in self._steps:
step(ctx)
def guard(self, *, scope: Scope):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
ctx = Context(request=request, required_scope=scope)
self.run(ctx)
kwargs.update(
app_model=ctx.app,
caller=ctx.caller,
caller_kind=ctx.caller_kind,
)
return view(*args, **kwargs)
return decorated
return decorator

View File

@ -0,0 +1,174 @@
"""Pipeline steps. Each is one responsibility.
`BearerCheck` is the only step that touches the token registry; downstream
steps see only the populated `Context`. `BearerCheck` also assigns
``g.auth_ctx`` (the same way ``validate_bearer`` does) so the surface gate
+ any handler reading the request-scoped context has a single source of
truth across both auth-attach paths.
"""
from __future__ import annotations
from collections.abc import Callable
from flask import g
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
from configs import dify_config
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
from controllers.openapi.auth.surface_gate import check_surface
from extensions.ext_database import db
from libs.oauth_bearer import (
AuthContext,
InvalidBearerError,
Scope,
SubjectType,
_extract_bearer, # type: ignore[attr-defined]
check_workspace_membership,
get_authenticator,
)
from models import App, Tenant, TenantStatus
class BearerCheck:
"""Resolve bearer → populate identity fields. Rate-limit is enforced
inside `BearerAuthenticator.authenticate`, so no separate step here.
Also attaches the resolved `AuthContext` to ``g.auth_ctx`` same shape
the decorator-level ``validate_bearer`` writes so the surface gate
+ downstream readers don't see two different identity sources."""
def __call__(self, ctx: Context) -> None:
token = _extract_bearer(ctx.request)
if not token:
raise Unauthorized("bearer required")
try:
authn = get_authenticator().authenticate(token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
ctx.subject_type = authn.subject_type
ctx.subject_email = authn.subject_email
ctx.subject_issuer = authn.subject_issuer
ctx.account_id = authn.account_id
ctx.scopes = frozenset(authn.scopes)
ctx.source = authn.source
ctx.token_id = authn.token_id
ctx.expires_at = authn.expires_at
ctx.token_hash = authn.token_hash
ctx.cached_verified_tenants = dict(authn.verified_tenants)
# Single source of truth for the request-scoped identity. Surface
# gate + handlers read `g.auth_ctx` regardless of whether the route
# ran the decorator path (`validate_bearer`) or the pipeline path.
g.auth_ctx = authn
class ScopeCheck:
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
def __call__(self, ctx: Context) -> None:
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
return
raise Forbidden("insufficient_scope")
class SurfaceCheck:
"""Reject the request if `g.auth_ctx.subject_type` is not in `accepted`.
Delegates to `surface_gate.check_surface` so the inline decorator and
the pipeline step emit identical audit events. Relies on `BearerCheck`
(above) having set `g.auth_ctx`.
"""
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
self._accepted = accepted
def __call__(self, ctx: Context) -> None:
check_surface(self._accepted)
class AppResolver:
"""Read app_id from request.view_args, populate ctx.app + ctx.tenant.
Every endpoint using the OAuth bearer pipeline must declare
``<string:app_id>`` in its route that is the design lock-in (no body /
header coupling).
"""
def __call__(self, ctx: Context) -> None:
app_id = (ctx.request.view_args or {}).get("app_id")
if not app_id:
raise BadRequest("app_id is required in path")
app = db.session.get(App, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
tenant = db.session.get(Tenant, app.tenant_id)
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("workspace unavailable")
ctx.app, ctx.tenant = app, tenant
class WorkspaceMembershipCheck:
"""Layer 0 — workspace membership gate.
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
(dfoa_) only SSO subjects skip.
"""
def __call__(self, ctx: Context) -> None:
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT:
return
if ctx.account_id is None or ctx.tenant is None:
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
if ctx.token_hash is None:
raise Unauthorized("token_hash unset — BearerCheck did not run")
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=ctx.tenant.id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.cached_verified_tenants or {},
)
class AppAuthzCheck:
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
self._resolve = resolve_strategy
def __call__(self, ctx: Context) -> None:
if not self._resolve().authorize(ctx):
raise Forbidden("subject_no_app_access")
class CallerMount:
def __init__(self, *mounters: CallerMounter) -> None:
self._mounters = mounters
def __call__(self, ctx: Context) -> None:
if ctx.subject_type is None:
raise Unauthorized("subject_type unset — BearerCheck did not run")
for m in self._mounters:
if m.applies_to(ctx.subject_type):
m.mount(ctx)
return
raise Unauthorized("no caller mounter for subject type")
# AuthContext re-export so callers reading `g.auth_ctx` after a pipeline
# run get a consistent import location next to the step that writes it.
__all__ = [
"AppAuthzCheck",
"AppResolver",
"AuthContext",
"BearerCheck",
"CallerMount",
"ScopeCheck",
"SurfaceCheck",
"WorkspaceMembershipCheck",
]

View File

@ -0,0 +1,188 @@
"""Strategy classes for the openapi auth pipeline.
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
vary along independent axes; each strategy is one class so the pipeline
composition stays a flat list.
"""
from __future__ import annotations
import uuid
from typing import Protocol
from flask import current_app
from flask_login import user_logged_in
from sqlalchemy import select
from controllers.openapi.auth.context import Context
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.oauth_bearer import SubjectType
from models import Account, TenantAccountJoin
from services.end_user_service import EndUserService
from services.enterprise.enterprise_service import (
EnterpriseService,
WebAppAccessMode,
)
class AppAuthzStrategy(Protocol):
def authorize(self, ctx: Context) -> bool: ...
class AclStrategy:
"""Per-app ACL, evaluated in two stages.
The EE gateway has already enforced tenancy and workspace membership
by the time this strategy runs, so AclStrategy only owns per-app ACL:
1. Subject vs access-mode compatibility (pure rule table). External-SSO
bearers belong to public-facing apps only; account bearers cover the
full set. A mismatch is an immediate deny no IO.
2. For modes that pair with the subject, decide whether the inner
permission API must run. Only `PRIVATE` (per-app selected-user list)
requires it; the remaining modes are pass-through.
"""
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
SubjectType.ACCOUNT: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
WebAppAccessMode.PRIVATE_ALL,
WebAppAccessMode.PRIVATE,
}
),
SubjectType.EXTERNAL_SSO: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
}
),
}
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset(
{WebAppAccessMode.PRIVATE}
)
def authorize(self, ctx: Context) -> bool:
if ctx.app is None:
return False
access_mode = self._fetch_access_mode(ctx.app.id)
if access_mode is None:
return False
if not self._subject_allowed_for_mode(ctx.subject_type, access_mode):
return False
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
return True
return self._inner_permission_check(ctx)
@staticmethod
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if settings is None:
return None
try:
return WebAppAccessMode(settings.access_mode)
except ValueError:
return None
@classmethod
def _subject_allowed_for_mode(
cls, subject_type: SubjectType, access_mode: WebAppAccessMode
) -> bool:
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
def _inner_permission_check(self, ctx: Context) -> bool:
if ctx.app is None:
return False
user_id = self._resolve_user_id(ctx)
if user_id is None:
return False
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_id=ctx.app.id,
)
@staticmethod
def _resolve_user_id(ctx: Context) -> str | None:
if ctx.subject_type == SubjectType.ACCOUNT:
return str(ctx.account_id) if ctx.account_id is not None else None
if ctx.subject_email is None:
return None
account = db.session.execute(
select(Account).where(Account.email == ctx.subject_email),
).scalar_one_or_none()
return str(account.id) if account is not None else None
class MembershipStrategy:
"""Tenant-membership fallback.
Used when webapp-auth is disabled (CE deployment). Account-bearing
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
denied (it requires the webapp-auth surface).
"""
def authorize(self, ctx: Context) -> bool:
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return False
if ctx.tenant is None:
return False
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool:
if not account_id:
return False
row = db.session.execute(
select(TenantAccountJoin.id).where(
TenantAccountJoin.tenant_id == tenant_id,
TenantAccountJoin.account_id == account_id,
)
).scalar_one_or_none()
return row is not None
def _login_as(user) -> None:
"""Set Flask-Login request user so downstream services see the caller."""
current_app.login_manager._update_request_context_with_user(user)
user_logged_in.send(current_app._get_current_object(), user=user)
class CallerMounter(Protocol):
def applies_to(self, subject_type: SubjectType) -> bool: ...
def mount(self, ctx: Context) -> None: ...
class AccountMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.ACCOUNT
def mount(self, ctx: Context) -> None:
if ctx.account_id is None:
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
account = db.session.get(Account, ctx.account_id)
if account is None:
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
account.current_tenant = ctx.tenant
_login_as(account)
ctx.caller, ctx.caller_kind = account, "account"
class EndUserMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.EXTERNAL_SSO
def mount(self, ctx: Context) -> None:
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=ctx.tenant.id,
app_id=ctx.app.id,
user_id=ctx.subject_email,
)
_login_as(end_user)
ctx.caller, ctx.caller_kind = end_user, "end_user"

View File

@ -0,0 +1,90 @@
"""Surface gate.
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
step) is the pipeline-level form. Both delegate to `check_surface` so the
audit emit + canonical-path message are single-sourced.
Subjects come from `libs.oauth_bearer.SubjectType` directly no parallel
vocabulary. Caller hits the wrong surface 403 ``wrong_surface`` + audit
``openapi.wrong_surface_denied``.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from flask import g, request
from werkzeug.exceptions import Forbidden
from controllers.openapi._audit import emit_wrong_surface
from libs.oauth_bearer import SubjectType
_CANONICAL_PATH: dict[SubjectType, str] = {
SubjectType.ACCOUNT: "/openapi/v1/apps",
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
}
F = TypeVar("F", bound=Callable[..., object])
def check_surface(accepted: frozenset[SubjectType]) -> None:
"""Enforce that ``g.auth_ctx.subject_type`` is in ``accepted``.
Raises ``Forbidden`` with ``wrong_surface`` + canonical-path hint on
miss; emits ``openapi.wrong_surface_denied`` audit. If ``g.auth_ctx``
is missing the bearer layer didn't run — that's a wiring bug, not a
user-driven failure, so surface it as a ``RuntimeError`` instead of
a silent 403.
"""
ctx = getattr(g, "auth_ctx", None)
if ctx is None:
raise RuntimeError(
"check_surface called without g.auth_ctx; "
"stack validate_bearer or BearerCheck above the surface gate"
)
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
if subject in accepted:
return
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
emit_wrong_surface(
subject_type=subject.value if subject else None,
attempted_path=request.path,
client_id=getattr(ctx, "client_id", None),
token_id=_stringify(getattr(ctx, "token_id", None)),
)
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
accepted_set: frozenset[SubjectType] = frozenset(accepted)
def deco(fn: F) -> F:
@wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
check_surface(accepted_set)
return fn(*args, **kwargs)
return wrapper # type: ignore[return-value]
return deco
def _coerce_subject_type(raw: object) -> SubjectType | None:
if raw is None:
return None
if isinstance(raw, SubjectType):
return raw
try:
return SubjectType(raw)
except ValueError:
return None
def _stringify(value: object) -> str | None:
if value is None:
return None
return str(value)

View File

@ -0,0 +1,9 @@
from flask_restx import Resource
from controllers.openapi import openapi_ns
@openapi_ns.route("/_health")
class HealthApi(Resource):
def get(self):
return {"ok": True}

View File

@ -0,0 +1,401 @@
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
sub-groups in one module:
Protocol (RFC 8628, public + rate-limited):
POST /oauth/device/code
POST /oauth/device/token
GET /oauth/device/lookup
Approval (account branch, console-cookie authed):
POST /oauth/device/approve
POST /oauth/device/deny
SSO branch lives in oauth_device_sso.py.
"""
from __future__ import annotations
import logging
from flask import request
from flask_login import login_required
from flask_restx import Resource
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest
from configs import dify_config
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.openapi import openapi_ns
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
from libs.rate_limit import (
LIMIT_APPROVE_CONSOLE,
LIMIT_DEVICE_CODE_PER_IP,
LIMIT_LOOKUP_PUBLIC,
rate_limit,
)
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
DEFAULT_POLL_INTERVAL_SECONDS,
DEVICE_FLOW_TTL_SECONDS,
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
SlowDownDecision,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
logger = logging.getLogger(__name__)
# =========================================================================
# Request / query schemas
# =========================================================================
class DeviceCodeRequest(BaseModel):
client_id: str
device_label: str
class DevicePollRequest(BaseModel):
device_code: str
client_id: str
class DeviceLookupQuery(BaseModel):
user_code: str
class DeviceMutateRequest(BaseModel):
user_code: str
def _validate_json[M: BaseModel](model: type[M]) -> M:
body = request.get_json(silent=True) or {}
try:
return model.model_validate(body)
except ValidationError as exc:
raise BadRequest(str(exc))
def _validate_query[M: BaseModel](model: type[M]) -> M:
try:
return model.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise BadRequest(str(exc))
# =========================================================================
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
# =========================================================================
@openapi_ns.route("/oauth/device/code")
class OAuthDeviceCodeApi(Resource):
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
def post(self):
payload = _validate_json(DeviceCodeRequest)
client_id = payload.client_id
device_label = payload.device_label
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
return {"error": "unsupported_client"}, 400
store = DeviceFlowRedis(redis_client)
ip = extract_remote_ip(request)
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
return {
"device_code": device_code,
"user_code": user_code,
"verification_uri": _verification_uri(),
"expires_in": expires_in,
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
}, 200
@openapi_ns.route("/oauth/device/token")
class OAuthDeviceTokenApi(Resource):
"""RFC 8628 poll."""
def post(self):
payload = _validate_json(DevicePollRequest)
device_code = payload.device_code
store = DeviceFlowRedis(redis_client)
# slow_down beats every other branch — polling-too-fast clients
# see only that response regardless of underlying state.
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
return {"error": "slow_down"}, 400
state = store.load_by_device_code(device_code)
if state is None:
return {"error": "expired_token"}, 400
if state.status is DeviceFlowStatus.PENDING:
return {"error": "authorization_pending"}, 400
terminal = store.consume_on_poll(device_code)
if terminal is None:
return {"error": "expired_token"}, 400
if terminal.status is DeviceFlowStatus.DENIED:
return {"error": "access_denied"}, 400
poll_payload = terminal.poll_payload or {}
if "token" not in poll_payload:
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
return {"error": "expired_token"}, 400
_audit_cross_ip_if_needed(state)
return poll_payload, 200
@openapi_ns.route("/oauth/device/lookup")
class OAuthDeviceLookupApi(Resource):
"""Read-only — public for pre-validate before login. user_code is
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
"""
@rate_limit(LIMIT_LOOKUP_PUBLIC)
def get(self):
payload = _validate_query(DeviceLookupQuery)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
_device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
return {
"valid": True,
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
"client_id": state.client_id,
}, 200
# =========================================================================
# Approval endpoints — account branch (cookie-authed)
# =========================================================================
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
_APPROVE_GUARD_TTL_SECONDS = 10
@openapi_ns.route("/oauth/device/approve")
class DeviceApproveApi(Resource):
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
# SET NX guard — without it, two in-flight approves both pass
# PENDING, both mint, and the second upsert silently rotates the
# first caller into an already-revoked token.
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
return {"error": "approve_in_progress"}, 409
try:
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
try:
validate_mint_policy(
subject_type=profile.subject_type,
prefix=profile.prefix,
scopes=profile.scopes,
)
except MintPolicyViolation as e:
raise BadRequest(description=str(e)) from None
ttl_days = oauth_ttl_days(tenant_id=tenant)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=account.email,
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
account_id=str(account.id),
client_id=state.client_id,
device_label=state.device_label,
prefix=profile.prefix,
ttl_days=ttl_days,
)
poll_payload = _build_account_poll_payload(account, tenant, mint)
try:
store.approve(
device_code,
subject_email=account.email,
account_id=str(account.id),
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFoundError, InvalidTransitionError):
# Row minted but state vanished — roll forward; the orphan
# token is revocable via auth devices list / Authorized Apps.
logger.exception("device_flow: approve raced on %s", device_code)
return {"error": "state_lost"}, 409
finally:
redis_client.delete(guard_key)
_emit_approve_audit(state, account, tenant, mint)
return {"status": "approved"}, 200
@openapi_ns.route("/oauth/device/deny")
class DeviceDenyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
try:
store.deny(device_code)
except (StateNotFoundError, InvalidTransitionError):
logger.exception("device_flow: deny raced on %s", device_code)
return {"error": "state_lost"}, 409
_emit_deny_audit(state)
return {"status": "denied"}, 200
# =========================================================================
# Helpers
# =========================================================================
def _verification_uri() -> str:
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
if base:
return f"{base.rstrip('/')}/device"
return f"{request.host_url.rstrip('/')}/device"
def _audit_cross_ip_if_needed(state) -> None:
poll_ip = extract_remote_ip(request)
if state.created_ip and poll_ip and poll_ip != state.created_ip:
logger.warning(
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
state.token_id,
state.created_ip,
poll_ip,
extra={
"audit": True,
"token_id": state.token_id,
"creation_ip": state.created_ip,
"poll_ip": poll_ip,
},
)
def _build_account_poll_payload(account, tenant, mint) -> dict:
"""Pre-render the poll-response body so the unauthenticated poll
handler doesn't re-query accounts/tenants for authz data.
"""
from models import Tenant, TenantAccountJoin
rows = (
db.session.query(Tenant, TenantAccountJoin)
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == account.id)
.all()
)
workspaces = [{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} for t, m in rows]
# Prefer active session tenant → DB-flagged current join → first membership.
default_ws_id = None
if tenant and any(w["id"] == str(tenant) for w in workspaces):
default_ws_id = str(tenant)
if default_ws_id is None:
for _t, m in rows:
if getattr(m, "current", False):
default_ws_id = str(m.tenant_id)
break
if default_ws_id is None and workspaces:
default_ws_id = workspaces[0]["id"]
return {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.ACCOUNT,
"account": {"id": str(account.id), "email": account.email, "name": account.name},
"workspaces": workspaces,
"default_workspace_id": default_ws_id,
"token_id": str(mint.token_id),
}
def _emit_approve_audit(state, account, tenant, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
mint.token_id,
account.email,
state.client_id,
state.device_label,
mint.expires_at,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"token_id": str(mint.token_id),
"subject_type": SubjectType.ACCOUNT,
"subject_email": account.email,
"account_id": str(account.id),
"tenant_id": tenant,
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["full"],
"expires_at": mint.expires_at.isoformat(),
},
)
def _emit_deny_audit(state) -> None:
logger.warning(
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
state.client_id,
state.device_label,
extra={
"audit": True,
"event": "oauth.device_flow_denied",
"client_id": state.client_id,
"device_label": state.device_label,
},
)

View File

@ -0,0 +1,369 @@
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
EE-only. Browser flow:
GET /oauth/device/sso-initiate 302 to IdP authorize URL
GET /oauth/device/sso-complete ACS callback, sets approval-grant cookie
GET /oauth/device/approval-context SPA reads cookie claims (idempotent)
POST /oauth/device/approve-external mints dfoe_ token + clears cookie
Function-based (raw @bp.route) rather than Resource classes because the
handlers do redirects + cookie kwargs that don't fit the Resource shape.
"""
from __future__ import annotations
import logging
import secrets
from dataclasses import dataclass
from flask import jsonify, make_response, redirect, request
from sqlalchemy import func, select
from werkzeug.exceptions import (
BadGateway,
BadRequest,
Conflict,
Forbidden,
NotFound,
Unauthorized,
)
from controllers.openapi import bp
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import jws
from libs.device_flow_security import (
APPROVAL_GRANT_COOKIE_NAME,
ApprovalGrantClaims,
approval_grant_cleared_cookie_kwargs,
approval_grant_cookie_kwargs,
consume_approval_grant_nonce,
consume_sso_assertion_nonce,
enterprise_only,
mint_approval_grant,
verify_approval_grant,
)
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
from libs.rate_limit import (
LIMIT_APPROVE_EXT_PER_EMAIL,
LIMIT_SSO_INITIATE_PER_IP,
enforce,
rate_limit,
)
from models import Account
from models.account import AccountStatus
from services.enterprise.enterprise_service import EnterpriseService
from services.oauth_device_flow import (
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
logger = logging.getLogger(__name__)
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
# device_code it references.
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
@enterprise_only
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
def sso_initiate():
user_code = (request.args.get("user_code") or "").strip().upper()
if not user_code:
raise BadRequest("user_code required")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise BadRequest("invalid_user_code")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise BadRequest("invalid_user_code")
keyset = jws.KeySet.from_shared_secret()
signed_state = jws.sign(
keyset,
payload={
"redirect_url": "",
"app_code": "",
"intent": "device_flow",
"user_code": user_code,
"nonce": secrets.token_urlsafe(16),
"return_to": "",
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
},
aud=jws.AUD_STATE_ENVELOPE,
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
)
try:
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
except Exception as e:
logger.warning("sso-initiate: enterprise call failed: %s", e)
raise BadGateway("sso_initiate_failed") from e
url = (reply or {}).get("url")
if not url:
raise BadGateway("sso_initiate_missing_url")
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
resp = redirect(url, code=302)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
@bp.route("/oauth/device/sso-complete", methods=["GET"])
@enterprise_only
def sso_complete():
blob = request.args.get("sso_assertion")
if not blob:
raise BadRequest("sso_assertion required")
keyset = jws.KeySet.from_shared_secret()
try:
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
except jws.VerifyError as e:
logger.warning("sso-complete: rejected assertion: %s", e)
raise BadRequest("invalid_sso_assertion") from e
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
raise BadRequest("invalid_sso_assertion")
user_code = (claims.get("user_code") or "").strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise Conflict("user_code_not_pending")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if _email_belongs_to_dify_account(claims["email"]):
_emit_external_rejection_audit(
state,
_RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]),
reason="email_belongs_to_dify_account",
)
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
iss = request.host_url.rstrip("/")
cookie_value, _ = mint_approval_grant(
keyset=keyset,
iss=iss,
subject_email=claims["email"],
subject_issuer=claims["issuer"],
user_code=user_code,
)
resp = redirect("/device?sso_verified=1", code=302)
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
return resp
@bp.route("/oauth/device/approval-context", methods=["GET"])
@enterprise_only
def approval_context():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("no_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approval-context: bad cookie: %s", e)
raise Unauthorized("no_session") from e
return jsonify(
{
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"user_code": claims.user_code,
"csrf_token": claims.csrf_token,
"expires_at": claims.expires_at.isoformat(),
}
), 200
@bp.route("/oauth/device/approve-external", methods=["POST"])
@enterprise_only
def approve_external():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("invalid_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approve-external: bad cookie: %s", e)
raise Unauthorized("invalid_session") from e
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
csrf_header = request.headers.get("X-CSRF-Token", "")
if not csrf_header or csrf_header != claims.csrf_token:
raise Forbidden("csrf_mismatch")
data = request.get_json(silent=True) or {}
body_user_code = (data.get("user_code") or "").strip().upper()
if body_user_code != claims.user_code:
raise BadRequest("user_code_mismatch")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(claims.user_code)
if found is None:
raise NotFound("user_code_not_pending")
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if _email_belongs_to_dify_account(claims.subject_email):
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
raise Forbidden("email_belongs_to_dify_account")
if not consume_approval_grant_nonce(redis_client, claims.nonce):
raise Unauthorized("session_already_consumed")
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
try:
validate_mint_policy(
subject_type=profile.subject_type,
prefix=profile.prefix,
scopes=profile.scopes,
)
except MintPolicyViolation as e:
raise BadRequest(description=str(e)) from None
ttl_days = oauth_ttl_days(tenant_id=None)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=claims.subject_email,
subject_issuer=claims.subject_issuer,
account_id=None,
client_id=state.client_id,
device_label=state.device_label,
prefix=profile.prefix,
ttl_days=ttl_days,
)
poll_payload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"account": None,
"workspaces": [],
"default_workspace_id": None,
"token_id": str(mint.token_id),
}
try:
store.approve(
device_code,
subject_email=claims.subject_email,
account_id=None,
subject_issuer=claims.subject_issuer,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFoundError, InvalidTransitionError) as e:
logger.exception("approve-external: state transition raced")
raise Conflict("state_lost") from e
_emit_approve_external_audit(state, claims, mint)
resp = make_response(jsonify({"status": "approved"}), 200)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
@dataclass(frozen=True)
class _RejectedClaims:
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
event without reaching for the full dataclass.
"""
subject_email: str
subject_issuer: str
def _email_belongs_to_dify_account(email: str) -> bool:
"""External SSO subjects whose email matches an active Dify Account must
authenticate via the internal Dify login path (which mints dfoa_), not via
the external SSO device flow. Returning True here blocks dfoe_ minting.
Pending/uninitialized/banned/closed accounts do not block: pending and
uninitialized users may complete invitation via SSO; banned and closed
accounts are handled by separate enforcement paths.
"""
if not email:
return False
normalized = email.strip().lower()
if not normalized:
return False
row = db.session.execute(
select(Account.id).where(
func.lower(Account.email) == normalized,
Account.status == AccountStatus.ACTIVE,
),
).scalar_one_or_none()
return row is not None
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
logger.warning(
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
SubjectType.EXTERNAL_SSO,
claims.subject_email,
claims.subject_issuer,
reason,
extra={
"audit": True,
"event": "oauth.device_flow_rejected",
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"reason": reason,
"client_id": state.client_id,
"device_label": state.device_label,
},
)
def _emit_approve_external_audit(state, claims, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
SubjectType.EXTERNAL_SSO,
claims.subject_email,
claims.subject_issuer,
mint.token_id,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"token_id": str(mint.token_id),
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["apps:run"],
"expires_at": mint.expires_at.isoformat(),
},
)

View File

@ -0,0 +1,87 @@
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
counterparts to the cookie-authed /console/api/workspaces endpoints.
Account bearers (dfoa_) see every tenant they're a member of. External
SSO bearers (dfoe_) have no account_id and so see an empty list that
matches /openapi/v1/account.
"""
from __future__ import annotations
from itertools import starmap
from flask import g
from flask_restx import Resource
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.openapi import openapi_ns
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
SubjectType,
validate_bearer,
)
from models import Tenant, TenantAccountJoin
@openapi_ns.route("/workspaces")
class WorkspacesApi(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self):
ctx = g.auth_ctx
rows = db.session.execute(
select(Tenant, TenantAccountJoin)
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.account_id == str(ctx.account_id))
.order_by(Tenant.created_at.asc())
).all()
return {"workspaces": list(starmap(_workspace_summary, rows))}, 200
@openapi_ns.route("/workspaces/<string:workspace_id>")
class WorkspaceByIdApi(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self, workspace_id: str):
ctx = g.auth_ctx
row = db.session.execute(
select(Tenant, TenantAccountJoin)
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
.where(
Tenant.id == workspace_id,
TenantAccountJoin.account_id == str(ctx.account_id),
)
).first()
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
if row is None:
raise NotFound("workspace not found")
tenant, membership = row
return _workspace_detail(tenant, membership), 200
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> dict:
return {
"id": str(tenant.id),
"name": tenant.name,
"role": getattr(membership, "role", ""),
"status": tenant.status,
"current": getattr(membership, "current", False),
}
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> dict:
return {
"id": str(tenant.id),
"name": tenant.name,
"role": getattr(membership, "role", ""),
"status": tenant.status,
"current": getattr(membership, "current", False),
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
}

View File

@ -16,7 +16,7 @@ from libs.passport import PassportService
from libs.token import extract_webapp_passport
from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode, WebAppSettings
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
@ -74,7 +74,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
if not webapp_settings:
raise NotFound("Web app settings not found.")
app_web_auth_enabled = webapp_settings.access_mode != "public"
app_web_auth_enabled = webapp_settings.access_mode != WebAppAccessMode.PUBLIC
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(
@ -88,7 +88,8 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
raise Unauthorized("Please re-login to access the web app.")
app_id = AppService.get_app_id_by_code(app_code)
app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public"
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode
!= WebAppAccessMode.PUBLIC
)
if app_web_auth_enabled:
raise WebAppAuthRequiredError()

View File

@ -685,6 +685,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
match invoke_from:
case InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
case InvokeFrom.OPENAPI:
created_from = WorkflowAppLogCreatedFrom.OPENAPI
case InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
case InvokeFrom.WEB_APP:

View File

@ -24,6 +24,7 @@ class UserFrom(StrEnum):
class InvokeFrom(StrEnum):
SERVICE_API = "service-api"
OPENAPI = "openapi"
WEB_APP = "web-app"
TRIGGER = "trigger"
EXPLORE = "explore"

View File

@ -8,6 +8,8 @@ AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
OPENAPI_HEADERS: tuple[str, ...] = ("Authorization", "Content-Type", HEADER_NAME_CSRF_TOKEN)
OPENAPI_MAX_AGE_SECONDS: int = 600
def _apply_cors_once(bp, /, **cors_kwargs):
@ -29,6 +31,7 @@ def init_app(app: DifyApp):
from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp
from controllers.mcp import bp as mcp_bp
from controllers.openapi import bp as openapi_bp
from controllers.service_api import bp as service_api_bp
from controllers.trigger import bp as trigger_bp
from controllers.web import bp as web_bp
@ -41,6 +44,22 @@ def init_app(app: DifyApp):
)
app.register_blueprint(service_api_bp)
# User-scoped programmatic API. Default empty allowlist = same-origin
# only; expand via OPENAPI_CORS_ALLOW_ORIGINS for third-party
# integrations. supports_credentials so cookie-authed approve/deny
# work; cross-origin OPTIONS without an allowed origin will fail
# the same as on the console blueprint.
_apply_cors_once(
openapi_bp,
resources={r"/*": {"origins": dify_config.OPENAPI_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=list(OPENAPI_HEADERS),
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
expose_headers=list(EXPOSED_HEADERS),
max_age=OPENAPI_MAX_AGE_SECONDS,
)
app.register_blueprint(openapi_bp)
_apply_cors_once(
web_bp,
resources={

View File

@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
"schedule": crontab(minute="0", hour="0"),
}
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
imports.append("schedule.clean_oauth_access_tokens_task")
beat_schedule["clean_oauth_access_tokens_task"] = {
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
}
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {

View File

@ -12,7 +12,7 @@ from constants import HEADER_NAME_APP_CODE
from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token, extract_webapp_passport
from libs.token import extract_access_token, extract_console_cookie_token, extract_webapp_passport
from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
@ -84,6 +84,24 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
elif request.blueprint == "openapi":
# Account-branch device-flow approval routes (approve / deny /
# approval-context) sit under @login_required and authenticate via
# the console session cookie. Cookie-only on purpose — bearer
# tokens (dfoa_/dfoe_) live on the Authorization header and are
# validated by AppPipeline, not flask-login.
cookie_token = extract_console_cookie_token(request)
if not cookie_token:
return None
try:
decoded = PassportService().verify(cookie_token)
except Exception:
return None
user_id = decoded.get("user_id")
source = decoded.get("token_source")
if source or not user_id:
return None
return AccountService.load_logged_in_account(account_id=user_id)
elif request.blueprint == "web":
app_code = request.headers.get(HEADER_NAME_APP_CODE)
webapp_token = extract_webapp_passport(app_code, request) if app_code else None

View File

@ -0,0 +1,23 @@
"""Bind the bearer authenticator at startup. Must run after ext_database
and ext_redis (needs both factories).
"""
from __future__ import annotations
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import build_and_bind
def is_enabled() -> bool:
return dify_config.ENABLE_OAUTH_BEARER
def init_app(app: DifyApp) -> None:
# scoped_session isn't a context manager; request teardown closes it.
def session_factory():
return db.session
build_and_bind(session_factory=session_factory, redis_client=redis_client)

View File

@ -0,0 +1,196 @@
"""Device-flow security primitives: enterprise_only gate, approval-grant
cookie mint/verify/consume, and anti-framing headers.
"""
from __future__ import annotations
import logging
import secrets
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from functools import wraps
from flask import Blueprint
from werkzeug.exceptions import NotFound
from libs import jws
from libs.token import is_secure
from services.feature_service import FeatureService, LicenseStatus
logger = logging.getLogger(__name__)
# ============================================================================
# enterprise_only decorator
# ============================================================================
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
responses don't consume the bucket.
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
if settings.license.status not in _EE_ENABLED_STATUSES:
raise NotFound()
return view(*args, **kwargs)
return decorated
# ============================================================================
# approval_grant cookie
# ============================================================================
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
APPROVAL_GRANT_COOKIE_PATH = "/openapi/v1/oauth/device"
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
@dataclass(frozen=True, slots=True)
class ApprovalGrantClaims:
subject_email: str
subject_issuer: str
user_code: str
nonce: str
csrf_token: str
expires_at: datetime
def mint_approval_grant(
*,
keyset: jws.KeySet,
iss: str,
subject_email: str,
subject_issuer: str,
user_code: str,
) -> tuple[str, ApprovalGrantClaims]:
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
source of truth for Path/HttpOnly/Secure/SameSite.
"""
now = datetime.now(UTC)
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
nonce = _random_opaque()
csrf_token = _random_opaque()
payload = {
"iss": iss,
"subject_email": subject_email,
"subject_issuer": subject_issuer,
"user_code": user_code,
"nonce": nonce,
"csrf_token": csrf_token,
}
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
return token, ApprovalGrantClaims(
subject_email=subject_email,
subject_issuer=subject_issuer,
user_code=user_code,
nonce=nonce,
csrf_token=csrf_token,
expires_at=exp,
)
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
"""Sig + aud + exp only — nonce consumption is the caller's job."""
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
return ApprovalGrantClaims(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
user_code=data["user_code"],
nonce=data["nonce"],
csrf_token=data["csrf_token"],
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
)
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
if not nonce:
return False
return bool(
redis_client.set(
NONCE_KEY_FMT.format(nonce=nonce),
"1",
nx=True,
ex=NONCE_TTL_SECONDS,
)
)
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
if not nonce:
return False
return bool(
redis_client.set(
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
"1",
nx=True,
ex=NONCE_TTL_SECONDS,
)
)
def approval_grant_cookie_kwargs(value: str) -> dict:
"""``secure`` follows is_secure() so HTTP-only deployments don't
silently drop the cookie.
"""
return {
"key": APPROVAL_GRANT_COOKIE_NAME,
"value": value,
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
"path": APPROVAL_GRANT_COOKIE_PATH,
"secure": is_secure(),
"httponly": True,
"samesite": "Lax",
}
def approval_grant_cleared_cookie_kwargs() -> dict:
return {
"key": APPROVAL_GRANT_COOKIE_NAME,
"value": "",
"max_age": 0,
"path": APPROVAL_GRANT_COOKIE_PATH,
"secure": is_secure(),
"httponly": True,
"samesite": "Lax",
}
def _random_opaque() -> str:
return secrets.token_urlsafe(16)
# ============================================================================
# Anti-framing headers
# ============================================================================
_ANTI_FRAMING_HEADERS = {
"X-Frame-Options": "DENY",
"Content-Security-Policy": "frame-ancestors 'none'",
}
def attach_anti_framing(bp: Blueprint) -> None:
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
@bp.after_request
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
for name, value in _ANTI_FRAMING_HEADERS.items():
response.headers.setdefault(name, value)
return response

View File

@ -75,6 +75,7 @@ def register_external_error_handlers(api: Api):
def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e)
current_app.logger.exception("value_error in request handler")
status_code = 400
data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code

View File

@ -542,3 +542,18 @@ class RateLimiter:
self._redis_client.zadd(key, {member: current_time})
self._redis_client.expire(key, self.time_window * 2)
def seconds_until_available(self, email: str) -> int:
"""Seconds until the oldest in-window entry expires, freeing a slot.
Defensive floor of 1 second. Caller should only invoke this after
is_rate_limited() returned True.
"""
key = self._get_key(email)
oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True)
if not oldest:
return 1
_member, score = oldest[0]
free_at = int(score) + self.time_window
remaining = free_at - int(time.time())
return max(remaining, 1)

108
api/libs/jws.py Normal file
View File

@ -0,0 +1,108 @@
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
state envelope, external subject assertion, and approval-grant cookie
all three share one key-set so api enterprise can verify each other.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import jwt
from configs import dify_config
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
ACTIVE_KID_V1 = "dify-shared-v1"
class KeySetError(Exception):
pass
class KeySet:
"""``from_entries`` reserves multi-kid construction for rotation slots."""
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
if active_kid not in entries:
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
if not entries[active_kid]:
raise KeySetError(f"active kid {active_kid!r} has empty secret")
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
self._active_kid = active_kid
@classmethod
def from_shared_secret(cls) -> KeySet:
secret = dify_config.SECRET_KEY
if not secret:
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
@classmethod
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
return cls(entries, active_kid)
@property
def active_kid(self) -> str:
return self._active_kid
def lookup(self, kid: str) -> bytes | None:
return self._entries.get(kid)
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
"""``iat`` + ``exp`` are injected here; callers must not set them."""
if "aud" in payload or "iat" in payload or "exp" in payload:
raise ValueError("reserved claim present in payload (aud/iat/exp)")
if ttl_seconds <= 0:
raise ValueError("ttl_seconds must be positive")
kid = keyset.active_kid
secret = keyset.lookup(kid)
if secret is None:
raise KeySetError(f"active kid {kid!r} lookup miss")
iat = datetime.now(UTC)
exp = iat + timedelta(seconds=ttl_seconds)
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
return jwt.encode(
claims,
secret,
algorithm="HS256",
headers={"kid": kid, "typ": "JWT"},
)
class VerifyError(Exception):
pass
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
"""Unknown kid is rejected — never fall back to the active kid, since
a past kid value would otherwise be forgeable by anyone who saw it.
"""
try:
header = jwt.get_unverified_header(token)
except jwt.PyJWTError as e:
raise VerifyError(f"decode header: {e}") from e
kid = header.get("kid")
if not kid:
raise VerifyError("no kid in header")
secret = keyset.lookup(kid)
if secret is None:
raise VerifyError(f"unknown kid {kid!r}")
try:
return jwt.decode(
token,
secret,
algorithms=["HS256"],
audience=expected_aud,
)
except jwt.ExpiredSignatureError as e:
raise VerifyError("token expired") from e
except jwt.InvalidAudienceError as e:
raise VerifyError("aud mismatch") from e
except jwt.PyJWTError as e:
raise VerifyError(f"decode: {e}") from e

650
api/libs/oauth_bearer.py Normal file
View File

@ -0,0 +1,650 @@
"""OAuth bearer primitives.
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
Authenticator + validate_bearer stay untouched.
"""
from __future__ import annotations
import hashlib
import json
import logging
import uuid
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import StrEnum
from functools import wraps
from typing import Literal, ParamSpec, Protocol, TypeVar
from flask import g, request
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.rate_limit import enforce_bearer_rate_limit
from models import Account, OAuthAccessToken, TenantAccountJoin
logger = logging.getLogger(__name__)
# ============================================================================
# Contract — types, enums, protocols
# ============================================================================
class SubjectType(StrEnum):
ACCOUNT = "account"
EXTERNAL_SSO = "external_sso"
class Scope(StrEnum):
"""Catalog of bearer scopes recognised by the openapi surface.
`FULL` is the catch-all carried by `dfoa_` account tokens it satisfies
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
(`APPS_RUN`, `APPS_READ_PERMITTED_EXTERNAL`).
"""
FULL = "full"
APPS_READ = "apps:read"
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
APPS_RUN = "apps:run"
class Accepts(StrEnum):
"""Subject types a route is willing to accept as caller."""
USER_ACCOUNT = "user_account"
USER_EXT_SSO = "user_ext_sso"
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
ACCEPT_USER_EXT_SSO: frozenset[Accepts] = frozenset({Accepts.USER_EXT_SSO})
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
}
@dataclass(frozen=True, slots=True)
class AuthContext:
"""Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source``
come from the TokenKind, not the DB corrupt rows can't elevate scope.
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
authenticate time. Per-request mutations write through to Redis via
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
"""
subject_type: SubjectType
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
client_id: str | None
scopes: frozenset[Scope]
token_id: uuid.UUID
source: str
expires_at: datetime | None
token_hash: str
verified_tenants: dict[str, bool] = field(default_factory=dict)
@dataclass(frozen=True, slots=True)
class ResolvedRow:
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
client_id: str | None
token_id: uuid.UUID
expires_at: datetime | None
verified_tenants: dict[str, bool] = field(default_factory=dict)
def to_cache(self) -> dict:
return {
"subject_email": self.subject_email,
"subject_issuer": self.subject_issuer,
"account_id": str(self.account_id) if self.account_id else None,
"client_id": self.client_id,
"token_id": str(self.token_id),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"verified_tenants": dict(self.verified_tenants),
}
@classmethod
def from_cache(cls, data: dict) -> ResolvedRow:
return cls(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
client_id=data.get("client_id"),
token_id=uuid.UUID(data["token_id"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
)
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
on all live deployments (60s TTL safe to drop one release after rollout).
"""
if not isinstance(raw, dict):
return {}
out: dict[str, bool] = {}
for k, v in raw.items():
if isinstance(v, bool):
out[k] = v
elif v == "ok":
out[k] = True
elif v == "denied":
out[k] = False
return out
class Resolver(Protocol):
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
...
@dataclass(frozen=True, slots=True)
class TokenKind:
prefix: str
subject_type: SubjectType
scopes: frozenset[Scope]
source: str
resolver: Resolver
def matches(self, token: str) -> bool:
return token.startswith(self.prefix)
@dataclass(frozen=True, slots=True)
class MintProfile:
"""Single source of truth for (subject_type, prefix, scopes) at mint time.
Consumers:
- ``build_registry`` reads scopes here so the resolve-time TokenKind
cannot drift from the mint-time intent.
- Device-flow ``approve`` / ``approve-external`` read prefix + scopes
here when calling ``mint_oauth_token`` and ``validate_mint_policy``.
- ``services.openapi.mint_policy.validate_mint_policy`` cross-checks
the (subject_type, prefix, scopes) triple a caller intends to mint
against this table a caller that assembles its own scope set
from a non-canonical source will fail closed at approve time.
"""
subject_type: SubjectType
prefix: str
scopes: frozenset[Scope]
MINTABLE_PROFILES: dict[SubjectType, MintProfile] = {
SubjectType.ACCOUNT: MintProfile(
subject_type=SubjectType.ACCOUNT,
prefix="dfoa_",
scopes=frozenset({Scope.FULL}),
),
SubjectType.EXTERNAL_SSO: MintProfile(
subject_type=SubjectType.EXTERNAL_SSO,
prefix="dfoe_",
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
),
}
class InvalidBearerError(Exception):
"""Token missing, unknown prefix, or no live row."""
class TokenExpiredError(Exception):
"""Hard-expire bookkeeping is the resolver's job before raising."""
# ============================================================================
# Registry
# ============================================================================
class TokenKindRegistry:
def __init__(self, kinds: Iterable[TokenKind]) -> None:
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
prefixes = [k.prefix for k in self._kinds]
if len(set(prefixes)) != len(prefixes):
raise ValueError(f"duplicate prefix in registry: {prefixes}")
def find(self, token: str) -> TokenKind | None:
for k in self._kinds:
if k.matches(token):
return k
return None
def kinds(self) -> tuple[TokenKind, ...]:
return self._kinds
# ============================================================================
# Authenticator
# ============================================================================
def sha256_hex(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
class BearerAuthenticator:
def __init__(self, registry: TokenKindRegistry) -> None:
self._registry = registry
@property
def registry(self) -> TokenKindRegistry:
return self._registry
def authenticate(self, token: str) -> AuthContext:
"""Identity + per-token rate limit (single source).
Both the openapi pipeline (`BearerCheck`) and the decorator
(`validate_bearer`) call this rate-limit fires exactly once per
request regardless of which path hosts the route.
"""
kind = self._registry.find(token)
if kind is None:
raise InvalidBearerError("unknown token prefix")
token_hash = sha256_hex(token)
row = kind.resolver.resolve(token_hash)
if row is None:
raise InvalidBearerError("token unknown or revoked")
enforce_bearer_rate_limit(token_hash)
return AuthContext(
subject_type=kind.subject_type,
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=row.account_id,
client_id=row.client_id,
scopes=kind.scopes,
token_id=row.token_id,
source=kind.source,
expires_at=row.expires_at,
token_hash=token_hash,
verified_tenants=dict(row.verified_tenants),
)
# ============================================================================
# OAuth access token resolver (PAT resolver would be a sibling class)
# ============================================================================
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
POSITIVE_TTL_SECONDS = 60
NEGATIVE_TTL_SECONDS = 10
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
ScopeVariant = Literal["account", "external_sso"]
class OAuthAccessTokenResolver:
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
sharing DB + cache plumbing.
"""
def __init__(
self,
session_factory,
redis_client,
positive_ttl: int = POSITIVE_TTL_SECONDS,
negative_ttl: int = NEGATIVE_TTL_SECONDS,
) -> None:
self.session_factory = session_factory
self._redis = redis_client
self._positive_ttl = positive_ttl
self._negative_ttl = negative_ttl
def for_account(self) -> Resolver:
return _VariantResolver(self, variant="account")
def for_external_sso(self) -> Resolver:
return _VariantResolver(self, variant="external_sso")
def _cache_key(self, token_hash: str) -> str:
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
raw = self._redis.get(self._cache_key(token_hash))
if raw is None:
return None
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return "invalid"
try:
return ResolvedRow.from_cache(json.loads(text))
except (ValueError, KeyError):
logger.warning("auth:token cache entry malformed; treating as miss")
return None
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
self._redis.setex(
self._cache_key(token_hash),
self._positive_ttl,
json.dumps(row.to_cache()),
)
def cache_set_negative(self, token_hash: str) -> None:
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
"""Atomic CAS — only the worker that flips revoked_at emits audit;
replays are idempotent.
"""
stmt = (
update(OAuthAccessToken)
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
.values(revoked_at=datetime.now(UTC), token_hash=None)
)
result = session.execute(stmt)
session.commit()
if result.rowcount == 1:
logger.warning(
"audit: %s token_id=%s",
AUDIT_OAUTH_EXPIRED,
row_id,
extra={"audit": True, "token_id": str(row_id)},
)
self._redis.delete(self._cache_key(token_hash))
self.cache_set_negative(token_hash)
class _VariantResolver:
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
self._parent = parent
self._variant = variant
def resolve(self, token_hash: str) -> ResolvedRow | None:
cached = self._parent.cache_get(token_hash)
if cached == "invalid":
return None
if cached is not None and not isinstance(cached, str):
if not self._matches_variant(cached):
return None
return cached
# Flask-SQLAlchemy's scoped_session is request-bound and not a
# context manager; use it directly.
session = self._parent.session_factory()
row = self._load_from_db(session, token_hash)
if row is None:
self._parent.cache_set_negative(token_hash)
return None
now = datetime.now(UTC)
if row.expires_at is not None and row.expires_at <= now:
self._parent.hard_expire(session, row.id, token_hash)
return None
if not self._matches_variant_model(row):
logger.error(
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
row.id,
row.prefix,
)
return None
resolved = ResolvedRow(
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
client_id=row.client_id,
token_id=uuid.UUID(str(row.id)),
expires_at=row.expires_at,
)
self._parent.cache_set_positive(token_hash, resolved)
return resolved
def _matches_variant(self, row: ResolvedRow) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account
return not has_account
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account and row.prefix == "dfoa_"
return (not has_account) and row.prefix == "dfoe_"
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
return (
session.query(OAuthAccessToken)
.filter(
OAuthAccessToken.token_hash == token_hash,
OAuthAccessToken.revoked_at.is_(None),
)
.one_or_none()
)
# ============================================================================
# Layer 0 — workspace membership cache + helper
# ============================================================================
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
`auth:token:{hash}`. No-op if entry missing/expired/invalid next request
rebuilds via authenticate() and re-runs Layer 0.
"""
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
raw = redis_client.get(cache_key)
if raw is None:
return
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return
try:
data = json.loads(text)
except (ValueError, KeyError):
return
ttl = redis_client.ttl(cache_key)
if ttl <= 0:
return
data.setdefault("verified_tenants", {})[tenant_id] = verdict
redis_client.setex(cache_key, ttl, json.dumps(data))
def check_workspace_membership(
*,
account_id: uuid.UUID | str,
tenant_id: str,
token_hash: str,
cached_verdicts: dict[str, bool],
) -> None:
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
inline helper (`require_workspace_member`). Caller is responsible for
short-circuiting on EE / SSO subjects before invoking this function
runs the membership + active-status checks unconditionally.
"""
cached = cached_verdicts.get(tenant_id)
if cached is True:
return
if cached is False:
raise Forbidden("workspace_membership_revoked")
join = db.session.execute(
select(TenantAccountJoin.id).where(
TenantAccountJoin.account_id == account_id,
TenantAccountJoin.tenant_id == tenant_id,
)
).scalar_one_or_none()
if join is None:
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
if status != "active":
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
record_layer0_verdict(token_hash, tenant_id, True)
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
(no `tenant_account_joins` row by definition).
"""
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
return
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=tenant_id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.verified_tenants,
)
# ============================================================================
# Decorator — route-level bearer gate
# ============================================================================
_authenticator: BearerAuthenticator | None = None
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
global _authenticator
_authenticator = authenticator
def get_authenticator() -> BearerAuthenticator:
if _authenticator is None:
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
return _authenticator
def _extract_bearer(req) -> str | None:
header = req.headers.get("Authorization", "")
scheme, _, value = header.partition(" ")
if scheme.lower() != "bearer" or not value:
return None
return value.strip()
_DP = ParamSpec("_DP")
_DR = TypeVar("_DR")
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
"""Opt-in: omitting it leaves the route unauthenticated.
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
``app-`` keys belong to ``service_api/wraps.py:validate_app_token``
and are rejected here as the wrong auth scheme for this surface.
"""
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
@wraps(fn)
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
token = _extract_bearer(request)
if token is None:
raise Unauthorized("missing bearer token")
if _authenticator is None:
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
try:
ctx = get_authenticator().authenticate(token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
raise Forbidden("token subject type not accepted here")
g.auth_ctx = ctx
return fn(*args, **kwargs)
return inner
return wrap
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
without the authenticator, so fail fast instead of approving silently.
"""
@wraps(fn)
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.ENABLE_OAUTH_BEARER:
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
return fn(*args, **kwargs)
return inner
def require_scope(scope: Scope) -> Callable:
"""Route-level scope gate — must run AFTER validate_bearer so that
g.auth_ctx is set. Raises Forbidden('insufficient_scope: <scope>')
when the bearer lacks both the requested scope and `Scope.FULL`.
"""
def wrap(fn: Callable) -> Callable:
@wraps(fn)
def inner(*args, **kwargs):
ctx = getattr(g, "auth_ctx", None)
if ctx is None:
raise RuntimeError(
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
)
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
raise Forbidden(f"insufficient_scope: {scope}")
return fn(*args, **kwargs)
return inner
return wrap
# ============================================================================
# Wiring — called once from the app factory
# ============================================================================
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
account = MINTABLE_PROFILES[SubjectType.ACCOUNT]
external = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
return TokenKindRegistry(
[
TokenKind(
prefix=account.prefix,
subject_type=account.subject_type,
scopes=account.scopes,
source="oauth_account",
resolver=oauth.for_account(),
),
TokenKind(
prefix=external.prefix,
subject_type=external.subject_type,
scopes=external.scopes,
source="oauth_external_sso",
resolver=oauth.for_external_sso(),
),
]
)
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
registry = build_registry(session_factory, redis_client)
auth = BearerAuthenticator(registry)
bind_authenticator(auth)
return auth

140
api/libs/rate_limit.py Normal file
View File

@ -0,0 +1,140 @@
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
window Redis ZSET). Apply after auth decorators so scopes can read
``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed
in-handler. RFC-8628 ``slow_down`` is inline its response shape isn't
generic 429.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from enum import StrEnum
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import g, jsonify, make_response, request, session
from werkzeug.exceptions import TooManyRequests
from configs import dify_config
from libs.helper import RateLimiter, extract_remote_ip
class RateLimitScope(StrEnum):
IP = "ip"
SESSION = "session"
ACCOUNT = "account"
SUBJECT_EMAIL = "subject_email"
TOKEN_ID = "token_id"
@dataclass(frozen=True, slots=True)
class RateLimit:
limit: int
window: timedelta
scopes: tuple[RateLimitScope, ...]
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
LIMIT_BEARER_PER_TOKEN = RateLimit(
limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN,
window=timedelta(minutes=1),
scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token)
)
def _one_key(scope: RateLimitScope) -> str:
match scope:
case RateLimitScope.IP:
return f"ip:{extract_remote_ip(request) or 'unknown'}"
case RateLimitScope.SESSION:
return f"session:{session.get('_id', 'anon')}"
case RateLimitScope.ACCOUNT:
ctx = getattr(g, "auth_ctx", None)
if ctx and ctx.account_id:
return f"account:{ctx.account_id}"
return "account:anon"
case RateLimitScope.SUBJECT_EMAIL:
ctx = getattr(g, "auth_ctx", None)
if ctx and ctx.subject_email:
return f"subject:{ctx.subject_email}"
return "subject:anon"
case RateLimitScope.TOKEN_ID:
ctx = getattr(g, "auth_ctx", None)
if ctx and ctx.token_id:
return f"token:{ctx.token_id}"
return "token:anon"
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
return "|".join(_one_key(s) for s in scopes)
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
return "rl:" + "+".join(s.value for s in scopes)
def _build_limiter(spec: RateLimit) -> RateLimiter:
return RateLimiter(
prefix=_limiter_prefix(spec.scopes),
max_attempts=spec.limit,
time_window=int(spec.window.total_seconds()),
)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Apply after auth decorators that the scopes read from."""
limiter = _build_limiter(spec)
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
key = _composite_key(spec.scopes)
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")
limiter.increment_rate_limit(key)
return fn(*args, **kwargs)
return inner
return wrap
def enforce(spec: RateLimit, *, key: str) -> None:
"""Imperative form — caller composes the bucket key to match scope
semantics (the key is opaque here).
"""
limiter = _build_limiter(spec)
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")
limiter.increment_rate_limit(key)
def enforce_bearer_rate_limit(token_hash: str) -> None:
"""Per-token rate limit on /openapi/v1/* bearer-authed routes.
Bucket key = ``token:<sha256_hex>`` so the same token shares one
bucket across api replicas (Redis-backed sliding window).
"""
limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN)
key = f"token:{token_hash}"
if limiter.is_rate_limited(key):
retry_after = limiter.seconds_until_available(key)
response = make_response(
jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}),
429,
)
response.headers["Retry-After"] = str(retry_after)
raise TooManyRequests(response=response)
limiter.increment_rate_limit(key)

View File

@ -72,11 +72,15 @@ def extract_csrf_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
def extract_access_token(request: Request) -> str | None:
def _try_extract_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
def extract_console_cookie_token(request: Request) -> str | None:
"""Cookie-only console session token. Used by /openapi/v1/oauth/device/*
approval routes, which must not fall through to the Authorization header
(that's where dfoa_/dfoe_ bearers live — they aren't JWTs)."""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
def extract_access_token(request: Request) -> str | None:
return extract_console_cookie_token(request) or _try_extract_from_header(request)
def extract_webapp_access_token(request: Request) -> str | None:

View File

@ -0,0 +1,104 @@
"""add oauth_access_tokens table
Revision ID: d4a5e1f3c9b7
Revises: 227822d22895, b69ca54b9208, 2a3aebbbf4bb
Create Date: 2026-04-23 22:00:00.000000
Merges the three open heads at time of authoring (add_workflow_comments_table,
add_chatbot_color_theme, add_app_tracing) into a single parent so the new
oauth_access_tokens table sits on a definite linear chain thereafter.
Table stores user-level OAuth bearer tokens minted via the device-flow grant
(difyctl auth login). PAT storage (personal_access_tokens) is a separate
table not added in this migration.
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "d4a5e1f3c9b7"
down_revision = ("227822d22895", "b69ca54b9208", "2a3aebbbf4bb")
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"oauth_access_tokens",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
primary_key=True,
),
sa.Column("subject_email", sa.Text(), nullable=False),
sa.Column("subject_issuer", sa.Text(), nullable=True),
sa.Column("account_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("device_label", sa.Text(), nullable=False),
sa.Column("prefix", sa.String(length=8), nullable=False),
sa.Column("token_hash", sa.String(length=64), nullable=True, unique=True),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("NOW()"),
nullable=False,
),
sa.Column("last_used_at", sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["account_id"],
["accounts.id"],
name="fk_oauth_access_tokens_account_id",
ondelete="SET NULL",
),
)
op.create_index(
"idx_oauth_subject_email",
"oauth_access_tokens",
["subject_email"],
postgresql_where=sa.text("revoked_at IS NULL"),
)
op.create_index(
"idx_oauth_account",
"oauth_access_tokens",
["account_id"],
postgresql_where=sa.text("revoked_at IS NULL AND account_id IS NOT NULL"),
)
op.create_index(
"idx_oauth_client",
"oauth_access_tokens",
["subject_email", "client_id"],
postgresql_where=sa.text("revoked_at IS NULL"),
)
op.create_index(
"idx_oauth_token_hash",
"oauth_access_tokens",
["token_hash"],
postgresql_where=sa.text("revoked_at IS NULL"),
)
# Partial unique index — rotate-in-place keyed on (subject, client, device).
# The app always writes a non-NULL subject_issuer (account flow uses a
# sentinel, external-SSO uses the verified IdP issuer); without that the
# composite key would never collide because Postgres treats NULLs as
# distinct in unique indices.
op.create_index(
"uq_oauth_active_per_device",
"oauth_access_tokens",
["subject_email", "subject_issuer", "client_id", "device_label"],
unique=True,
postgresql_where=sa.text("revoked_at IS NULL"),
)
def downgrade():
op.drop_index("uq_oauth_active_per_device", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_token_hash", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_client", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_account", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_subject_email", table_name="oauth_access_tokens")
op.drop_table("oauth_access_tokens")

View File

@ -73,7 +73,7 @@ from .model import (
TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
from .oauth import DatasourceOauthParamConfig, DatasourceProvider, OAuthAccessToken
from .provider import (
LoadBalancingModelConfig,
Provider,
@ -177,6 +177,7 @@ __all__ = [
"MessageChain",
"MessageFeedback",
"MessageFile",
"OAuthAccessToken",
"OperationLog",
"PinnedConversation",
"Provider",

View File

@ -84,3 +84,35 @@ class DatasourceOauthTenantParamConfig(TypeBase):
onupdate=func.current_timestamp(),
init=False,
)
class OAuthAccessToken(TypeBase):
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account,
subject_issuer = "dify:account" sentinel); account_id NULL +
subject_issuer = verified IdP issuer dfoe_ (external SSO, EE-only).
subject_issuer is non-NULL for all rows the app writes Postgres
treats NULLs as distinct in unique indices, so the partial unique
index on (subject_email, subject_issuer, client_id, device_label)
WHERE revoked_at IS NULL would otherwise fail to rotate in place.
"""
__tablename__ = "oauth_access_tokens"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
subject_email: Mapped[str] = mapped_column(sa.Text, nullable=False)
client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False)
device_label: Mapped[str] = mapped_column(sa.Text, nullable=False)
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
subject_issuer: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None)
last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
)

View File

@ -1206,6 +1206,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
SERVICE_API = "service-api"
WEB_APP = "web-app"
INSTALLED_APP = "installed-app"
OPENAPI = "openapi"
@classmethod
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":

View File

@ -0,0 +1,54 @@
"""DELETE oauth_access_tokens past retention. Revocation is UPDATE
(token_id stays for audits) so rows accumulate across re-logins, and
expired-but-never-presented rows have no hard-expire trigger both get
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
"""
from __future__ import annotations
import logging
import time
from datetime import UTC, datetime, timedelta
import click
from sqlalchemy import delete, or_, select
import app
from configs import dify_config
from extensions.ext_database import db
from models.oauth import OAuthAccessToken
logger = logging.getLogger(__name__)
DELETE_BATCH_SIZE = 500
@app.celery.task(queue="retention")
def clean_oauth_access_tokens_task():
click.echo(click.style("Start clean oauth_access_tokens.", fg="green"))
retention_days = int(dify_config.OAUTH_ACCESS_TOKEN_RETENTION_DAYS)
cutoff = datetime.now(UTC) - timedelta(days=retention_days)
start_at = time.perf_counter()
candidates = or_(
OAuthAccessToken.revoked_at < cutoff,
# Zombies: expired but never re-presented, so middleware never flipped them.
(OAuthAccessToken.revoked_at.is_(None)) & (OAuthAccessToken.expires_at < cutoff),
)
total = 0
while True:
ids = db.session.scalars(select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)).all()
if not ids:
break
db.session.execute(delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)))
db.session.commit()
total += len(ids)
end_at = time.perf_counter()
click.echo(
click.style(
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d in {end_at - start_at:.2f}s",
fg="green",
)
)

View File

@ -37,7 +37,7 @@ class AppService:
Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id
:param args: request args
:param args: request args. Optional keys: status (e.g. "normal") restricts App.status.
:return:
"""
filters = [App.tenant_id == tenant_id, App.is_universal == False]
@ -53,6 +53,14 @@ class AppService:
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT)
if args.get("status"):
filters.append(App.status == args["status"])
# OpenAPI surface visibility gate. Pushed into the query so
# `pagination.total` reflects only apps the openapi caller can
# actually reach — post-filtering by enable_api after the page
# arrives would make `total` page-dependent.
if args.get("openapi_visible"):
filters.append(App.enable_api.is_(True))
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):

View File

@ -0,0 +1,44 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from werkzeug.exceptions import ServiceUnavailable
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.enterprise import EnterpriseAPIError
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class PermittedAppsPage:
app_ids: list[str]
total: int
has_more: bool
def list_permitted_apps(
*,
page: int,
limit: int,
mode: str | None = None,
name: str | None = None,
) -> PermittedAppsPage:
try:
body = EnterpriseService.WebAppAuth.list_externally_accessible_apps(
page=page, limit=limit, mode=mode, name=name
)
except EnterpriseAPIError as exc:
logger.warning(
"permitted_apps EE call failed: status=%s message=%s",
getattr(exc, "status_code", None),
str(exc),
)
raise ServiceUnavailable("permitted_apps_unavailable") from exc
return PermittedAppsPage(
app_ids=[row["appId"] for row in body.get("data", [])],
total=int(body.get("total", 0)),
has_more=bool(body.get("hasMore", False)),
)

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import enum
import logging
import uuid
from datetime import datetime
@ -23,10 +24,22 @@ VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
class WebAppAccessMode(enum.StrEnum):
PUBLIC = "public"
PRIVATE = "private"
PRIVATE_ALL = "private_all"
SSO_VERIFIED = "sso_verified"
PERMISSION_CHECK_MODES: frozenset[WebAppAccessMode] = frozenset(
{WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
)
class WebAppSettings(BaseModel):
access_mode: str = Field(
description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'",
default="private",
description=f"Access mode for the web app. One of: {', '.join(m.value for m in WebAppAccessMode)}",
default=WebAppAccessMode.PRIVATE.value,
alias="accessMode",
)
@ -106,6 +119,15 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
return EnterpriseRequest.send_request(
"POST",
"/device-flow/sso-initiate",
json={"signed_state": signed_state},
raise_for_status=True,
)
@classmethod
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
"""
@ -217,8 +239,9 @@ class EnterpriseService:
def update_app_access_mode(cls, app_id: str, access_mode: str):
if not app_id:
raise ValueError("app_id must be provided.")
if access_mode not in ["public", "private", "private_all"]:
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
allowed = {WebAppAccessMode.PUBLIC, WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
if access_mode not in allowed:
raise ValueError(f"access_mode must be one of: {', '.join(m.value for m in allowed)}")
data = {"appId": app_id, "accessMode": access_mode}
@ -234,6 +257,32 @@ class EnterpriseService:
params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
@classmethod
def list_externally_accessible_apps(
cls,
*,
page: int,
limit: int,
mode: str | None = None,
name: str | None = None,
) -> dict:
"""Call EE InnerListExternallyAccessibleApps; returns raw camelCase response.
Response shape: ``{"data": [{"appId", "tenantId", "mode", "name", "updatedAt"}],
"total": int, "hasMore": bool}``.
"""
body: dict[str, str | int] = {"page": page, "limit": limit}
if mode is not None:
body["mode"] = mode
if name is not None:
body["name"] = name
return EnterpriseRequest.send_request(
"POST",
"/webapp/externally-accessible-apps",
json=body,
timeout=5.0,
)
@classmethod
def get_cached_license_status(cls) -> LicenseStatus | None:
"""Get enterprise license status with Redis caching to reduce HTTP calls.

View File

@ -0,0 +1,467 @@
"""Device-flow service layer: Redis state machine, OAuth token mint
(DB upsert + plaintext generation), and TTL policy. Specs:
docs/specs/v1.0/server/{device-flow.md, tokens.md}.
"""
from __future__ import annotations
import hashlib
import json
import logging
import os
import secrets
import time
import uuid
from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session, scoped_session
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
from models.oauth import OAuthAccessToken
logger = logging.getLogger(__name__)
# ============================================================================
# Redis state machine — device_code + user_code ephemeral state
# ============================================================================
_DEVICE_CODE_KEY_PREFIX = "device_code:"
_USER_CODE_KEY_PREFIX = "user_code:"
DEVICE_CODE_KEY_FMT = _DEVICE_CODE_KEY_PREFIX + "{code}"
USER_CODE_KEY_FMT = _USER_CODE_KEY_PREFIX + "{code}"
# Atomic GET → status-check → DEL(both keys). Two concurrent pollers must
# not both observe APPROVED — only the winner gets the plaintext token,
# the loser sees nil and the caller maps that to expired_token.
_CONSUME_ON_POLL_LUA = """
local raw = redis.call('GET', KEYS[1])
if not raw then return nil end
local ok, decoded = pcall(cjson.decode, raw)
if not ok then return nil end
if decoded.status == 'pending' then return nil end
if decoded.user_code then
redis.call('DEL', ARGV[1] .. decoded.user_code)
end
redis.call('DEL', KEYS[1])
return raw
"""
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
USER_CODE_SEGMENT_LEN = 4
USER_CODE_MAX_CLAIM_ATTEMPTS = 5
DEFAULT_POLL_INTERVAL_SECONDS = 5 # RFC 8628 minimum
class DeviceFlowStatus(StrEnum):
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
class SlowDownDecision(StrEnum):
OK = "ok"
SLOW_DOWN = "slow_down"
@dataclass
class DeviceFlowState:
"""``minted_token`` is plaintext between approve and the next poll;
DEL'd after the poll reads it.
"""
user_code: str
client_id: str
device_label: str
status: DeviceFlowStatus
subject_email: str | None = None
account_id: str | None = None
subject_issuer: str | None = None
minted_token: str | None = None
token_id: str | None = None
created_at: str = ""
created_ip: str = ""
last_poll_at: str = ""
poll_payload: dict | None = field(default=None)
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, raw: str) -> DeviceFlowState:
data = json.loads(raw)
if "status" in data:
data["status"] = DeviceFlowStatus(data["status"])
return cls(**data)
def _random_device_code() -> str:
return "dc_" + secrets.token_urlsafe(24)
def _random_user_code_segment() -> str:
return "".join(secrets.choice(USER_CODE_ALPHABET) for _ in range(USER_CODE_SEGMENT_LEN))
def _random_user_code() -> str:
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
class StateNotFoundError(Exception):
pass
class InvalidTransitionError(Exception):
pass
class UserCodeExhaustedError(Exception):
pass
class DeviceFlowRedis:
def __init__(self, redis_client) -> None:
self._redis = redis_client
self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA)
def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]:
device_code = _random_device_code()
user_code = self._claim_user_code(device_code)
state = DeviceFlowState(
user_code=user_code,
client_id=client_id,
device_label=device_label,
status=DeviceFlowStatus.PENDING,
created_at=datetime.now(UTC).isoformat(),
created_ip=created_ip,
)
self._redis.setex(
DEVICE_CODE_KEY_FMT.format(code=device_code),
DEVICE_FLOW_TTL_SECONDS,
state.to_json(),
)
return device_code, user_code, DEVICE_FLOW_TTL_SECONDS
def _claim_user_code(self, device_code: str) -> str:
for _ in range(USER_CODE_MAX_CLAIM_ATTEMPTS):
user_code = _random_user_code()
key = USER_CODE_KEY_FMT.format(code=user_code)
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
if ok:
return user_code
raise UserCodeExhaustedError("could not allocate a unique user_code in 5 attempts")
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
if not raw_dc:
return None
device_code = raw_dc.decode() if isinstance(raw_dc, (bytes, bytearray)) else raw_dc
state = self._load_state(device_code)
if state is None:
return None
return device_code, state
def load_by_device_code(self, device_code: str) -> DeviceFlowState | None:
return self._load_state(device_code)
def _load_state(self, device_code: str) -> DeviceFlowState | None:
raw = self._redis.get(DEVICE_CODE_KEY_FMT.format(code=device_code))
if not raw:
return None
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
try:
return DeviceFlowState.from_json(text_)
except (ValueError, KeyError):
logger.exception("device_flow: corrupt state for %s", device_code)
return None
def approve(
self,
device_code: str,
subject_email: str,
account_id: str | None,
minted_token: str,
token_id: str,
subject_issuer: str | None = None,
poll_payload: dict | None = None,
) -> None:
state = self._load_state(device_code)
if state is None:
raise StateNotFoundError(device_code)
if state.status is not DeviceFlowStatus.PENDING:
raise InvalidTransitionError(f"cannot approve {state.status}")
state.status = DeviceFlowStatus.APPROVED
state.subject_email = subject_email
state.account_id = account_id
state.subject_issuer = subject_issuer
state.minted_token = minted_token
state.token_id = token_id
state.poll_payload = poll_payload
new_ttl = self._remaining_ttl(device_code, floor=APPROVED_TTL_SECONDS_MIN)
self._redis.setex(DEVICE_CODE_KEY_FMT.format(code=device_code), new_ttl, state.to_json())
def deny(self, device_code: str) -> None:
state = self._load_state(device_code)
if state is None:
raise StateNotFoundError(device_code)
if state.status is not DeviceFlowStatus.PENDING:
raise InvalidTransitionError(f"cannot deny {state.status}")
state.status = DeviceFlowStatus.DENIED
self._redis.setex(
DEVICE_CODE_KEY_FMT.format(code=device_code),
self._remaining_ttl(device_code, floor=1),
state.to_json(),
)
def consume_on_poll(self, device_code: str) -> DeviceFlowState | None:
"""Race-safe via Lua EVAL: GET + status-check + DEL execute in a
single Redis transaction so only one of N concurrent pollers
observes the APPROVED state. Losers get None, mapped to
expired_token by the caller.
"""
raw = self._consume_on_poll_script(
keys=[DEVICE_CODE_KEY_FMT.format(code=device_code)],
args=[_USER_CODE_KEY_PREFIX],
)
if raw is None:
return None
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
try:
return DeviceFlowState.from_json(text_)
except (ValueError, KeyError):
logger.exception("device_flow: corrupt state on consume %s", device_code)
return None
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
now = time.time()
key = f"device_code:{device_code}:last_poll"
prev_raw = self._redis.get(key)
self._redis.setex(key, DEVICE_FLOW_TTL_SECONDS, str(now))
if prev_raw is None:
return SlowDownDecision.OK
prev_s = prev_raw.decode() if isinstance(prev_raw, (bytes, bytearray)) else prev_raw
try:
prev = float(prev_s)
except ValueError:
return SlowDownDecision.OK
if now - prev < interval_seconds:
return SlowDownDecision.SLOW_DOWN
return SlowDownDecision.OK
def _remaining_ttl(self, device_code: str, floor: int) -> int:
"""``max(remaining, floor)`` — guarantees the CLI has at least
``floor`` seconds to poll after a near-expiry approve.
"""
ttl = self._redis.ttl(DEVICE_CODE_KEY_FMT.format(code=device_code))
if ttl is None or ttl < 0:
return floor
return max(int(ttl), floor)
# ============================================================================
# Token mint — generate + upsert
# ============================================================================
OAUTH_BODY_BYTES = 32 # ~256 bits entropy
PREFIX_OAUTH_ACCOUNT = "dfoa_"
PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_"
# Sentinel issuer for account-flow rows. Postgres' default partial unique
# index treats NULLs as distinct, which would let two live `dfoa_` rows
# share (email, client, device) and break rotate-in-place. Storing a
# non-empty literal makes the composite key collide as intended.
ACCOUNT_ISSUER_SENTINEL = "dify:account"
@dataclass(frozen=True, slots=True)
class MintResult:
"""Plaintext token surfaces to the caller once."""
token: str
token_id: uuid.UUID
expires_at: datetime
@dataclass(frozen=True, slots=True)
class UpsertOutcome:
token_id: uuid.UUID
rotated: bool
old_hash: str | None
def generate_token(prefix: str) -> str:
return prefix + secrets.token_urlsafe(OAUTH_BODY_BYTES)
def sha256_hex(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def mint_oauth_token(
# Accept either Session or Flask-SQLAlchemy's request-scoped wrapper —
# the wrapper proxies the same execute/commit surface.
session: Session | scoped_session,
redis_client,
*,
subject_email: str,
subject_issuer: str | None,
account_id: str | None,
client_id: str,
device_label: str,
prefix: str,
ttl_days: int,
) -> MintResult:
"""Live row rotates in place via partial unique index
``uq_oauth_active_per_device``; hard-expired rows are excluded by the
index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is
deleted so stale AuthContext drops immediately.
"""
if prefix == PREFIX_OAUTH_ACCOUNT:
# Account flow always writes the sentinel — caller may pass None
# (for clarity) or the sentinel itself; nothing else is valid.
if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL):
raise ValueError(f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}")
subject_issuer = ACCOUNT_ISSUER_SENTINEL
elif prefix == PREFIX_OAUTH_EXTERNAL_SSO:
# Defense in depth: enterprise canonicalises + rejects empty,
# but a regression there must not yield a NULL composite key here.
if not subject_issuer or not subject_issuer.strip():
raise ValueError("external-SSO token requires non-empty subject_issuer")
else:
raise ValueError(f"unknown oauth prefix: {prefix!r}")
token = generate_token(prefix)
new_hash = sha256_hex(token)
expires_at = datetime.now(UTC) + timedelta(days=ttl_days)
outcome = _upsert(
session,
subject_email=subject_email,
subject_issuer=subject_issuer,
account_id=account_id,
client_id=client_id,
device_label=device_label,
prefix=prefix,
new_hash=new_hash,
expires_at=expires_at,
)
if outcome.rotated and outcome.old_hash:
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=outcome.old_hash))
return MintResult(token=token, token_id=outcome.token_id, expires_at=expires_at)
def _upsert(
session: Session | scoped_session,
*,
subject_email: str,
subject_issuer: str | None,
account_id: str | None,
client_id: str,
device_label: str,
prefix: str,
new_hash: str,
expires_at: datetime,
) -> UpsertOutcome:
# Snapshot prior live row's hash for Redis invalidation post-rotate.
# subject_issuer is always non-null here (account flow uses sentinel,
# external-SSO is validated upstream), so equality matches the index.
prior = session.execute(
select(OAuthAccessToken.id, OAuthAccessToken.token_hash)
.where(
OAuthAccessToken.subject_email == subject_email,
OAuthAccessToken.subject_issuer == subject_issuer,
OAuthAccessToken.client_id == client_id,
OAuthAccessToken.device_label == device_label,
OAuthAccessToken.revoked_at.is_(None),
)
.limit(1)
).first()
old_hash = prior.token_hash if prior else None
insert_stmt = pg_insert(OAuthAccessToken).values(
subject_email=subject_email,
subject_issuer=subject_issuer,
account_id=account_id,
client_id=client_id,
device_label=device_label,
prefix=prefix,
token_hash=new_hash,
expires_at=expires_at,
)
upsert_stmt = insert_stmt.on_conflict_do_update(
index_elements=["subject_email", "subject_issuer", "client_id", "device_label"],
index_where=OAuthAccessToken.revoked_at.is_(None),
set_={
"token_hash": insert_stmt.excluded.token_hash,
"prefix": insert_stmt.excluded.prefix,
"account_id": insert_stmt.excluded.account_id,
"expires_at": insert_stmt.excluded.expires_at,
"created_at": func.now(),
"last_used_at": None,
},
).returning(OAuthAccessToken.id)
row = session.execute(upsert_stmt).first()
session.commit()
if row is None:
raise RuntimeError("oauth_token upsert returned no row")
token_id = uuid.UUID(str(row.id))
return UpsertOutcome(
token_id=token_id,
rotated=prior is not None,
old_hash=old_hash,
)
# ============================================================================
# TTL policy — days new OAuth tokens live
# ============================================================================
DEFAULT_OAUTH_TTL_DAYS = 14
MIN_TTL_DAYS = 1
MAX_TTL_DAYS = 365
_TTL_ENV_VAR = "OAUTH_TTL_DAYS"
def oauth_ttl_days(tenant_id: str | None = None) -> int:
"""``OAUTH_TTL_DAYS`` env, else default. EE tenant-level lookup
is deferred; when it lands it wins over the env (Redis-cached 60s).
"""
_ = tenant_id
raw = os.environ.get(_TTL_ENV_VAR)
if raw is None:
return DEFAULT_OAUTH_TTL_DAYS
try:
value = int(raw)
except ValueError:
logger.warning(
"%s=%r is not an int; falling back to %d",
_TTL_ENV_VAR,
raw,
DEFAULT_OAUTH_TTL_DAYS,
)
return DEFAULT_OAUTH_TTL_DAYS
if value < MIN_TTL_DAYS:
logger.warning("%s=%d below min %d; clamping", _TTL_ENV_VAR, value, MIN_TTL_DAYS)
return MIN_TTL_DAYS
if value > MAX_TTL_DAYS:
logger.warning("%s=%d above max %d; clamping", _TTL_ENV_VAR, value, MAX_TTL_DAYS)
return MAX_TTL_DAYS
return value

View File

View File

@ -0,0 +1,54 @@
"""License gate for the /openapi/v1/permitted-external-apps* surface.
EE-only. CE deploys (``ENTERPRISE_ENABLED=false``) skip the gate entirely
the EE blueprint chain is what gives CE deploys no callers on this surface
in practice, but the explicit short-circuit avoids any test/fixture that
flips the surface on without flipping the license.
Reuses ``FeatureService.get_system_features()`` so the license status
travels the same path as the console reads.
Companion to ``controllers.console.wraps.enterprise_license_required``
that one is for console (cookie-authed, force-logout 401). This one is
for bearer surface (token-authed, 403 ``license_required``).
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from functools import wraps
from werkzeug.exceptions import Forbidden
from configs import dify_config
from services.feature_service import FeatureService, LicenseStatus
logger = logging.getLogger(__name__)
_VALID_LICENSE_STATUSES: frozenset[LicenseStatus] = frozenset(
{LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
)
def license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""Decorator form. Raises ``Forbidden('license_required')`` when the EE
deployment has no valid license. No-op on CE (``ENTERPRISE_ENABLED=false``).
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if dify_config.ENTERPRISE_ENABLED and not _is_license_valid():
raise Forbidden(description="license_required")
return view(*args, **kwargs)
return decorated
def _is_license_valid() -> bool:
try:
features = FeatureService.get_system_features()
except Exception:
logger.exception("license_gate: FeatureService.get_system_features failed")
return False
return features.license.status in _VALID_LICENSE_STATUSES

View File

@ -0,0 +1,51 @@
"""Hard mint policy.
``validate_mint_policy`` cross-checks a (subject_type, prefix, scopes)
triple a caller intends to mint against ``MINTABLE_PROFILES``
the single source of truth in ``libs.oauth_bearer``.
The defense-in-depth value: if a future caller assembles ``prefix`` or
``scopes`` from a non-canonical source (env, request body, plug-in
contribution), the mismatch fails closed at approve time before any
row hits the DB. When the caller reads straight from
``MINTABLE_PROFILES``, the check is a structural pin it confirms the
table entry is well-formed and the caller picked the right key.
"""
from __future__ import annotations
from libs.oauth_bearer import MINTABLE_PROFILES, Scope, SubjectType
class MintPolicyViolation(Exception): # noqa: N818 — spec-defined name, used in BadRequest message
"""Raised on a (subject_type, prefix, scopes) mismatch. Callers translate
to 400 ``mint_policy_violation``."""
def validate_mint_policy(
*,
subject_type: SubjectType,
prefix: str,
scopes: frozenset[Scope],
) -> None:
"""Raise ``MintPolicyViolation`` when the triple does not match the
canonical ``MINTABLE_PROFILES`` entry for ``subject_type``.
"""
profile = MINTABLE_PROFILES.get(subject_type)
if profile is None:
raise MintPolicyViolation(
f"mint_policy_violation: unknown subject_type={subject_type!r}"
)
drift = []
if profile.prefix != prefix:
drift.append(f"prefix got={prefix!r} expected={profile.prefix!r}")
if frozenset(scopes) != profile.scopes:
got = sorted(s.value for s in scopes)
want = sorted(s.value for s in profile.scopes)
drift.append(f"scopes got={got} expected={want}")
if drift:
raise MintPolicyViolation(
f"mint_policy_violation: subject_type={subject_type.value}" + "; ".join(drift)
)

View File

@ -0,0 +1,32 @@
"""Single-source visibility filter for the /openapi/v1/* surface.
Keep every openapi-surface app query routed through ``_apply_openapi_gate``;
retiring or replacing the gate then becomes a one-line change here.
The Service API (/v1/* app-key surface) does NOT use this helper that
surface has its own per-request guard (``service_api_disabled``) wired
into the legacy ``validate_app_token`` decorator.
"""
from __future__ import annotations
from typing import Any
from models.model import App
def apply_openapi_gate(query: Any) -> Any:
"""Filter a SQLAlchemy Select/Query to apps visible on /openapi/v1/*.
Works with both legacy ``Query.filter`` and 2.0-style ``Select.filter``
(alias of ``.where``).
"""
return query.filter(App.enable_api.is_(True))
def is_openapi_visible(app: App) -> bool:
"""Per-row counterpart for code paths that fetch an App by primary key
(``session.get`` / ``session.scalar``) and need the same visibility check
the query gate would have applied.
"""
return bool(app.enable_api)

View File

@ -15,7 +15,7 @@ from models import Account, AccountStatus
from models.model import App, EndUser, Site
from services.account_service import AccountService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.enterprise.enterprise_service import PERMISSION_CHECK_MODES, EnterpriseService, WebAppAccessMode
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from tasks.mail_email_code_login import send_email_code_login_mail_task
@ -137,12 +137,8 @@ class WebAppAuthService:
"""
Check if the app requires permission check based on its access mode.
"""
modes_requiring_permission_check = [
"private",
"private_all",
]
if access_mode:
return access_mode in modes_requiring_permission_check
return access_mode in PERMISSION_CHECK_MODES
if not app_code and not app_id:
raise ValueError("Either app_code or app_id must be provided.")
@ -153,7 +149,7 @@ class WebAppAuthService:
raise ValueError("App ID could not be determined from the provided app_code.")
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check:
if webapp_settings and webapp_settings.access_mode in PERMISSION_CHECK_MODES:
return True
return False
@ -166,11 +162,11 @@ class WebAppAuthService:
raise ValueError("Either app_code or access_mode must be provided.")
if access_mode:
if access_mode == "public":
if access_mode == WebAppAccessMode.PUBLIC:
return WebAppAuthType.PUBLIC
elif access_mode in ["private", "private_all"]:
elif access_mode in PERMISSION_CHECK_MODES:
return WebAppAuthType.INTERNAL
elif access_mode == "sso_verified":
elif access_mode == WebAppAccessMode.SSO_VERIFIED:
return WebAppAuthType.EXTERNAL
if app_code:

View File

@ -0,0 +1,125 @@
"""Shared fixtures for /openapi/v1/* integration tests."""
from __future__ import annotations
import hashlib
import uuid
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
import pytest
from flask import Flask
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, App, OAuthAccessToken, Tenant, TenantAccountJoin
from models.account import AccountStatus
def _sha256(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
@pytest.fixture(autouse=True)
def disable_enterprise(monkeypatch):
"""Default to CE behaviour for /openapi/v1 tests. Tests that exercise the
EE branch override this with their own monkeypatch in-test."""
from configs import dify_config
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False)
@pytest.fixture
def workspace_account(flask_app: Flask) -> Generator[tuple[Account, Tenant, TenantAccountJoin], None, None]:
with flask_app.app_context():
tenant = Tenant(name="t1", status="normal")
account = Account(email="u@example.com", name="u")
db.session.add_all([tenant, account])
db.session.commit()
account.status = AccountStatus.ACTIVE
join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role="owner")
db.session.add(join)
db.session.commit()
yield account, tenant, join
db.session.delete(join)
db.session.delete(account)
db.session.delete(tenant)
db.session.commit()
@pytest.fixture
def app_in_workspace(flask_app: Flask, workspace_account) -> Generator[App, None, None]:
_, tenant, _ = workspace_account
with flask_app.app_context():
app = App(tenant_id=tenant.id, name="a", mode="chat", status="normal", enable_site=True, enable_api=True)
db.session.add(app)
db.session.commit()
yield app
db.session.delete(app)
db.session.commit()
@pytest.fixture
def mint_token(flask_app: Flask):
"""Factory fixture; tracks minted rows and deletes them on teardown so
the auth-related test runs don't accumulate `oauth_access_tokens` rows."""
minted: list[OAuthAccessToken] = []
def _mint(
token: str,
*,
account_id: str | None,
prefix: str,
subject_email: str,
subject_issuer: str | None,
) -> OAuthAccessToken:
with flask_app.app_context():
row = OAuthAccessToken(
token_hash=_sha256(token),
prefix=prefix,
account_id=account_id,
subject_email=subject_email,
subject_issuer=subject_issuer,
client_id="difyctl",
device_label="test-device",
expires_at=datetime.now(UTC) + timedelta(hours=1),
)
db.session.add(row)
db.session.commit()
minted.append(row)
return row
yield _mint
with flask_app.app_context():
for row in minted:
db.session.delete(db.session.merge(row))
db.session.commit()
@pytest.fixture
def account_token(workspace_account, mint_token) -> str:
account, _, _ = workspace_account
token = "dfoa_" + uuid.uuid4().hex
mint_token(
token,
account_id=account.id,
prefix="dfoa_",
subject_email=account.email,
subject_issuer="dify:account",
)
return token
@pytest.fixture(autouse=True)
def _flush_auth_redis(flask_app: Flask) -> Generator[None, None, None]:
def _flush():
with flask_app.app_context():
for k in redis_client.keys("auth:*"):
redis_client.delete(k)
for k in redis_client.keys("rl:*"):
redis_client.delete(k)
_flush()
yield
_flush()

View File

@ -0,0 +1,252 @@
"""Integration tests for POST /openapi/v1/apps/<id>/run."""
from __future__ import annotations
import uuid
from collections.abc import Generator
import pytest
from flask import Flask
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models import App
def test_run_chat_dispatches_to_chat_handler(flask_app, account_token, app_in_workspace, monkeypatch):
captured = {}
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
captured["mode"] = app_model.mode
captured["args"] = args
captured["invoke_from"] = invoke_from
return {
"event": "message",
"task_id": "t",
"id": "m",
"message_id": "m",
"conversation_id": "c",
"mode": "chat",
"answer": "ok",
"created_at": 0,
}
monkeypatch.setattr(
"controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate)
)
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "query": "hi", "response_mode": "blocking", "user": "spoof@x.com"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.get_json()["mode"] == "chat"
assert captured["mode"] == "chat"
assert captured["invoke_from"] == InvokeFrom.OPENAPI
assert "user" not in captured["args"], "server must strip body.user; identity comes from bearer"
@pytest.fixture
def app_with_mode(flask_app: Flask, workspace_account):
"""Factory that creates an App row in the workspace_account tenant with
a specified mode. Tracks rows for teardown.
"""
_, tenant, _ = workspace_account
created: list[App] = []
def _make(mode: str) -> App:
with flask_app.app_context():
app = App(
tenant_id=tenant.id,
name=f"a-{mode}",
mode=mode,
status="normal",
enable_site=True,
enable_api=True,
)
db.session.add(app)
db.session.commit()
db.session.refresh(app)
db.session.expunge(app)
created.append(app)
return app
yield _make
with flask_app.app_context():
for app in created:
db.session.delete(db.session.merge(app))
db.session.commit()
def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_workspace, monkeypatch):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
assert b"query_required_for_chat" in res.data
def test_run_completion_dispatches_to_completion_handler(
flask_app, account_token, app_with_mode, monkeypatch
):
app = app_with_mode("completion")
captured: dict = {}
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
captured["mode"] = app_model.mode
captured["args"] = args
return {
"event": "message",
"task_id": "t",
"id": "m",
"message_id": "m",
"mode": "completion",
"answer": "ok",
"created_at": 0,
}
monkeypatch.setattr(
"controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate)
)
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app.id}/run",
json={"inputs": {}, "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.get_json()["mode"] == "completion"
assert captured["mode"] == "completion"
def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
app = app_with_mode("workflow")
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app.id}/run",
json={"inputs": {}, "query": "hi", "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
assert b"query_not_supported_for_workflow" in res.data
def test_run_workflow_no_query_dispatches_to_workflow_handler(
flask_app, account_token, app_with_mode, monkeypatch
):
app = app_with_mode("workflow")
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
return {
"workflow_run_id": "wfr",
"task_id": "t",
"data": {"id": "wf-d", "workflow_id": "wf", "status": "succeeded"},
}
monkeypatch.setattr(
"controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate)
)
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app.id}/run",
json={"inputs": {}, "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.get_json()
assert body["mode"] == "workflow"
assert body["workflow_run_id"] == "wfr"
def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
app = app_with_mode("channel")
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app.id}/run",
json={"inputs": {}, "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
assert b"mode_not_runnable" in res.data
def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "query": "hi"},
)
assert res.status_code == 401
def test_run_with_insufficient_scope_returns_403(
flask_app, account_token, app_in_workspace, monkeypatch
):
"""Stub the authenticator to return an AuthContext with empty scopes."""
from libs import oauth_bearer
real_authenticate = oauth_bearer.BearerAuthenticator.authenticate
def _stub_authenticate(self, token: str):
ctx = real_authenticate(self, token)
from dataclasses import replace
return replace(ctx, scopes=frozenset())
monkeypatch.setattr(oauth_bearer.BearerAuthenticator, "authenticate", _stub_authenticate)
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "query": "hi"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 403
def test_run_with_unknown_app_returns_404(flask_app, account_token):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{uuid.uuid4()}/run",
json={"inputs": {}, "query": "hi"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 404
def test_run_streaming_returns_event_stream(
flask_app, account_token, app_in_workspace, monkeypatch
):
def _stream() -> Generator[str, None, None]:
yield "event: message\ndata: {\"x\": 1}\n\n"
monkeypatch.setattr(
"controllers.openapi.app_run.AppGenerateService.generate",
staticmethod(lambda **kw: _stream()),
)
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "query": "hi", "response_mode": "streaming"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.headers["Content-Type"].startswith("text/event-stream")
assert b"event: message" in res.data
def test_run_without_inputs_returns_422(flask_app, account_token, app_in_workspace):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"query": "hi"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422

View File

@ -0,0 +1,210 @@
"""Integration tests for /openapi/v1/apps* read surface."""
from __future__ import annotations
from flask.testing import FlaskClient
from models import App
def test_apps_bare_id_route_404(test_client, app_in_workspace, account_token):
resp = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}",
headers={"Authorization": f"Bearer {account_token}"},
)
assert resp.status_code == 404
def test_apps_parameters_route_404(test_client, app_in_workspace, account_token):
resp = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/parameters",
headers={"Authorization": f"Bearer {account_token}"},
)
assert resp.status_code == 404
def test_apps_info_route_404(test_client, app_in_workspace, account_token):
resp = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert resp.status_code == 404
def test_apps_describe_returns_merged_shape(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"]["id"] == app_in_workspace.id
assert body["info"]["mode"] == "chat"
assert isinstance(body["parameters"], dict)
def test_apps_describe_full_includes_input_schema(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is not None
assert body["parameters"] is not None
assert body["input_schema"] is not None
assert body["input_schema"]["$schema"] == "https://json-schema.org/draft/2020-12/schema"
def test_apps_describe_fields_info_only(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is not None
assert body["parameters"] is None
assert body["input_schema"] is None
def test_apps_describe_fields_parameters_only(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=parameters",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is None
assert body["parameters"] is not None
assert body["input_schema"] is None
def test_apps_describe_fields_input_schema_only(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=input_schema",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is None
assert body["parameters"] is None
assert body["input_schema"] is not None
def test_apps_describe_fields_combined(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info,input_schema",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is not None
assert body["parameters"] is None
assert body["input_schema"] is not None
def test_apps_describe_fields_unknown_returns_422(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=garbage",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
def test_apps_describe_fields_extra_param_returns_422(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info&page=1",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
def test_apps_list_returns_pagination_envelope(
test_client: FlaskClient,
workspace_account,
app_in_workspace: App,
account_token: str,
):
_, tenant, _ = workspace_account
res = test_client.get(
f"/openapi/v1/apps?workspace_id={tenant.id}&page=1&limit=20",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["page"] == 1
assert body["limit"] == 20
assert body["total"] >= 1
assert any(d["id"] == app_in_workspace.id for d in body["data"])
def test_apps_list_requires_workspace_id(test_client: FlaskClient, account_token: str):
res = test_client.get("/openapi/v1/apps", headers={"Authorization": f"Bearer {account_token}"})
assert res.status_code == 400
def test_apps_list_tag_no_match_returns_empty_data_not_400(
test_client: FlaskClient,
workspace_account,
app_in_workspace: App,
account_token: str,
):
_, tenant, _ = workspace_account
res = test_client.get(
f"/openapi/v1/apps?workspace_id={tenant.id}&tag=nonexistent",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.json["data"] == []
def test_account_sessions_returns_envelope(
test_client: FlaskClient,
account_token: str,
):
res = test_client.get("/openapi/v1/account/sessions", headers={"Authorization": f"Bearer {account_token}"})
assert res.status_code == 200
body = res.json
# canonical envelope shape
assert isinstance(body["data"], list)
assert "page" in body
assert "limit" in body
assert "total" in body
assert "has_more" in body
# the bearer's own minted session must appear
assert any(s["prefix"] == "dfoa_" for s in body["data"])
# legacy "sessions" key must NOT appear
assert "sessions" not in body

View File

@ -0,0 +1,127 @@
"""Integration tests for the /openapi/v1 bearer auth surface.
Layer 0 (workspace membership), per-token rate limit, and read-scope (`apps:read`)
acceptance/rejection on app-scoped routes.
"""
from __future__ import annotations
from collections.abc import Generator
import pytest
from flask import Flask
from flask.testing import FlaskClient
from extensions.ext_database import db
from models import App, Tenant
def test_info_accepts_account_bearer_with_apps_read_scope(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
) -> None:
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.json["id"] == app_in_workspace.id
@pytest.fixture
def other_workspace_app(flask_app: Flask) -> Generator[App, None, None]:
"""A fresh app under a *different* tenant — caller has no membership row."""
with flask_app.app_context():
other_tenant = Tenant(name="other", status="normal")
db.session.add(other_tenant)
db.session.commit()
app = App(
tenant_id=other_tenant.id,
name="b",
mode="chat",
status="normal",
enable_site=True,
enable_api=True,
)
db.session.add(app)
db.session.commit()
yield app
db.session.delete(app)
db.session.delete(other_tenant)
db.session.commit()
def test_layer0_denies_account_bearer_without_membership(
test_client: FlaskClient,
account_token: str,
other_workspace_app: App,
) -> None:
"""Account A bearer hitting an app under tenant B — Layer 0 denies on CE."""
res = test_client.get(
f"/openapi/v1/apps/{other_workspace_app.id}/info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 403
assert res.json.get("message") == "workspace_membership_revoked"
def test_layer0_skipped_when_enterprise_enabled(
test_client: FlaskClient,
account_token: str,
other_workspace_app: App,
monkeypatch,
) -> None:
"""On EE, Layer 0 short-circuits — gateway RBAC owns tenant isolation.
/info uses validate_bearer + require_workspace_member inline (no
AppAuthzCheck), so a cross-tenant bearer reaches the app lookup and
gets 200 gateway is expected to enforce isolation upstream.
"""
from configs import dify_config
# Override the conftest autouse default for this test only.
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True)
res = test_client.get(
f"/openapi/v1/apps/{other_workspace_app.id}/info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.json.get("message") != "workspace_membership_revoked"
def test_rate_limit_returns_429_after_60_requests(
test_client: FlaskClient,
account_token: str,
) -> None:
"""61st sequential GET to /account on the same bearer → 429 with Retry-After."""
headers = {"Authorization": f"Bearer {account_token}"}
for i in range(60):
r = test_client.get("/openapi/v1/account", headers=headers)
assert r.status_code == 200, f"unexpected fail at i={i}"
r = test_client.get("/openapi/v1/account", headers=headers)
assert r.status_code == 429
assert r.headers.get("Retry-After"), "Retry-After header missing"
assert int(r.headers["Retry-After"]) >= 1
body = r.json or {}
assert body.get("error") == "rate_limited"
assert isinstance(body.get("retry_after_ms"), int)
assert body["retry_after_ms"] >= 1000
def test_rate_limit_bucket_shared_across_surfaces(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
) -> None:
"""30 calls to /account + 30 calls to /apps/<id>/info on same token → 61st 429s."""
headers = {"Authorization": f"Bearer {account_token}"}
for _ in range(30):
assert test_client.get("/openapi/v1/account", headers=headers).status_code == 200
for _ in range(30):
assert test_client.get(f"/openapi/v1/apps/{app_in_workspace.id}/info", headers=headers).status_code == 200
r = test_client.get("/openapi/v1/account", headers=headers)
assert r.status_code == 429

View File

@ -0,0 +1,66 @@
from unittest.mock import patch
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
AppResolver,
BearerCheck,
CallerMount,
ScopeCheck,
SurfaceCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
AclStrategy,
EndUserMounter,
MembershipStrategy,
)
from libs.oauth_bearer import SubjectType
def test_pipeline_is_composed():
assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline)
def test_pipeline_step_order():
"""BearerCheck → SurfaceCheck → ScopeCheck → AppResolver →
WorkspaceMembershipCheck AppAuthzCheck CallerMount.
SurfaceCheck enforces the dfoa_/dfoe_ surface split + emits
`openapi.wrong_surface_denied`. Rate-limit is enforced inside
`BearerAuthenticator.authenticate`, not as a separate pipeline step."""
steps = OAUTH_BEARER_PIPELINE._steps
assert isinstance(steps[0], BearerCheck)
assert isinstance(steps[1], SurfaceCheck)
assert isinstance(steps[2], ScopeCheck)
assert isinstance(steps[3], AppResolver)
assert isinstance(steps[4], WorkspaceMembershipCheck)
assert isinstance(steps[5], AppAuthzCheck)
assert isinstance(steps[6], CallerMount)
def test_pipeline_surface_check_accepts_account_only():
"""Current pipeline serves /apps/<id>/run — account surface only."""
surface = OAUTH_BEARER_PIPELINE._steps[1]
assert isinstance(surface, SurfaceCheck)
assert surface._accepted == frozenset({SubjectType.ACCOUNT})
def test_caller_mount_has_both_mounters():
cm = OAUTH_BEARER_PIPELINE._steps[6]
kinds = {type(m) for m in cm._mounters}
assert AccountMounter in kinds
assert EndUserMounter in kinds
@patch("controllers.openapi.auth.composition.FeatureService")
def test_strategy_resolver_picks_acl_when_enabled(fs):
fs.get_system_features.return_value.webapp_auth.enabled = True
assert isinstance(_resolve_app_authz_strategy(), AclStrategy)
@patch("controllers.openapi.auth.composition.FeatureService")
def test_strategy_resolver_picks_membership_when_disabled(fs):
fs.get_system_features.return_value.webapp_auth.enabled = False
assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy)

View File

@ -0,0 +1,21 @@
from unittest.mock import MagicMock
from controllers.openapi.auth.context import Context
def test_context_starts_unpopulated():
ctx = Context(request=MagicMock(), required_scope="apps:run")
assert ctx.subject_type is None
assert ctx.subject_email is None
assert ctx.account_id is None
assert ctx.scopes == frozenset()
assert ctx.app is None
assert ctx.tenant is None
assert ctx.caller is None
assert ctx.caller_kind is None
def test_context_fields_are_mutable():
ctx = Context(request=MagicMock(), required_scope="apps:run")
ctx.scopes = frozenset({"full"})
assert "full" in ctx.scopes

View File

@ -0,0 +1,61 @@
from unittest.mock import MagicMock
import pytest
from flask import Flask
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.pipeline import Pipeline
def test_run_invokes_each_step_in_order():
calls = []
class S:
def __init__(self, tag):
self.tag = tag
def __call__(self, ctx):
calls.append(self.tag)
Pipeline(S("a"), S("b"), S("c")).run(Context(request=MagicMock(), required_scope="x"))
assert calls == ["a", "b", "c"]
def test_run_short_circuits_on_raise():
calls = []
class Boom:
def __call__(self, ctx):
raise RuntimeError("boom")
class Tail:
def __call__(self, ctx):
calls.append("ran")
with pytest.raises(RuntimeError):
Pipeline(Boom(), Tail()).run(Context(request=MagicMock(), required_scope="x"))
assert calls == []
def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs():
seen = {}
class FakeStep:
def __call__(self, ctx):
ctx.app = "APP"
ctx.caller = "CALLER"
ctx.caller_kind = "account"
pipeline = Pipeline(FakeStep())
@pipeline.guard(scope="apps:run")
def handler(app_model, caller, caller_kind):
seen["app_model"] = app_model
seen["caller"] = caller
seen["caller_kind"] = caller_kind
return "ok"
app = Flask(__name__)
with app.test_request_context("/x", method="POST"):
assert handler() == "ok"
assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"}

View File

@ -0,0 +1,64 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import AppResolver
from models import TenantStatus
def _ctx(view_args):
req = MagicMock()
req.view_args = view_args
return Context(request=req, required_scope="apps:run")
def _app(*, status="normal", enable_api=True):
return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api)
def _tenant(*, status=TenantStatus.NORMAL):
return SimpleNamespace(id="t1", status=status)
def test_resolver_rejects_missing_path_param():
with pytest.raises(BadRequest):
AppResolver()(_ctx({}))
def test_resolver_rejects_none_view_args():
with pytest.raises(BadRequest):
AppResolver()(_ctx(None))
@patch("controllers.openapi.auth.steps.db")
def test_resolver_404_when_app_missing(db):
db.session.get.side_effect = [None]
with pytest.raises(NotFound):
AppResolver()(_ctx({"app_id": "x"}))
@patch("controllers.openapi.auth.steps.db")
def test_resolver_403_when_disabled(db):
db.session.get.side_effect = [_app(enable_api=False)]
with pytest.raises(Forbidden) as exc:
AppResolver()(_ctx({"app_id": "x"}))
assert "service_api_disabled" in str(exc.value.description)
@patch("controllers.openapi.auth.steps.db")
def test_resolver_403_when_tenant_archived(db):
db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)]
with pytest.raises(Forbidden):
AppResolver()(_ctx({"app_id": "x"}))
@patch("controllers.openapi.auth.steps.db")
def test_resolver_populates_app_and_tenant(db):
db.session.get.side_effect = [_app(), _tenant()]
ctx = _ctx({"app_id": "x"})
AppResolver()(ctx)
assert ctx.app.id == "app1"
assert ctx.tenant.id == "t1"

View File

@ -0,0 +1,75 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import AppAuthzCheck
from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id="acc1"):
c = Context(request=MagicMock(), required_scope="apps:run")
c.subject_type = subject_type
c.subject_email = "alice@example.com"
c.account_id = account_id
c.app = SimpleNamespace(id="app1")
c.tenant = SimpleNamespace(id="t1")
return c
@patch("controllers.openapi.auth.strategies.EnterpriseService")
def test_acl_strategy_private_calls_inner_api(ent):
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode="private")
ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True
assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with(
user_id="acc1",
app_id="app1",
)
@pytest.mark.parametrize(
("access_mode", "subject_type", "expected"),
[
("public", SubjectType.ACCOUNT, True),
("public", SubjectType.EXTERNAL_SSO, True),
("sso_verified", SubjectType.ACCOUNT, True),
("sso_verified", SubjectType.EXTERNAL_SSO, True),
("private_all", SubjectType.ACCOUNT, True),
("private_all", SubjectType.EXTERNAL_SSO, False),
("private", SubjectType.EXTERNAL_SSO, False),
],
)
@patch("controllers.openapi.auth.strategies.EnterpriseService")
def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expected):
"""Step 1 matrix: subject vs access-mode compatibility. No inner API call expected."""
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode=access_mode)
account_id = "acc1" if subject_type == SubjectType.ACCOUNT else None
assert AclStrategy().authorize(_ctx(subject_type=subject_type, account_id=account_id)) is expected
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called()
@patch("controllers.openapi.auth.strategies._has_tenant_membership")
def test_membership_strategy_uses_join_lookup(member):
member.return_value = True
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
member.assert_called_once_with("acc1", "t1")
def test_membership_strategy_rejects_external_sso():
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
def test_app_authz_check_raises_when_strategy_denies():
deny = SimpleNamespace(authorize=lambda c: False)
with pytest.raises(Forbidden) as exc:
AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT))
assert "subject_no_app_access" in str(exc.value.description)
def test_app_authz_check_passes_when_strategy_allows():
allow = SimpleNamespace(authorize=lambda c: True)
AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT))

View File

@ -0,0 +1,67 @@
import uuid
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from werkzeug.exceptions import Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import BearerCheck
from libs.oauth_bearer import AuthContext, InvalidBearerError, Scope, SubjectType
def _ctx(headers):
req = MagicMock()
req.headers = headers
return Context(request=req, required_scope="apps:run")
def test_bearer_check_rejects_missing_header():
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx({}))
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_rejects_unknown_prefix(get_auth):
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"}))
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_populates_context_and_g_auth_ctx(get_auth):
tok_id = uuid.uuid4()
authn = AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="a@x.com",
subject_issuer=None,
account_id=None,
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=tok_id,
source="oauth-account",
expires_at=datetime.now(UTC),
token_hash="hash-1",
verified_tenants={},
)
get_auth.return_value.authenticate.return_value = authn
app = Flask(__name__)
ctx = _ctx({"Authorization": "Bearer dfoa_abc"})
with app.test_request_context():
BearerCheck()(ctx)
assert ctx.subject_type == SubjectType.ACCOUNT
assert ctx.subject_email == "a@x.com"
assert ctx.scopes == frozenset({Scope.FULL})
assert ctx.source == "oauth-account"
assert ctx.token_id == tok_id
assert ctx.token_hash == "hash-1"
# BearerCheck must also publish the same identity on `g.auth_ctx`
# so the surface gate + downstream handlers don't see two
# different identity sources between the decorator + pipeline paths.
assert g.auth_ctx is authn
assert g.auth_ctx.client_id == "difyctl"

View File

@ -0,0 +1,157 @@
"""Unit tests for WorkspaceMembershipCheck (Layer 0)."""
from __future__ import annotations
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import WorkspaceMembershipCheck
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
c = Context(request=MagicMock(), required_scope="apps:read")
c.subject_type = subject_type
c.account_id = account_id
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
c.cached_verified_tenants = cached_verified_tenants
c.token_hash = token_hash
return c
@pytest.fixture
def step():
return WorkspaceMembershipCheck()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = True
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id=str(uuid.uuid4()),
tenant_id=str(uuid.uuid4()),
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.EXTERNAL_SSO,
account_id=None,
tenant_id=str(uuid.uuid4()),
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={"t1": True},
token_hash="hash-1",
)
step(ctx)
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={"t1": False},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_record.assert_called_once_with("hash-1", "t1", False)
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
]
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_record.assert_called_once_with("hash-1", "t1", False)
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_allows_active_member(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
]
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_record.assert_called_once_with("hash-1", "t1", True)

View File

@ -0,0 +1,77 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import CallerMount
from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter
from core.app.entities.app_invoke_entities import InvokeFrom
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id=None, subject_email=None):
c = Context(request=MagicMock(), required_scope="apps:run")
c.subject_type = subject_type
c.account_id = account_id
c.subject_email = subject_email
c.app = SimpleNamespace(id="app1")
c.tenant = SimpleNamespace(id="t1")
return c
@patch("controllers.openapi.auth.strategies._login_as")
@patch("controllers.openapi.auth.strategies.db")
def test_account_mounter(db, login):
account = SimpleNamespace()
db.session.get.return_value = account
ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1")
AccountMounter().mount(ctx)
assert ctx.caller is account
assert ctx.caller.current_tenant is ctx.tenant
assert ctx.caller_kind == "account"
login.assert_called_once_with(account)
@patch("controllers.openapi.auth.strategies._login_as")
@patch("controllers.openapi.auth.strategies.EndUserService")
def test_end_user_mounter(svc, login):
eu = SimpleNamespace()
svc.get_or_create_end_user_by_type.return_value = eu
ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com")
EndUserMounter().mount(ctx)
svc.get_or_create_end_user_by_type.assert_called_once_with(
InvokeFrom.OPENAPI,
tenant_id="t1",
app_id="app1",
user_id="a@x.com",
)
assert ctx.caller is eu
assert ctx.caller_kind == "end_user"
def test_caller_mount_dispatches_by_subject_type():
seen = {}
class Fake:
def __init__(self, st, tag):
self._st, self._tag = st, tag
def applies_to(self, st):
return st == self._st
def mount(self, ctx):
seen["who"] = self._tag
cm = CallerMount(
Fake(SubjectType.ACCOUNT, "acct"),
Fake(SubjectType.EXTERNAL_SSO, "sso"),
)
cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO))
assert seen == {"who": "sso"}
def test_caller_mount_raises_when_none_applies():
with pytest.raises(Unauthorized):
CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT))

View File

@ -0,0 +1,27 @@
from unittest.mock import MagicMock
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import ScopeCheck
def _ctx(scopes, required):
c = Context(request=MagicMock(), required_scope=required)
c.scopes = frozenset(scopes)
return c
def test_scope_check_passes_on_full():
ScopeCheck()(_ctx({"full"}, "apps:run"))
def test_scope_check_passes_on_explicit_match():
ScopeCheck()(_ctx({"apps:run"}, "apps:run"))
def test_scope_check_rejects_when_missing():
with pytest.raises(Forbidden) as exc:
ScopeCheck()(_ctx({"apps:read"}, "apps:run"))
assert "insufficient_scope" in str(exc.value.description)

View File

@ -0,0 +1,181 @@
"""Surface gate tests.
The gate has two attachment forms decorator (`accept_subjects`) and
pipeline step (`SurfaceCheck`) and both must:
- 403 on mismatched subject type with a canonical-path hint
- emit `openapi.wrong_surface_denied` once with the right payload
- pass-through on match
- raise RuntimeError (not 403) if g.auth_ctx is missing that's a
wiring bug, not a user-driven failure
"""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import SurfaceCheck
from controllers.openapi.auth.surface_gate import accept_subjects, check_surface
from libs.oauth_bearer import AuthContext, Scope, SubjectType
def _account_ctx() -> AuthContext:
return AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="user@example.com",
subject_issuer="dify:account",
account_id=uuid.uuid4(),
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
source="oauth_account",
expires_at=datetime.now(UTC),
token_hash="h1",
verified_tenants={},
)
def _sso_ctx() -> AuthContext:
return AuthContext(
subject_type=SubjectType.EXTERNAL_SSO,
subject_email="sso@partner.com",
subject_issuer="https://idp.partner.com",
account_id=None,
client_id="difyctl",
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
token_id=uuid.uuid4(),
source="oauth_external_sso",
expires_at=datetime.now(UTC),
token_hash="h2",
verified_tenants={},
)
# ---------------------------------------------------------------------------
# check_surface — shared core
# ---------------------------------------------------------------------------
def test_check_surface_passes_when_subject_in_accepted():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"):
g.auth_ctx = _account_ctx()
check_surface(frozenset({SubjectType.ACCOUNT})) # no raise
def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/permitted-external-apps"):
g.auth_ctx = _account_ctx()
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden) as exc:
check_surface(frozenset({SubjectType.EXTERNAL_SSO}))
assert "wrong_surface" in exc.value.description
# canonical-path hint should point at the caller's surface,
# not the surface they were rejected from
assert "/openapi/v1/apps" in exc.value.description
emit.assert_called_once()
kwargs = emit.call_args.kwargs
assert kwargs["subject_type"] == SubjectType.ACCOUNT.value
assert kwargs["attempted_path"] == "/openapi/v1/permitted-external-apps"
assert kwargs["client_id"] == "difyctl"
assert kwargs["token_id"] is not None
def test_check_surface_rejects_sso_on_account_surface():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"):
g.auth_ctx = _sso_ctx()
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden):
check_surface(frozenset({SubjectType.ACCOUNT}))
kwargs = emit.call_args.kwargs
assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value
def test_check_surface_runtime_error_when_g_auth_ctx_missing():
"""Missing g.auth_ctx means the bearer layer didn't run — wiring bug,
not a user-driven failure. Surface as RuntimeError (loud) so a future
refactor doesn't accidentally let a route skip authentication and
return a 403 that looks identical to a legitimate wrong-surface deny.
"""
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"):
with pytest.raises(RuntimeError):
check_surface(frozenset({SubjectType.ACCOUNT}))
# ---------------------------------------------------------------------------
# @accept_subjects — decorator form
# ---------------------------------------------------------------------------
def _make_app() -> Flask:
app = Flask(__name__)
@app.route("/account-only")
@accept_subjects(SubjectType.ACCOUNT)
def _account_only():
return "ok"
@app.route("/external-only")
@accept_subjects(SubjectType.EXTERNAL_SSO)
def _external_only():
return "ok"
return app
def test_accept_subjects_decorator_passes_on_match():
app = _make_app()
with app.test_request_context("/account-only"):
g.auth_ctx = _account_ctx()
# Re-route through the decorated function by reaching for view_function
view = app.view_functions["_account_only"]
assert view() == "ok"
def test_accept_subjects_decorator_403_on_miss():
app = _make_app()
with app.test_request_context("/external-only"):
g.auth_ctx = _account_ctx()
view = app.view_functions["_external_only"]
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"):
with pytest.raises(Forbidden):
view()
# ---------------------------------------------------------------------------
# SurfaceCheck — pipeline step form
# ---------------------------------------------------------------------------
def _pipeline_ctx() -> Context:
req = MagicMock()
req.path = "/openapi/v1/apps/<id>/run"
return Context(request=req, required_scope=Scope.APPS_RUN)
def test_surface_check_passes_on_match():
step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT}))
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps/x/run"):
g.auth_ctx = _account_ctx()
step(_pipeline_ctx()) # no raise
def test_surface_check_rejects_on_miss_and_emits_audit():
step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO}))
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps/x/run"):
g.auth_ctx = _account_ctx()
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden):
step(_pipeline_ctx())
emit.assert_called_once()

View File

@ -0,0 +1,15 @@
import pytest
from controllers.openapi.auth.pipeline import Pipeline
@pytest.fixture
def bypass_pipeline(monkeypatch):
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real
pipeline at import time; mocking the module attribute does not undo
that. Patching Pipeline.run on the class is the bypass that actually
works.
"""
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)

View File

@ -0,0 +1,140 @@
"""User-scoped identity + session endpoints under /openapi/v1/account."""
import builtins
import pytest
from flask import Flask
from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
from controllers.openapi.account import (
AccountApi,
AccountSessionByIdApi,
AccountSessionsApi,
AccountSessionsSelfApi,
)
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def openapi_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(openapi_bp)
return app
def _rule(app: Flask, path: str):
return next(r for r in app.url_map.iter_rules() if r.rule == path)
def test_account_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/account" in rules
def test_account_dispatches_to_class(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/account")
assert openapi_app.view_functions[rule.endpoint].view_class is AccountApi
def test_account_sessions_self_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/account/sessions/self" in rules
def test_sessions_self_dispatches_to_class(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/account/sessions/self")
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionsSelfApi
def test_account_methods(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/account")
assert "GET" in rule.methods
def test_sessions_self_methods(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/account/sessions/self")
assert "DELETE" in rule.methods
def test_sessions_list_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/account/sessions" in rules
def test_sessions_list_dispatches_to_sessions_api(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/account/sessions")
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionsApi
assert "GET" in rule.methods
def test_session_by_id_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/account/sessions/<string:session_id>" in rules
def test_session_by_id_dispatches_to_correct_class(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/account/sessions/<string:session_id>")
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionByIdApi
assert "DELETE" in rule.methods
def test_subject_match_for_account_filters_by_account_id():
"""Account subject scopes queries via account_id."""
import uuid as _uuid
from controllers.openapi.account import _subject_match
from libs.oauth_bearer import AuthContext, SubjectType
aid = _uuid.uuid4()
ctx = AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="user@example.com",
subject_issuer="dify:account",
account_id=aid,
client_id="difyctl",
scopes=frozenset({"full"}),
token_id=_uuid.uuid4(),
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants={},
)
clauses = _subject_match(ctx)
# One predicate, on account_id
assert len(clauses) == 1
assert "account_id" in str(clauses[0])
def test_subject_match_for_external_sso_filters_by_email_and_issuer():
"""External SSO subject scopes via (subject_email, subject_issuer)
AND account_id IS NULL so a same-email account row from a
federated tenant cannot be revoked through an SSO bearer.
"""
import uuid as _uuid
from controllers.openapi.account import _subject_match
from libs.oauth_bearer import AuthContext, SubjectType
ctx = AuthContext(
subject_type=SubjectType.EXTERNAL_SSO,
subject_email="sso@partner.com",
subject_issuer="https://idp.partner.com",
account_id=None,
client_id="difyctl",
scopes=frozenset({"apps:run"}),
token_id=_uuid.uuid4(),
source="oauth_external_sso",
expires_at=None,
token_hash="h1",
verified_tenants={},
)
clauses = _subject_match(ctx)
assert len(clauses) == 3
rendered = " ".join(str(c) for c in clauses)
assert "subject_email" in rendered
assert "subject_issuer" in rendered
assert "account_id IS NULL" in rendered

View File

@ -0,0 +1,48 @@
"""Unit tests for AppDescribeQuery (`?fields=` allow-list)."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from controllers.openapi.apps import AppDescribeQuery
def test_no_fields_returns_none() -> None:
q = AppDescribeQuery.model_validate({})
assert q.fields is None
def test_empty_string_returns_none() -> None:
q = AppDescribeQuery.model_validate({"fields": ""})
assert q.fields is None
def test_single_field() -> None:
q = AppDescribeQuery.model_validate({"fields": "info"})
assert q.fields == {"info"}
def test_comma_list() -> None:
q = AppDescribeQuery.model_validate({"fields": "info,parameters"})
assert q.fields == {"info", "parameters"}
def test_whitespace_tolerant() -> None:
q = AppDescribeQuery.model_validate({"fields": " info , input_schema "})
assert q.fields == {"info", "input_schema"}
def test_unknown_member_rejected() -> None:
with pytest.raises(ValidationError):
AppDescribeQuery.model_validate({"fields": "garbage"})
def test_unknown_among_known_rejected() -> None:
with pytest.raises(ValidationError):
AppDescribeQuery.model_validate({"fields": "info,garbage"})
def test_extra_param_forbidden() -> None:
with pytest.raises(ValidationError):
AppDescribeQuery.model_validate({"fields": "info", "page": "1"})

View File

@ -0,0 +1,105 @@
"""Unit tests for AppListQuery — the /apps query-param validator.
Runs against the model directly, not the HTTP layer. Pins:
- defaults match the plan (page=1, limit=20).
- workspace_id is required.
- numeric bounds enforced (page >= 1, limit in [1, MAX_PAGE_LIMIT]).
- mode validates against the AppMode enum.
- name and tag have length caps.
"""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from controllers.openapi._models import MAX_PAGE_LIMIT
from controllers.openapi.apps import AppListQuery
def test_defaults():
q = AppListQuery.model_validate({"workspace_id": "ws-1"})
assert q.workspace_id == "ws-1"
assert q.page == 1
assert q.limit == 20
assert q.mode is None
assert q.name is None
assert q.tag is None
def test_workspace_id_required():
with pytest.raises(ValidationError):
AppListQuery.model_validate({})
def test_page_must_be_positive():
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "ws-1", "page": 0})
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "ws-1", "page": -1})
def test_page_rejects_non_integer_string():
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "ws-1", "page": "abc"})
def test_limit_must_be_positive():
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": 0})
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": -1})
def test_limit_caps_at_max_page_limit():
# Boundary accepts.
q = AppListQuery.model_validate({"workspace_id": "ws-1", "limit": MAX_PAGE_LIMIT})
assert q.limit == MAX_PAGE_LIMIT
# Just over rejects.
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": MAX_PAGE_LIMIT + 1})
def test_mode_whitelisted_against_app_mode():
# Valid mode passes.
q = AppListQuery.model_validate({"workspace_id": "ws-1", "mode": "chat"})
assert q.mode is not None
assert q.mode.value == "chat"
# Invalid mode rejects.
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "ws-1", "mode": "not-a-mode"})
def test_name_length_capped():
AppListQuery.model_validate({"workspace_id": "ws-1", "name": "x" * 200})
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "ws-1", "name": "x" * 201})
def test_tag_length_capped():
AppListQuery.model_validate({"workspace_id": "ws-1", "tag": "x" * 100})
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "ws-1", "tag": "x" * 101})
def test_all_fields_accept_valid_values():
"""Pin the happy-path acceptance for every field in one place."""
q = AppListQuery.model_validate(
{
"workspace_id": "ws-1",
"page": 5,
"limit": 50,
"mode": "workflow",
"name": "search",
"tag": "prod",
}
)
assert q.workspace_id == "ws-1"
assert q.page == 5
assert q.limit == 50
assert q.mode is not None
assert q.mode.value == "workflow"
assert q.name == "search"
assert q.tag == "prod"

View File

@ -0,0 +1,55 @@
"""Unit tests for app payload-rendering helpers — independent of
HTTP plumbing or DB. Pin the response shapes that are CLI contracts.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from controllers.openapi.apps import ( # pyright: ignore[reportPrivateUsage]
_EMPTY_PARAMETERS,
parameters_payload,
)
from controllers.service_api.app.error import AppUnavailableError
def _fake_app(**overrides):
base = {
"id": "app1",
"name": "X",
"description": "d",
"mode": "chat",
"author_name": "alice",
"tags": [SimpleNamespace(name="prod")],
"updated_at": None,
"enable_api": True,
"workflow": None,
"app_model_config": None,
}
base.update(overrides)
return SimpleNamespace(**base)
def test_parameters_payload_raises_app_unavailable_when_no_config():
with pytest.raises(AppUnavailableError):
parameters_payload(_fake_app(mode="chat", app_model_config=None))
def test_empty_parameters_constant_matches_describe_fallback_shape():
"""The fallback dict served by /describe when an app has no config
must match the spec's stated keys (opening_statement, suggested_questions,
user_input_form, file_upload, system_parameters)."""
assert set(_EMPTY_PARAMETERS.keys()) == {
"opening_statement",
"suggested_questions",
"user_input_form",
"file_upload",
"system_parameters",
}
assert _EMPTY_PARAMETERS["suggested_questions"] == []
assert _EMPTY_PARAMETERS["user_input_form"] == []
assert _EMPTY_PARAMETERS["opening_statement"] is None
assert _EMPTY_PARAMETERS["file_upload"] is None
assert _EMPTY_PARAMETERS["system_parameters"] == {}

View File

@ -0,0 +1,45 @@
import pytest
from werkzeug.exceptions import InternalServerError
from controllers.openapi.app_run import (
_DISPATCH,
AppRunRequest,
_unpack_blocking,
)
from models.model import AppMode
def test_dispatch_covers_runnable_modes():
runnable = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW}
assert set(_DISPATCH) == runnable
def test_unpack_blocking_passes_through_mapping():
assert _unpack_blocking({"a": 1}) == {"a": 1}
def test_unpack_blocking_unwraps_tuple():
assert _unpack_blocking(({"a": 1}, 200)) == {"a": 1}
def test_unpack_blocking_rejects_non_mapping():
with pytest.raises(InternalServerError):
_unpack_blocking("not a mapping")
def test_app_run_request_strips_blank_conversation_id():
payload = AppRunRequest(inputs={}, conversation_id=" ")
assert payload.conversation_id is None
def test_app_run_request_rejects_invalid_uuid_conversation_id():
from pydantic import ValidationError
with pytest.raises(ValidationError, match="conversation_id must be a valid UUID"):
AppRunRequest(inputs={}, conversation_id="not-a-uuid")
def test_app_run_request_accepts_valid_uuid_conversation_id():
import uuid as _uuid
cid = str(_uuid.uuid4())
payload = AppRunRequest(inputs={}, conversation_id=cid)
assert payload.conversation_id == cid

View File

@ -0,0 +1,53 @@
"""Unit tests for PermittedExternalAppsListQuery — the
/permitted-external-apps query validator.
Strict ConfigDict(extra='forbid'): cross-tenant tag/workspace_id are
unresolvable, so the model must reject them as 422 instead of silently
dropping them. Mode/name/page/limit have the same shape as AppListQuery.
"""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from controllers.openapi.apps_permitted_external import PermittedExternalAppsListQuery
def test_query_defaults_match_apps_list():
q = PermittedExternalAppsListQuery.model_validate({})
assert q.page == 1
assert q.limit == 20
assert q.mode is None
assert q.name is None
def test_query_rejects_workspace_id():
"""workspace_id is meaningless for /permitted-external-apps (cross-tenant);
rejecting it forces CLI authors to drop the param rather than send it
silently."""
with pytest.raises(ValidationError):
PermittedExternalAppsListQuery.model_validate({"workspace_id": "ws-1"})
def test_query_rejects_tag():
"""Tags are tenant-scoped; cross-tenant tag resolution is undefined."""
with pytest.raises(ValidationError):
PermittedExternalAppsListQuery.model_validate({"tag": "prod"})
def test_query_validates_mode_against_app_mode():
with pytest.raises(ValidationError):
PermittedExternalAppsListQuery.model_validate({"mode": "not-a-mode"})
def test_query_clamps_limit_at_max():
with pytest.raises(ValidationError):
PermittedExternalAppsListQuery.model_validate({"limit": 500})
def test_query_accepts_valid_mode():
"""Pin the happy path: AppMode values pass."""
q = PermittedExternalAppsListQuery.model_validate({"mode": "chat"})
assert q.mode is not None
assert q.mode.value == "chat"

View File

@ -0,0 +1,26 @@
import logging
from controllers.openapi._audit import EVENT_APP_RUN_OPENAPI, emit_app_run
def test_event_constant():
assert EVENT_APP_RUN_OPENAPI == "app.run.openapi"
def test_emit_app_run_logs_with_audit_extra(caplog):
with caplog.at_level(logging.INFO, logger="controllers.openapi._audit"):
emit_app_run(
app_id="app1",
tenant_id="t1",
caller_kind="account",
mode="chat",
surface="apps",
)
record = next(r for r in caplog.records if r.message and "app.run.openapi" in r.message)
assert record.audit is True
assert record.event == EVENT_APP_RUN_OPENAPI
assert record.app_id == "app1"
assert record.tenant_id == "t1"
assert record.caller_kind == "account"
assert record.mode == "chat"
assert record.surface == "apps"

View File

@ -0,0 +1,127 @@
"""CORS posture for /openapi/v1/* — default empty allowlist (same-origin),
expandable via OPENAPI_CORS_ALLOW_ORIGINS. Cross-origin requests from
disallowed origins do not receive the Access-Control-Allow-Origin
header, which the browser then blocks.
Tests use a fresh Blueprint + Flask-CORS per case because the production
blueprint is a module-level singleton and can't be reconfigured once
registered.
"""
import builtins
from flask import Blueprint, Flask
from flask.views import MethodView
from flask_cors import CORS
from flask_restx import Resource
from configs import dify_config
from extensions.ext_blueprints import OPENAPI_HEADERS, OPENAPI_MAX_AGE_SECONDS
from libs.external_api import ExternalApi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _make_app(allowed_origins: list[str], blueprint_name: str) -> Flask:
"""Build a Flask app with a fresh openapi-style blueprint mirroring
production CORS settings, parameterised on the origin allowlist.
"""
bp = Blueprint(blueprint_name, __name__, url_prefix="/openapi/v1")
api = ExternalApi(bp, version="1.0", title="OpenAPI Test", description="")
@api.route("/_health")
class _Health(Resource):
def get(self):
return {"ok": True}
CORS(
bp,
resources={r"/*": {"origins": allowed_origins}},
supports_credentials=True,
allow_headers=list(OPENAPI_HEADERS),
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
expose_headers=["X-Version"],
max_age=OPENAPI_MAX_AGE_SECONDS,
)
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(bp)
return app
def test_default_openapi_cors_allowlist_is_empty():
"""Default config admits no cross-origin until operator opts in."""
assert dify_config.OPENAPI_CORS_ALLOW_ORIGINS == []
def test_preflight_allowed_origin_returns_cors_headers():
app = _make_app(["https://app.example.com"], "openapi_t1")
client = app.test_client()
response = client.options(
"/openapi/v1/_health",
headers={
"Origin": "https://app.example.com",
"Access-Control-Request-Method": "GET",
},
)
assert response.headers.get("Access-Control-Allow-Origin") == "https://app.example.com"
assert response.headers.get("Access-Control-Max-Age") == str(OPENAPI_MAX_AGE_SECONDS)
def test_preflight_disallowed_origin_omits_cors_headers():
app = _make_app(["https://app.example.com"], "openapi_t2")
client = app.test_client()
response = client.options(
"/openapi/v1/_health",
headers={
"Origin": "https://attacker.example",
"Access-Control-Request-Method": "GET",
},
)
# flask-cors omits Allow-Origin for disallowed origins; browser blocks.
assert "Access-Control-Allow-Origin" not in response.headers
def test_preflight_with_default_empty_allowlist_omits_cors_headers():
app = _make_app([], "openapi_t3")
client = app.test_client()
response = client.options(
"/openapi/v1/_health",
headers={
"Origin": "https://app.example.com",
"Access-Control-Request-Method": "GET",
},
)
assert "Access-Control-Allow-Origin" not in response.headers
def test_same_origin_request_succeeds_without_origin_header():
app = _make_app(["https://app.example.com"], "openapi_t4")
client = app.test_client()
# Browsers don't send Origin on same-origin GETs.
response = client.get("/openapi/v1/_health")
assert response.status_code == 200
assert response.get_json() == {"ok": True}
def test_authorization_header_is_in_allow_headers():
"""Bearer-authed routes need Authorization in the preflight response."""
app = _make_app(["https://app.example.com"], "openapi_t5")
client = app.test_client()
response = client.options(
"/openapi/v1/_health",
headers={
"Origin": "https://app.example.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "Authorization",
},
)
allow_headers = response.headers.get("Access-Control-Allow-Headers", "").lower()
assert "authorization" in allow_headers

View File

@ -0,0 +1,52 @@
"""Account-branch device-flow approve/deny under /openapi/v1."""
import builtins
import pytest
from flask import Flask
from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
from controllers.openapi.oauth_device import DeviceApproveApi, DeviceDenyApi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def openapi_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(openapi_bp)
return app
def _rule(app: Flask, path: str):
return next(r for r in app.url_map.iter_rules() if r.rule == path)
def test_approve_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/approve" in rules
def test_deny_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/deny" in rules
def test_approve_dispatches_to_class(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approve")
assert openapi_app.view_functions[rule.endpoint].view_class is DeviceApproveApi
def test_deny_dispatches_to_class(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/oauth/device/deny")
assert openapi_app.view_functions[rule.endpoint].view_class is DeviceDenyApi
def test_approve_and_deny_methods(openapi_app: Flask):
approve = _rule(openapi_app, "/openapi/v1/oauth/device/approve")
deny = _rule(openapi_app, "/openapi/v1/oauth/device/deny")
assert "POST" in approve.methods
assert "POST" in deny.methods

View File

@ -0,0 +1,47 @@
"""POST /openapi/v1/oauth/device/code is the canonical RFC 8628 device
authorization endpoint.
Tests verify URL routing without invoking the handler invoking would
require Redis, which the unit-test runtime does not initialise.
"""
import builtins
import pytest
from flask import Flask
from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
from controllers.openapi.oauth_device import OAuthDeviceCodeApi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def openapi_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(openapi_bp)
return app
def test_openapi_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/code" in rules
def test_route_dispatches_to_class(openapi_app: Flask):
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceCodeApi
def test_route_accepts_post(openapi_app: Flask):
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
assert "POST" in rule.methods
def test_known_client_ids_default_includes_difyctl():
from configs import dify_config
assert "difyctl" in dify_config.OPENAPI_KNOWN_CLIENT_IDS

View File

@ -0,0 +1,36 @@
"""GET /openapi/v1/oauth/device/lookup is the canonical user-code lookup."""
import builtins
import pytest
from flask import Flask
from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
from controllers.openapi.oauth_device import OAuthDeviceLookupApi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def openapi_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(openapi_bp)
return app
def test_openapi_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/lookup" in rules
def test_route_dispatches_to_class(openapi_app: Flask):
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceLookupApi
def test_route_accepts_get(openapi_app: Flask):
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
assert "GET" in rule.methods

View File

@ -0,0 +1,105 @@
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""
import builtins
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
from controllers.openapi.oauth_device_sso import (
_email_belongs_to_dify_account,
approval_context,
approve_external,
sso_complete,
sso_initiate,
)
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def openapi_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(openapi_bp)
return app
def _rule(app: Flask, path: str):
return next(r for r in app.url_map.iter_rules() if r.rule == path)
def test_sso_initiate_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/sso-initiate" in rules
def test_sso_complete_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/sso-complete" in rules
def test_approval_context_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/approval-context" in rules
def test_approve_external_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/approve-external" in rules
def test_sso_initiate_dispatches_to_function(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/oauth/device/sso-initiate")
assert openapi_app.view_functions[rule.endpoint] is sso_initiate
def test_sso_complete_dispatches_to_function(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/oauth/device/sso-complete")
assert openapi_app.view_functions[rule.endpoint] is sso_complete
def test_approval_context_dispatches_to_function(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approval-context")
assert openapi_app.view_functions[rule.endpoint] is approval_context
def test_approve_external_dispatches_to_function(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approve-external")
assert openapi_app.view_functions[rule.endpoint] is approve_external
def test_sso_complete_idp_callback_url_uses_canonical_path():
"""sso_initiate hardcodes the IdP callback URL — must point at the
canonical /openapi/v1/ path so IdP-side ACS configuration matches.
"""
from controllers.openapi import oauth_device_sso
assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete"
@pytest.mark.parametrize(
("email", "row", "expected"),
[
("alice@example.com", "acc1", True),
("alice@example.com", None, False),
("Alice@Example.COM", "acc1", True), # case-insensitive lookup
(" alice@example.com ", "acc1", True), # surrounding whitespace stripped
("", "acc1", False),
(" ", "acc1", False),
("", None, False),
],
)
@patch("controllers.openapi.oauth_device_sso.db")
def test_email_belongs_to_dify_account(db_mock, email, row, expected):
exec_result = MagicMock()
exec_result.scalar_one_or_none.return_value = row
db_mock.session.execute.return_value = exec_result
assert _email_belongs_to_dify_account(email) is expected
if email.strip():
db_mock.session.execute.assert_called_once()
else:
db_mock.session.execute.assert_not_called()

View File

@ -0,0 +1,31 @@
"""POST /openapi/v1/oauth/device/token is the canonical poll endpoint."""
import builtins
import pytest
from flask import Flask
from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
from controllers.openapi.oauth_device import OAuthDeviceTokenApi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def openapi_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(openapi_bp)
return app
def test_openapi_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/oauth/device/token" in rules
def test_route_dispatches_to_class(openapi_app: Flask):
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token")
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceTokenApi

View File

@ -0,0 +1,33 @@
import builtins
import pytest
from flask import Flask
from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(openapi_bp)
return app
def test_health_returns_ok(app: Flask):
client = app.test_client()
response = client.get("/openapi/v1/_health")
assert response.status_code == 200
assert response.get_json() == {"ok": True}
def test_health_path_is_under_openapi_v1_prefix(app: Flask):
client = app.test_client()
assert client.get("/_health").status_code == 404
assert client.get("/v1/_health").status_code == 404
assert client.get("/openapi/v1/_health").status_code == 200

View File

@ -0,0 +1,182 @@
"""Unit tests for input_schema derivation."""
from __future__ import annotations
import pytest
from controllers.openapi._input_schema import _form_to_jsonschema
def _wrap(component: dict) -> list[dict]:
"""user_input_form rows are single-key dicts: {"text-input": {...}}."""
return [component]
def test_text_input_required() -> None:
form = _wrap({"text-input": {"variable": "industry", "label": "Industry", "required": True, "max_length": 200}})
props, required = _form_to_jsonschema(form)
assert props == {"industry": {"type": "string", "title": "Industry", "maxLength": 200}}
assert required == ["industry"]
def test_paragraph_optional() -> None:
form = _wrap({"paragraph": {"variable": "context", "label": "Context", "required": False, "max_length": 4000}})
props, required = _form_to_jsonschema(form)
assert props["context"] == {"type": "string", "title": "Context", "maxLength": 4000}
assert required == []
def test_select_enum() -> None:
form = _wrap(
{
"select": {
"variable": "tier",
"label": "Tier",
"required": True,
"options": ["free", "pro", "enterprise"],
}
}
)
props, required = _form_to_jsonschema(form)
assert props == {"tier": {"type": "string", "title": "Tier", "enum": ["free", "pro", "enterprise"]}}
assert required == ["tier"]
def test_number() -> None:
form = _wrap({"number": {"variable": "count", "label": "Count", "required": False}})
props, _required = _form_to_jsonschema(form)
assert props["count"] == {"type": "number", "title": "Count"}
def test_file() -> None:
form = _wrap({"file": {"variable": "doc", "label": "Doc", "required": True}})
props, required = _form_to_jsonschema(form)
assert props["doc"]["type"] == "object"
assert "title" in props["doc"]
assert required == ["doc"]
def test_file_list() -> None:
form = _wrap({"file-list": {"variable": "attachments", "label": "Attachments", "required": False}})
props, _required = _form_to_jsonschema(form)
assert props["attachments"]["type"] == "array"
assert props["attachments"]["items"]["type"] == "object"
def test_unknown_type_skipped() -> None:
"""Forward-compat: unknown variable types are skipped, not 500'd."""
form = _wrap({"future-type": {"variable": "x", "label": "X", "required": False}})
props, required = _form_to_jsonschema(form)
assert props == {}
assert required == []
def test_required_order_preserved() -> None:
form = [
{"text-input": {"variable": "a", "label": "A", "required": True}},
{"text-input": {"variable": "b", "label": "B", "required": False}},
{"text-input": {"variable": "c", "label": "C", "required": True}},
]
_props, required = _form_to_jsonschema(form)
assert required == ["a", "c"]
def test_max_length_omitted_when_zero() -> None:
form = _wrap({"text-input": {"variable": "x", "label": "X", "required": False, "max_length": 0}})
props, _ = _form_to_jsonschema(form)
assert "maxLength" not in props["x"]
from unittest.mock import MagicMock
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema
from controllers.service_api.app.error import AppUnavailableError
from models.model import AppMode
def _stub_app(mode: AppMode, *, form: list[dict] | None = None, has_workflow: bool | None = None):
"""Returns a MagicMock whose .mode + workflow / app_model_config branch is wired up."""
app = MagicMock()
app.mode = mode
if mode in (AppMode.WORKFLOW, AppMode.ADVANCED_CHAT):
if has_workflow is False:
app.workflow = None
else:
app.workflow = MagicMock()
app.workflow.user_input_form.return_value = form or []
app.workflow.features_dict = {}
else:
if has_workflow is False:
app.app_model_config = None
else:
app.app_model_config = MagicMock()
app.app_model_config.to_dict.return_value = {"user_input_form": form or []}
return app
def test_chat_mode_includes_query() -> None:
app = _stub_app(AppMode.CHAT, form=[{"text-input": {"variable": "x", "label": "X", "required": True}}])
schema = build_input_schema(app)
assert schema["$schema"] == "https://json-schema.org/draft/2020-12/schema"
assert "query" in schema["properties"]
assert schema["properties"]["query"]["type"] == "string"
assert schema["properties"]["query"]["minLength"] == 1
assert "query" in schema["required"]
assert "inputs" in schema["required"]
assert schema["properties"]["inputs"]["additionalProperties"] is False
def test_agent_chat_mode_includes_query() -> None:
app = _stub_app(AppMode.AGENT_CHAT, form=[])
schema = build_input_schema(app)
assert "query" in schema["properties"]
def test_advanced_chat_mode_includes_query() -> None:
app = _stub_app(AppMode.ADVANCED_CHAT, form=[])
schema = build_input_schema(app)
assert "query" in schema["properties"]
def test_workflow_mode_omits_query() -> None:
app = _stub_app(AppMode.WORKFLOW, form=[])
schema = build_input_schema(app)
assert "query" not in schema["properties"]
assert schema["required"] == ["inputs"]
def test_completion_mode_omits_query() -> None:
app = _stub_app(AppMode.COMPLETION, form=[])
schema = build_input_schema(app)
assert "query" not in schema["properties"]
assert schema["required"] == ["inputs"]
def test_inputs_required_driven_by_form() -> None:
app = _stub_app(
AppMode.CHAT,
form=[
{"text-input": {"variable": "industry", "label": "Industry", "required": True}},
{"text-input": {"variable": "context", "label": "Context", "required": False}},
],
)
schema = build_input_schema(app)
assert schema["properties"]["inputs"]["required"] == ["industry"]
def test_misconfigured_chat_raises_app_unavailable() -> None:
app = _stub_app(AppMode.CHAT, has_workflow=False)
with pytest.raises(AppUnavailableError):
build_input_schema(app)
def test_misconfigured_workflow_raises_app_unavailable() -> None:
app = _stub_app(AppMode.WORKFLOW, has_workflow=False)
with pytest.raises(AppUnavailableError):
build_input_schema(app)
def test_empty_input_schema_sentinel_shape() -> None:
assert EMPTY_INPUT_SCHEMA["type"] == "object"
assert EMPTY_INPUT_SCHEMA["properties"] == {}
assert EMPTY_INPUT_SCHEMA["required"] == []

View File

@ -0,0 +1,31 @@
from controllers.openapi._models import MessageMetadata, UsageInfo
def test_usage_info_defaults_zero():
u = UsageInfo()
assert u.prompt_tokens == 0
assert u.completion_tokens == 0
assert u.total_tokens == 0
def test_message_metadata_accepts_partial():
m = MessageMetadata(usage=UsageInfo(total_tokens=10))
assert m.usage.total_tokens == 10
assert m.retriever_resources == []
def test_describe_response_all_blocks_optional() -> None:
from controllers.openapi._models import AppDescribeResponse
payload = AppDescribeResponse().model_dump(mode="json", exclude_none=False)
assert payload == {"info": None, "parameters": None, "input_schema": None}
def test_describe_response_input_schema_field() -> None:
from controllers.openapi._models import AppDescribeResponse
schema = {"$schema": "https://json-schema.org/draft/2020-12/schema", "type": "object"}
payload = AppDescribeResponse(input_schema=schema).model_dump(mode="json", exclude_none=False)
assert payload["input_schema"] == schema
assert payload["info"] is None
assert payload["parameters"] is None

View File

@ -0,0 +1,127 @@
"""Unit tests for PaginationEnvelope generic Pydantic model."""
from __future__ import annotations
from pydantic import BaseModel
from controllers.openapi._models import PaginationEnvelope
class _Row(BaseModel):
id: str
name: str
def test_envelope_basic_fields():
env = PaginationEnvelope[_Row](page=1, limit=20, total=42, has_more=True, data=[_Row(id="a", name="A")])
dumped = env.model_dump(mode="json")
assert dumped == {
"page": 1,
"limit": 20,
"total": 42,
"has_more": True,
"data": [{"id": "a", "name": "A"}],
}
def test_envelope_empty_data_no_more():
env = PaginationEnvelope[_Row](page=1, limit=20, total=0, has_more=False, data=[])
assert env.model_dump(mode="json")["data"] == []
assert env.model_dump(mode="json")["has_more"] is False
def test_envelope_has_more_true_when_total_exceeds_page_window():
env = PaginationEnvelope[_Row].build(page=1, limit=20, total=42, items=[_Row(id="a", name="A")])
assert env.has_more is True
def test_envelope_has_more_false_when_total_within_page_window():
env = PaginationEnvelope[_Row].build(page=2, limit=20, total=22, items=[_Row(id="a", name="A")])
assert env.has_more is False
def test_envelope_has_more_false_for_last_page():
env = PaginationEnvelope[_Row].build(page=3, limit=20, total=42, items=[_Row(id="a", name="A")])
assert env.has_more is False
def test_max_page_limit_is_200():
from controllers.openapi._models import MAX_PAGE_LIMIT
assert MAX_PAGE_LIMIT == 200
def test_envelope_uses_pep695_generics():
"""Verify the class uses PEP 695 native generic syntax (not legacy Generic[T])."""
from controllers.openapi._models import PaginationEnvelope
# PEP 695 syntax populates __type_params__; the legacy Generic[T] form does not.
assert PaginationEnvelope.__type_params__, "expected PEP 695 native generic syntax"
fields = PaginationEnvelope.model_fields
assert {"page", "limit", "total", "has_more", "data"} <= set(fields)
def test_app_info_response_dump_matches_spec():
from controllers.openapi._models import AppInfoResponse
obj = AppInfoResponse(
id="app1",
name="X",
description="d",
mode="chat",
author="alice",
tags=[{"name": "prod"}],
)
assert obj.model_dump(mode="json") == {
"id": "app1",
"name": "X",
"description": "d",
"mode": "chat",
"author": "alice",
"tags": [{"name": "prod"}],
}
def test_app_describe_response_nests_info_and_parameters():
from controllers.openapi._models import AppDescribeInfo, AppDescribeResponse
info = AppDescribeInfo(
id="app1",
name="X",
mode="chat",
description=None,
tags=[],
author=None,
updated_at="2026-05-05T00:00:00+00:00",
service_api_enabled=True,
)
obj = AppDescribeResponse(info=info, parameters={"opening_statement": None})
dumped = obj.model_dump(mode="json")
assert dumped["info"]["service_api_enabled"] is True
assert dumped["parameters"]["opening_statement"] is None
def test_response_models_dump_per_mode():
from controllers.openapi._models import (
ChatMessageResponse,
CompletionMessageResponse,
WorkflowRunData,
WorkflowRunResponse,
)
chat = ChatMessageResponse(
event="message", task_id="t1", id="m1", message_id="m1",
conversation_id="c1", mode="chat", answer="hi", created_at=0,
)
assert chat.model_dump(mode="json")["mode"] == "chat"
wf = WorkflowRunResponse(
workflow_run_id="r1", task_id="t1",
data=WorkflowRunData(id="r1", workflow_id="w1", status="succeeded"),
)
assert wf.model_dump(mode="json")["data"]["status"] == "succeeded"
assert wf.model_dump(mode="json")["mode"] == "workflow"
comp = CompletionMessageResponse(
event="message", task_id="t2", id="m2", message_id="m2",
mode="completion", answer="ok", created_at=0,
)
assert comp.model_dump(mode="json")["mode"] == "completion"

View File

@ -0,0 +1,58 @@
"""Phase E step 17: workspace reads at /openapi/v1/workspaces. Bearer-authed
list + member-gated detail. No legacy /v1/ equivalent the cookie-authed
/console/api/workspaces is a separate consumer that stays in console.
"""
import builtins
import pytest
from flask import Flask
from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
from controllers.openapi.workspaces import WorkspaceByIdApi, WorkspacesApi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def openapi_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(openapi_bp)
return app
def _rule(app: Flask, path: str):
return next(r for r in app.url_map.iter_rules() if r.rule == path)
def test_workspaces_list_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/workspaces" in rules
def test_workspaces_list_dispatches_to_workspaces_api(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/workspaces")
assert openapi_app.view_functions[rule.endpoint].view_class is WorkspacesApi
assert "GET" in rule.methods
def test_workspace_by_id_route_registered(openapi_app: Flask):
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/openapi/v1/workspaces/<string:workspace_id>" in rules
def test_workspace_by_id_dispatches_to_correct_class(openapi_app: Flask):
rule = _rule(openapi_app, "/openapi/v1/workspaces/<string:workspace_id>")
assert openapi_app.view_functions[rule.endpoint].view_class is WorkspaceByIdApi
assert "GET" in rule.methods
def test_console_legacy_workspaces_route_not_remounted_on_openapi(openapi_app: Flask):
"""Phase E only adds the bearer-authed mounts on /openapi/v1/.
The cookie-authed /console/api/workspaces stays where it is.
"""
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
assert "/console/api/workspaces" not in rules

View File

@ -0,0 +1,9 @@
from core.app.entities.app_invoke_entities import InvokeFrom
def test_openapi_variant_present():
assert InvokeFrom.OPENAPI.value == "openapi"
def test_openapi_distinct_from_service_api():
assert InvokeFrom.OPENAPI != InvokeFrom.SERVICE_API

View File

@ -0,0 +1,29 @@
"""Unit tests for the openapi bearer-scope catalog and TokenKind registry."""
from __future__ import annotations
from unittest.mock import MagicMock
def test_apps_read_permitted_external_scope_present():
from libs.oauth_bearer import Scope
assert Scope.APPS_READ_PERMITTED_EXTERNAL.value == "apps:read:permitted-external"
def test_dfoe_token_kind_carries_apps_read_permitted_external():
from libs.oauth_bearer import Scope, build_registry
registry = build_registry(MagicMock(), MagicMock())
dfoe = next(k for k in registry.kinds() if k.prefix == "dfoe_")
assert Scope.APPS_READ_PERMITTED_EXTERNAL in dfoe.scopes
def test_dfoa_token_kind_does_not_carry_apps_read_permitted_external():
"""dfoa_ relies on Scope.FULL umbrella; the explicit permitted scope
is reserved for dfoe_."""
from libs.oauth_bearer import Scope, build_registry
registry = build_registry(MagicMock(), MagicMock())
dfoa = next(k for k in registry.kinds() if k.prefix == "dfoa_")
assert Scope.APPS_READ_PERMITTED_EXTERNAL not in dfoa.scopes

View File

@ -0,0 +1,94 @@
"""Unit tests for record_layer0_verdict — merge L0 verdict into AuthContext cache."""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
from libs.oauth_bearer import record_layer0_verdict
@pytest.fixture
def mock_redis():
return MagicMock()
@patch("libs.oauth_bearer.redis_client")
def test_no_op_when_cache_entry_missing(mock_redis):
mock_redis.get.return_value = None
record_layer0_verdict("h1", "t1", True)
mock_redis.setex.assert_not_called()
@patch("libs.oauth_bearer.redis_client")
def test_no_op_when_cache_entry_invalid_marker(mock_redis):
mock_redis.get.return_value = b"invalid"
record_layer0_verdict("h1", "t1", True)
mock_redis.setex.assert_not_called()
@patch("libs.oauth_bearer.redis_client")
def test_no_op_when_json_malformed(mock_redis):
mock_redis.get.return_value = b"not json"
record_layer0_verdict("h1", "t1", True)
mock_redis.setex.assert_not_called()
@patch("libs.oauth_bearer.redis_client")
def test_no_op_when_ttl_expired(mock_redis):
mock_redis.get.return_value = json.dumps(
{
"subject_email": "e",
"subject_issuer": None,
"account_id": None,
"token_id": "tid",
"expires_at": None,
}
).encode()
mock_redis.ttl.return_value = -1
record_layer0_verdict("h1", "t1", True)
mock_redis.setex.assert_not_called()
@patch("libs.oauth_bearer.redis_client")
def test_merges_new_tenant_verdict(mock_redis):
mock_redis.get.return_value = json.dumps(
{
"subject_email": "e",
"subject_issuer": None,
"account_id": None,
"token_id": "tid",
"expires_at": None,
"verified_tenants": {"t0": True},
}
).encode()
mock_redis.ttl.return_value = 42
record_layer0_verdict("h1", "t1", False)
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args
assert args.args[0] == "auth:token:h1"
assert args.args[1] == 42 # remaining TTL preserved
written = json.loads(args.args[2])
assert written["verified_tenants"] == {"t0": True, "t1": False}
@patch("libs.oauth_bearer.redis_client")
def test_merges_when_field_absent_from_legacy_entry(mock_redis):
"""Backward compat: legacy cache entry without verified_tenants field."""
mock_redis.get.return_value = json.dumps(
{
"subject_email": "e",
"subject_issuer": None,
"account_id": None,
"token_id": "tid",
"expires_at": None,
}
).encode()
mock_redis.ttl.return_value = 42
record_layer0_verdict("h1", "t1", True)
written = json.loads(mock_redis.setex.call_args.args[2])
assert written["verified_tenants"] == {"t1": True}

View File

@ -0,0 +1,85 @@
"""require_scope is a route-level gate run after validate_bearer.
Tests use a fake auth_ctx attached directly to flask.g no
authenticator wiring needed.
"""
from __future__ import annotations
import uuid
import pytest
from flask import Flask, g
from werkzeug.exceptions import Forbidden
from libs.oauth_bearer import (
AuthContext,
Scope,
SubjectType,
require_scope,
)
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
return app
def _ctx(scopes) -> AuthContext:
return AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="user@example.com",
subject_issuer="dify:account",
account_id=uuid.uuid4(),
client_id="difyctl",
scopes=scopes,
token_id=uuid.uuid4(),
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants={},
)
def test_require_scope_allows_when_scope_present(app: Flask):
@require_scope("apps:read")
def view():
return "ok"
with app.test_request_context():
g.auth_ctx = _ctx(frozenset({"apps:read"}))
assert view() == "ok"
def test_require_scope_rejects_when_scope_missing(app: Flask):
@require_scope("apps:write")
def view():
return "ok"
with app.test_request_context():
g.auth_ctx = _ctx(frozenset({"apps:read"}))
with pytest.raises(Forbidden) as exc:
view()
assert "insufficient_scope: apps:write" in str(exc.value.description)
def test_require_scope_full_passes_any_check(app: Flask):
@require_scope("apps:write")
def view():
return "ok"
with app.test_request_context():
g.auth_ctx = _ctx(frozenset({Scope.FULL}))
assert view() == "ok"
def test_require_scope_without_validate_bearer_raises_runtime_error(app: Flask):
@require_scope("apps:read")
def view():
return "ok"
with app.test_request_context():
# No g.auth_ctx — validate_bearer was forgotten
with pytest.raises(RuntimeError, match="stack @validate_bearer above @require_scope"):
view()

View File

@ -0,0 +1,74 @@
"""Unit tests for the per-token bearer rate limit primitive."""
from __future__ import annotations
from datetime import timedelta
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import TooManyRequests
from libs.helper import RateLimiter
from libs.rate_limit import (
LIMIT_BEARER_PER_TOKEN,
enforce_bearer_rate_limit,
)
@pytest.fixture
def mock_redis():
return MagicMock()
def test_limit_bearer_per_token_uses_60_per_minute_default():
assert LIMIT_BEARER_PER_TOKEN.limit == 60
assert LIMIT_BEARER_PER_TOKEN.window == timedelta(minutes=1)
def test_seconds_until_available_returns_remaining_window(mock_redis):
"""ZSET oldest entry score = 100; window = 60s; now = 130s → remaining = 30s."""
rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis)
mock_redis.zrange.return_value = [(b"member-1", 100.0)]
with patch("libs.helper.time.time", return_value=130):
assert rl.seconds_until_available("k1") == 30
def test_seconds_until_available_floor_one_second(mock_redis):
"""Even when math says <1s remaining, return at least 1 so client backs off measurably."""
rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis)
mock_redis.zrange.return_value = [(b"member-1", 119.5)]
with patch("libs.helper.time.time", return_value=180):
# window expired (180 > 119.5+60=179.5 by 0.5s) — bucket is actually free now
# but this method only called when is_rate_limited() == True; defensive floor.
assert rl.seconds_until_available("k1") >= 1
def test_seconds_until_available_empty_bucket(mock_redis):
"""No entries → 1s sentinel (defensive; should not be reached when limited)."""
rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis)
mock_redis.zrange.return_value = []
assert rl.seconds_until_available("k1") == 1
@patch("libs.rate_limit._build_limiter")
def test_enforce_bearer_rate_limit_passes_under_limit(mock_build):
limiter = MagicMock()
limiter.is_rate_limited.return_value = False
mock_build.return_value = limiter
enforce_bearer_rate_limit("hash-1")
limiter.increment_rate_limit.assert_called_once_with("token:hash-1")
@patch("libs.rate_limit._build_limiter")
def test_enforce_bearer_rate_limit_raises_429_with_retry_after(mock_build):
limiter = MagicMock()
limiter.is_rate_limited.return_value = True
limiter.seconds_until_available.return_value = 23
mock_build.return_value = limiter
with pytest.raises(TooManyRequests) as exc:
enforce_bearer_rate_limit("hash-1")
headers = dict(exc.value.get_response().headers)
assert headers.get("Retry-After") == "23"
body = exc.value.get_response().get_json() or {}
assert body.get("error") == "rate_limited"
assert body.get("retry_after_ms") == 23000

View File

@ -0,0 +1,94 @@
"""Unit tests for require_workspace_member."""
from __future__ import annotations
import uuid
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member
def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext:
return AuthContext(
subject_type=SubjectType.ACCOUNT if account else SubjectType.EXTERNAL_SSO,
subject_email="e@example.com",
subject_issuer=None,
account_id=uuid.uuid4() if account else None,
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants=dict(verified or {}),
)
@patch("libs.oauth_bearer.dify_config")
def test_skips_when_enterprise_enabled(mock_cfg):
mock_cfg.ENTERPRISE_ENABLED = True
require_workspace_member(_ctx(), "t1")
@patch("libs.oauth_bearer.dify_config")
def test_skips_for_external_sso(mock_cfg):
mock_cfg.ENTERPRISE_ENABLED = False
require_workspace_member(_ctx(account=False), "t1")
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_uses_cached_ok_no_db_access(mock_cfg, mock_db):
mock_cfg.ENTERPRISE_ENABLED = False
require_workspace_member(_ctx({"t1": True}), "t1")
mock_db.session.execute.assert_not_called()
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_uses_cached_denied(mock_cfg, mock_db):
mock_cfg.ENTERPRISE_ENABLED = False
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
require_workspace_member(_ctx({"t1": False}), "t1")
mock_db.session.execute.assert_not_called()
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_denies_when_no_membership(mock_cfg, mock_db, mock_record):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
require_workspace_member(_ctx({}), "t1")
mock_record.assert_called_once_with("h1", "t1", False)
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_denies_when_account_inactive(mock_cfg, mock_db, mock_record):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
]
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
require_workspace_member(_ctx({}), "t1")
mock_record.assert_called_once_with("h1", "t1", False)
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_allows_active_member(mock_cfg, mock_db, mock_record):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
]
require_workspace_member(_ctx({}), "t1")
mock_record.assert_called_once_with("h1", "t1", True)

View File

@ -0,0 +1,57 @@
from unittest.mock import patch
import pytest
from services.enterprise.app_permitted_service import PermittedAppsPage, list_permitted_apps
from services.errors.enterprise import EnterpriseAPIError
WRAPPER = "services.enterprise.app_permitted_service.EnterpriseService.WebAppAuth.list_externally_accessible_apps"
def test_list_permitted_apps_decodes_camelcase_response():
fake_body = {
"data": [{"appId": "a"}, {"appId": "b"}],
"total": 2,
"hasMore": False,
}
with patch(WRAPPER, return_value=fake_body) as m:
page = list_permitted_apps(page=1, limit=10)
assert isinstance(page, PermittedAppsPage)
assert page.total == 2
assert page.has_more is False
assert page.app_ids == ["a", "b"]
m.assert_called_once_with(page=1, limit=10, mode=None, name=None)
def test_list_permitted_apps_passes_filters_to_wrapper():
fake_body = {"data": [], "total": 0, "hasMore": False}
with patch(WRAPPER, return_value=fake_body) as m:
list_permitted_apps(page=2, limit=5, mode="workflow", name="alpha")
m.assert_called_once_with(page=2, limit=5, mode="workflow", name="alpha")
def test_list_permitted_apps_503_on_ee_error():
with patch(WRAPPER, side_effect=EnterpriseAPIError("boom", status_code=500)):
from werkzeug.exceptions import ServiceUnavailable
with pytest.raises(ServiceUnavailable):
list_permitted_apps(page=1, limit=10)
def test_list_permitted_apps_503_on_status_error():
with patch(WRAPPER, side_effect=EnterpriseAPIError("bad key", status_code=401)):
from werkzeug.exceptions import ServiceUnavailable
with pytest.raises(ServiceUnavailable):
list_permitted_apps(page=1, limit=10)
def test_list_permitted_apps_handles_empty_response():
fake_body = {"data": [], "total": 0, "hasMore": False}
with patch(WRAPPER, return_value=fake_body):
page = list_permitted_apps(page=1, limit=10)
assert page.app_ids == []
assert page.total == 0
assert page.has_more is False

View File

@ -188,6 +188,31 @@ class TestWebAppAuth:
req.send_request.assert_called_once_with("DELETE", "/webapp/clean", params={"appId": "a1"})
def test_list_externally_accessible_apps_minimal_call(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"data": [], "total": 0, "hasMore": False}
result = EnterpriseService.WebAppAuth.list_externally_accessible_apps(page=1, limit=10)
assert result == {"data": [], "total": 0, "hasMore": False}
req.send_request.assert_called_once_with(
"POST",
"/webapp/externally-accessible-apps",
json={"page": 1, "limit": 10},
timeout=5.0,
)
def test_list_externally_accessible_apps_with_filters(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"data": [], "total": 0, "hasMore": False}
EnterpriseService.WebAppAuth.list_externally_accessible_apps(page=2, limit=5, mode="workflow", name="alpha")
req.send_request.assert_called_once_with(
"POST",
"/webapp/externally-accessible-apps",
json={"page": 2, "limit": 5, "mode": "workflow", "name": "alpha"},
timeout=5.0,
)
class TestJoinDefaultWorkspace:
def test_join_default_workspace_success(self):

Some files were not shown because too many files have changed in this diff Show More