mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
feat(openapi): app-run endpoints with auth pipeline
Ports service_api/app/{completion,workflow}.py to bearer-authed
/openapi/v1/apps/<app_id>/{info,chat-messages,completion-messages,workflows/run}.
Architecture:
- New controllers/openapi/auth/ package: Pipeline + Step protocol over
one mutable Context. Endpoints attach via @APP_PIPELINE.guard(scope=...)
— single attachment point; forgetting auth is structurally impossible.
- Pipeline order: BearerCheck -> ScopeCheck -> AppResolver -> AppAuthzCheck
-> CallerMount.
- Strategies vary along independent axes: AclStrategy (EE webapp-auth inner
API) vs MembershipStrategy (CE TenantAccountJoin); AccountMounter vs
EndUserMounter dispatched by SubjectType.
- App is in URL path (not header). Each non-GET has typed Pydantic Request;
each non-SSE response has typed Pydantic Response. Bearer-as-identity:
body 'user' field stripped, ignored if present.
Adds InvokeFrom.OPENAPI enum variant. Emits app.run.openapi audit log
on successful invocation via standard logger extra={"audit": True, ...}
convention.
This commit is contained in:
parent
85c3f9cbf8
commit
cf5ebe9430
@ -16,13 +16,27 @@ api = ExternalApi(
|
||||
|
||||
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||
|
||||
from . import account, index, oauth_device, oauth_device_sso, workspaces
|
||||
from . import (
|
||||
account,
|
||||
app_info,
|
||||
chat_messages,
|
||||
completion_messages,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workflow_run,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"app_info",
|
||||
"chat_messages",
|
||||
"completion_messages",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workflow_run",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
|
||||
32
api/controllers/openapi/_audit.py
Normal file
32
api/controllers/openapi/_audit.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""Audit emission for openapi app-run endpoints.
|
||||
|
||||
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||
|
||||
|
||||
def emit_app_run(*, app_id: str, tenant_id: str, caller_kind: str, mode: str) -> None:
|
||||
logger.info(
|
||||
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s",
|
||||
EVENT_APP_RUN_OPENAPI,
|
||||
app_id,
|
||||
tenant_id,
|
||||
caller_kind,
|
||||
mode,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_APP_RUN_OPENAPI,
|
||||
"app_id": app_id,
|
||||
"tenant_id": tenant_id,
|
||||
"caller_kind": caller_kind,
|
||||
"mode": mode,
|
||||
},
|
||||
)
|
||||
17
api/controllers/openapi/_models.py
Normal file
17
api/controllers/openapi/_models.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
usage: UsageInfo | None = None
|
||||
retriever_resources: list[dict[str, Any]] = []
|
||||
36
api/controllers/openapi/app_info.py
Normal file
36
api/controllers/openapi/app_info.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""GET /openapi/v1/apps/<app_id>/info — port of service_api/app/app.py:AppInfoApi."""
|
||||
from __future__ import annotations
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import APP_PIPELINE
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author_name: str | None = None
|
||||
tags: list[str] = []
|
||||
|
||||
|
||||
def _unpack_app(app_model):
|
||||
return app_model
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/info")
|
||||
class AppInfoApi(Resource):
|
||||
@APP_PIPELINE.guard(scope="apps:run")
|
||||
def get(self, app_id, app_model, caller, caller_kind):
|
||||
app = _unpack_app(app_model)
|
||||
return AppInfoResponse(
|
||||
id=app.id,
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
author_name=app.author_name,
|
||||
tags=[t.name for t in app.tags],
|
||||
).model_dump(mode="json")
|
||||
3
api/controllers/openapi/auth/__init__.py
Normal file
3
api/controllers/openapi/auth/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from controllers.openapi.auth.composition import APP_PIPELINE
|
||||
|
||||
__all__ = ["APP_PIPELINE"]
|
||||
37
api/controllers/openapi/auth/composition.py
Normal file
37
api/controllers/openapi/auth/composition.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""APP_PIPELINE — the only auth scheme for openapi app endpoints.
|
||||
|
||||
Endpoints attach via @APP_PIPELINE.guard(scope=…). No alternative paths.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
APP_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
)
|
||||
39
api/controllers/openapi/auth/context.py
Normal file
39
api/controllers/openapi/auth/context.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Literal, Protocol
|
||||
|
||||
from flask import Request
|
||||
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
request: Request
|
||||
required_scope: str
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: str | None = None
|
||||
scopes: frozenset[str] = field(default_factory=frozenset)
|
||||
token_id: str | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: object | None = None
|
||||
tenant: object | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
39
api/controllers/openapi/auth/pipeline.py
Normal file
39
api/controllers/openapi/auth/pipeline.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
|
||||
def guard(self, *, scope: str):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
ctx = Context(request=request, required_scope=scope)
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
112
api/controllers/openapi/auth/steps.py
Normal file
112
api/controllers/openapi/auth/steps.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
BearerCheck is the only step that touches the token registry; downstream
|
||||
steps see only the populated Context.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import TokenExpired, get_authenticator, sha256_hex
|
||||
from models import App, Tenant, TenantStatus
|
||||
|
||||
|
||||
def _registry():
|
||||
return get_authenticator()._registry # noqa: SLF001
|
||||
|
||||
|
||||
def _extract_bearer(req) -> str | None:
|
||||
auth = req.headers.get("Authorization")
|
||||
if not auth or not auth.lower().startswith("bearer "):
|
||||
return None
|
||||
return auth.split(None, 1)[1].strip() or None
|
||||
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
return sha256_hex(token)
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
token = _extract_bearer(ctx.request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
kind = _registry().find(token)
|
||||
if kind is None:
|
||||
raise Unauthorized("invalid bearer prefix")
|
||||
|
||||
try:
|
||||
row = kind.resolver.resolve(_hash_token(token))
|
||||
except TokenExpired:
|
||||
raise Unauthorized("token expired")
|
||||
if row is None:
|
||||
raise Unauthorized("invalid bearer")
|
||||
|
||||
ctx.subject_type = kind.subject_type
|
||||
ctx.subject_email = row.subject_email
|
||||
ctx.subject_issuer = row.subject_issuer
|
||||
ctx.account_id = row.account_id
|
||||
ctx.scopes = kind.scopes
|
||||
ctx.source = kind.source
|
||||
ctx.token_id = row.token_id
|
||||
ctx.expires_at = row.expires_at
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if "full" in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read app_id from request.view_args, populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using APP_PIPELINE must declare ``<string:app_id>`` in
|
||||
its route — that is the design lock-in (no body / header coupling).
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = (ctx.request.view_args or {}).get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = db.session.get(App, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = db.session.get(Tenant, app.tenant_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
103
api/controllers/openapi/auth/strategies.py
Normal file
103
api/controllers/openapi/auth/strategies.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from models import Account, TenantAccountJoin
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL via the workspace-auth inner API.
|
||||
|
||||
Used when webapp-auth is enabled (EE deployment). The inner-API
|
||||
allowlist is the source of truth.
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=ctx.subject_email,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool:
|
||||
if not account_id:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user) # noqa: SLF001
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # noqa: SLF001
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, st: SubjectType) -> bool:
|
||||
return st == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
account = db.session.get(Account, ctx.account_id)
|
||||
account.current_tenant = ctx.tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, st: SubjectType) -> bool:
|
||||
return st == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
167
api/controllers/openapi/chat_messages.py
Normal file
167
api/controllers/openapi/chat_messages.py
Normal file
@ -0,0 +1,167 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/chat-messages — port of
|
||||
service_api/app/completion.py:ChatApi.
|
||||
|
||||
Differences from service_api:
|
||||
- App is in URL path, not header.
|
||||
- One decorator: @APP_PIPELINE.guard(scope="apps:run").
|
||||
- Request body has no `user` field (Model 2: identity is the bearer).
|
||||
- Typed Request and Response models.
|
||||
- invoke_from = InvokeFrom.OPENAPI.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import MessageMetadata
|
||||
from controllers.openapi.auth.composition import APP_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NotChatAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = Field(default=None)
|
||||
auto_generate_name: bool = Field(default=True)
|
||||
workflow_id: str | None = Field(default=None)
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
def _unpack_app(app_model):
|
||||
return app_model
|
||||
|
||||
|
||||
def _unpack_caller(caller):
|
||||
return caller
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/chat-messages")
|
||||
class ChatMessagesApi(Resource):
|
||||
@APP_PIPELINE.guard(scope="apps:run")
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
app = _unpack_app(app_model)
|
||||
if AppMode.value_of(app.mode) not in {
|
||||
AppMode.CHAT,
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.ADVANCED_CHAT,
|
||||
}:
|
||||
raise NotChatAppError()
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = ChatMessageRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=_unpack_caller(caller),
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
body_dict = response[0] if isinstance(response, tuple) else response
|
||||
return ChatMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200
|
||||
124
api/controllers/openapi/completion_messages.py
Normal file
124
api/controllers/openapi/completion_messages.py
Normal file
@ -0,0 +1,124 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/completion-messages — port of
|
||||
service_api/app/completion.py:CompletionApi."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import MessageMetadata
|
||||
from controllers.openapi.auth.composition import APP_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessageRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = Field(default="")
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
def _unpack_app(app_model):
|
||||
return app_model
|
||||
|
||||
|
||||
def _unpack_caller(caller):
|
||||
return caller
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/completion-messages")
|
||||
class CompletionMessagesApi(Resource):
|
||||
@APP_PIPELINE.guard(scope="apps:run")
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
app = _unpack_app(app_model)
|
||||
if AppMode.value_of(app.mode) != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = CompletionMessageRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=_unpack_caller(caller),
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
body_dict = response[0] if isinstance(response, tuple) else response
|
||||
return CompletionMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200
|
||||
132
api/controllers/openapi/workflow_run.py
Normal file
132
api/controllers/openapi/workflow_run.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/workflows/run — port of
|
||||
service_api/app/workflow.py:WorkflowRunApi."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi.auth.composition import APP_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
CompletionRequestError,
|
||||
NotWorkflowAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunRequest(WorkflowRunPayloadBase):
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = {}
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
data: WorkflowRunData
|
||||
|
||||
|
||||
def _unpack_app(app_model):
|
||||
return app_model
|
||||
|
||||
|
||||
def _unpack_caller(caller):
|
||||
return caller
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/workflows/run")
|
||||
class WorkflowRunApi(Resource):
|
||||
@APP_PIPELINE.guard(scope="apps:run")
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
app = _unpack_app(app_model)
|
||||
if AppMode.value_of(app.mode) != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = WorkflowRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=_unpack_caller(caller),
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
body_dict = response[0] if isinstance(response, tuple) else response
|
||||
return WorkflowRunResponse.model_validate(body_dict).model_dump(mode="json"), 200
|
||||
@ -24,6 +24,7 @@ class UserFrom(StrEnum):
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
OPENAPI = "openapi"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
|
||||
@ -0,0 +1,49 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.openapi.auth.composition import APP_PIPELINE, _resolve_app_authz_strategy
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
|
||||
|
||||
def test_app_pipeline_is_composed():
|
||||
assert isinstance(APP_PIPELINE, Pipeline)
|
||||
|
||||
|
||||
def test_app_pipeline_step_order():
|
||||
steps = APP_PIPELINE._steps
|
||||
assert isinstance(steps[0], BearerCheck)
|
||||
assert isinstance(steps[1], ScopeCheck)
|
||||
assert isinstance(steps[2], AppResolver)
|
||||
assert isinstance(steps[3], AppAuthzCheck)
|
||||
assert isinstance(steps[4], CallerMount)
|
||||
|
||||
|
||||
def test_caller_mount_has_both_mounters():
|
||||
cm = APP_PIPELINE._steps[4]
|
||||
kinds = {type(m) for m in cm._mounters}
|
||||
assert AccountMounter in kinds
|
||||
assert EndUserMounter in kinds
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_acl_when_enabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = True
|
||||
assert isinstance(_resolve_app_authz_strategy(), AclStrategy)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_membership_when_disabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = False
|
||||
assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy)
|
||||
@ -0,0 +1,21 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
|
||||
|
||||
def test_context_starts_unpopulated():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
assert ctx.subject_type is None
|
||||
assert ctx.subject_email is None
|
||||
assert ctx.account_id is None
|
||||
assert ctx.scopes == frozenset()
|
||||
assert ctx.app is None
|
||||
assert ctx.tenant is None
|
||||
assert ctx.caller is None
|
||||
assert ctx.caller_kind is None
|
||||
|
||||
|
||||
def test_context_fields_are_mutable():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
ctx.scopes = frozenset({"full"})
|
||||
assert "full" in ctx.scopes
|
||||
@ -0,0 +1,61 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
def test_run_invokes_each_step_in_order():
|
||||
calls = []
|
||||
|
||||
class S:
|
||||
def __init__(self, tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(self, ctx):
|
||||
calls.append(self.tag)
|
||||
|
||||
Pipeline(S("a"), S("b"), S("c")).run(Context(request=MagicMock(), required_scope="x"))
|
||||
assert calls == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_run_short_circuits_on_raise():
|
||||
calls = []
|
||||
|
||||
class Boom:
|
||||
def __call__(self, ctx):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Tail:
|
||||
def __call__(self, ctx):
|
||||
calls.append("ran")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Pipeline(Boom(), Tail()).run(Context(request=MagicMock(), required_scope="x"))
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs():
|
||||
seen = {}
|
||||
|
||||
class FakeStep:
|
||||
def __call__(self, ctx):
|
||||
ctx.app = "APP"
|
||||
ctx.caller = "CALLER"
|
||||
ctx.caller_kind = "account"
|
||||
|
||||
pipeline = Pipeline(FakeStep())
|
||||
|
||||
@pipeline.guard(scope="apps:run")
|
||||
def handler(app_model, caller, caller_kind):
|
||||
seen["app_model"] = app_model
|
||||
seen["caller"] = caller
|
||||
seen["caller_kind"] = caller_kind
|
||||
return "ok"
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/x", method="POST"):
|
||||
assert handler() == "ok"
|
||||
assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"}
|
||||
@ -0,0 +1,64 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppResolver
|
||||
from models import TenantStatus
|
||||
|
||||
|
||||
def _ctx(view_args):
|
||||
req = MagicMock()
|
||||
req.view_args = view_args
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
|
||||
|
||||
def _app(*, status="normal", enable_api=True):
|
||||
return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api)
|
||||
|
||||
|
||||
def _tenant(*, status=TenantStatus.NORMAL):
|
||||
return SimpleNamespace(id="t1", status=status)
|
||||
|
||||
|
||||
def test_resolver_rejects_missing_path_param():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx({}))
|
||||
|
||||
|
||||
def test_resolver_rejects_none_view_args():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx(None))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_404_when_app_missing(db):
|
||||
db.session.get.side_effect = [None]
|
||||
with pytest.raises(NotFound):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_disabled(db):
|
||||
db.session.get.side_effect = [_app(enable_api=False)]
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
assert "service_api_disabled" in str(exc.value.description)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_tenant_archived(db):
|
||||
db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)]
|
||||
with pytest.raises(Forbidden):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_populates_app_and_tenant(db):
|
||||
db.session.get.side_effect = [_app(), _tenant()]
|
||||
ctx = _ctx({"app_id": "x"})
|
||||
AppResolver()(ctx)
|
||||
assert ctx.app.id == "app1"
|
||||
assert ctx.tenant.id == "t1"
|
||||
@ -0,0 +1,58 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppAuthzCheck
|
||||
from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id="acc1"):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.subject_email = "alice@example.com"
|
||||
c.account_id = account_id
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_calls_inner_api(ent):
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True
|
||||
assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with(
|
||||
user_id="alice@example.com",
|
||||
app_id="app1",
|
||||
)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._has_tenant_membership")
|
||||
def test_membership_strategy_uses_join_lookup(member):
|
||||
member.return_value = True
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
member.assert_called_once_with("acc1", "t1")
|
||||
|
||||
|
||||
def test_membership_strategy_rejects_external_sso():
|
||||
assert (
|
||||
MembershipStrategy().authorize(
|
||||
_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_app_authz_check_raises_when_strategy_denies():
|
||||
deny = SimpleNamespace(authorize=lambda c: False)
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
assert "subject_no_app_access" in str(exc.value.description)
|
||||
|
||||
|
||||
def test_app_authz_check_passes_when_strategy_allows():
|
||||
allow = SimpleNamespace(authorize=lambda c: True)
|
||||
AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -0,0 +1,58 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import BearerCheck
|
||||
from libs.oauth_bearer import ResolvedRow, SubjectType
|
||||
|
||||
|
||||
def _ctx(headers):
|
||||
req = MagicMock()
|
||||
req.headers = headers
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
|
||||
|
||||
def test_bearer_check_rejects_missing_header():
|
||||
with pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps._registry")
|
||||
def test_bearer_check_rejects_unknown_prefix(reg):
|
||||
reg.return_value.find.return_value = None
|
||||
with pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps._registry")
|
||||
def test_bearer_check_populates_context(reg):
|
||||
tok_id = uuid.uuid4()
|
||||
fake_resolver = MagicMock()
|
||||
fake_resolver.resolve.return_value = ResolvedRow(
|
||||
subject_email="a@x.com",
|
||||
subject_issuer=None,
|
||||
account_id=None,
|
||||
token_id=tok_id,
|
||||
expires_at=datetime.now(UTC),
|
||||
)
|
||||
fake_kind = SimpleNamespace(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({"full"}),
|
||||
source="oauth-account",
|
||||
resolver=fake_resolver,
|
||||
)
|
||||
reg.return_value.find.return_value = fake_kind
|
||||
|
||||
ctx = _ctx({"Authorization": "Bearer dfoa_abc"})
|
||||
BearerCheck()(ctx)
|
||||
|
||||
assert ctx.subject_type == SubjectType.ACCOUNT
|
||||
assert ctx.subject_email == "a@x.com"
|
||||
assert ctx.scopes == frozenset({"full"})
|
||||
assert ctx.source == "oauth-account"
|
||||
assert ctx.token_id == tok_id
|
||||
@ -0,0 +1,77 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import CallerMount
|
||||
from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id=None, subject_email=None):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.subject_email = subject_email
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.db")
|
||||
def test_account_mounter(db, login):
|
||||
account = SimpleNamespace()
|
||||
db.session.get.return_value = account
|
||||
ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1")
|
||||
AccountMounter().mount(ctx)
|
||||
assert ctx.caller is account
|
||||
assert ctx.caller.current_tenant is ctx.tenant
|
||||
assert ctx.caller_kind == "account"
|
||||
login.assert_called_once_with(account)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.EndUserService")
|
||||
def test_end_user_mounter(svc, login):
|
||||
eu = SimpleNamespace()
|
||||
svc.get_or_create_end_user_by_type.return_value = eu
|
||||
ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com")
|
||||
EndUserMounter().mount(ctx)
|
||||
svc.get_or_create_end_user_by_type.assert_called_once_with(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id="t1",
|
||||
app_id="app1",
|
||||
user_id="a@x.com",
|
||||
)
|
||||
assert ctx.caller is eu
|
||||
assert ctx.caller_kind == "end_user"
|
||||
|
||||
|
||||
def test_caller_mount_dispatches_by_subject_type():
|
||||
seen = {}
|
||||
|
||||
class Fake:
|
||||
def __init__(self, st, tag):
|
||||
self._st, self._tag = st, tag
|
||||
|
||||
def applies_to(self, st):
|
||||
return st == self._st
|
||||
|
||||
def mount(self, ctx):
|
||||
seen["who"] = self._tag
|
||||
|
||||
cm = CallerMount(
|
||||
Fake(SubjectType.ACCOUNT, "acct"),
|
||||
Fake(SubjectType.EXTERNAL_SSO, "sso"),
|
||||
)
|
||||
cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO))
|
||||
assert seen == {"who": "sso"}
|
||||
|
||||
|
||||
def test_caller_mount_raises_when_none_applies():
|
||||
with pytest.raises(Unauthorized):
|
||||
CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -0,0 +1,27 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import ScopeCheck
|
||||
|
||||
|
||||
def _ctx(scopes, required):
|
||||
c = Context(request=MagicMock(), required_scope=required)
|
||||
c.scopes = frozenset(scopes)
|
||||
return c
|
||||
|
||||
|
||||
def test_scope_check_passes_on_full():
|
||||
ScopeCheck()(_ctx({"full"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_passes_on_explicit_match():
|
||||
ScopeCheck()(_ctx({"apps:run"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_rejects_when_missing():
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
ScopeCheck()(_ctx({"apps:read"}, "apps:run"))
|
||||
assert "insufficient_scope" in str(exc.value.description)
|
||||
14
api/tests/unit_tests/controllers/openapi/conftest.py
Normal file
14
api/tests/unit_tests/controllers/openapi/conftest.py
Normal file
@ -0,0 +1,14 @@
|
||||
import pytest
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bypass_pipeline(monkeypatch):
|
||||
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
|
||||
|
||||
Module-level @APP_PIPELINE.guard(...) captures the real APP_PIPELINE at
|
||||
import time; mocking the module attribute does not undo that. Patching
|
||||
Pipeline.run on the class is the bypass that actually works.
|
||||
"""
|
||||
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)
|
||||
46
api/tests/unit_tests/controllers/openapi/test_app_info.py
Normal file
46
api/tests/unit_tests/controllers/openapi/test_app_info.py
Normal file
@ -0,0 +1,46 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import app_info # noqa: F401
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
api.add_namespace(openapi_ns, path="/openapi/v1")
|
||||
return app.test_client()
|
||||
|
||||
|
||||
def test_app_info_returns_response_model(bypass_pipeline):
|
||||
app_obj = SimpleNamespace(
|
||||
id="app1",
|
||||
name="X",
|
||||
description="d",
|
||||
mode="chat",
|
||||
author_name="alice",
|
||||
tags=[SimpleNamespace(name="prod")],
|
||||
)
|
||||
with patch("controllers.openapi.app_info._unpack_app", return_value=app_obj):
|
||||
r = _client().get("/openapi/v1/apps/app1/info")
|
||||
assert r.status_code == 200
|
||||
body = r.get_json()
|
||||
assert body == {
|
||||
"id": "app1",
|
||||
"name": "X",
|
||||
"description": "d",
|
||||
"mode": "chat",
|
||||
"author_name": "alice",
|
||||
"tags": ["prod"],
|
||||
}
|
||||
|
||||
|
||||
def test_app_info_response_model_validates():
|
||||
from controllers.openapi.app_info import AppInfoResponse
|
||||
|
||||
m = AppInfoResponse(id="x", name="N", mode="chat")
|
||||
assert m.tags == []
|
||||
assert m.description is None
|
||||
@ -0,0 +1,19 @@
|
||||
import logging
|
||||
|
||||
from controllers.openapi._audit import EVENT_APP_RUN_OPENAPI, emit_app_run
|
||||
|
||||
|
||||
def test_event_constant():
|
||||
assert EVENT_APP_RUN_OPENAPI == "app.run.openapi"
|
||||
|
||||
|
||||
def test_emit_app_run_logs_with_audit_extra(caplog):
|
||||
with caplog.at_level(logging.INFO, logger="controllers.openapi._audit"):
|
||||
emit_app_run(app_id="app1", tenant_id="t1", caller_kind="account", mode="chat")
|
||||
record = next(r for r in caplog.records if r.message and "app.run.openapi" in r.message)
|
||||
assert record.audit is True
|
||||
assert record.event == EVENT_APP_RUN_OPENAPI
|
||||
assert record.app_id == "app1"
|
||||
assert record.tenant_id == "t1"
|
||||
assert record.caller_kind == "account"
|
||||
assert record.mode == "chat"
|
||||
@ -0,0 +1,89 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import chat_messages # noqa: F401
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
api.add_namespace(openapi_ns, path="/openapi/v1")
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch("controllers.openapi.chat_messages.AppGenerateService")
|
||||
def test_chat_dispatches_and_returns_response_model(svc, bypass_pipeline):
|
||||
svc.generate.return_value = (
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": "tk1",
|
||||
"id": "m1",
|
||||
"message_id": "m1",
|
||||
"conversation_id": "c1",
|
||||
"mode": "chat",
|
||||
"answer": "hi",
|
||||
"metadata": {},
|
||||
"created_at": 1700000000,
|
||||
},
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch(
|
||||
"controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()
|
||||
):
|
||||
r = _client().post(
|
||||
"/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}}
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.get_json()
|
||||
assert body["conversation_id"] == "c1"
|
||||
assert body["answer"] == "hi"
|
||||
assert svc.generate.call_args.kwargs["invoke_from"].value == "openapi"
|
||||
|
||||
|
||||
@patch("controllers.openapi.chat_messages.AppGenerateService")
|
||||
def test_chat_strips_user_field_from_body(svc, bypass_pipeline):
|
||||
svc.generate.return_value = (
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": "tk1",
|
||||
"id": "m1",
|
||||
"message_id": "m1",
|
||||
"conversation_id": "c1",
|
||||
"mode": "chat",
|
||||
"answer": "hi",
|
||||
"metadata": {},
|
||||
"created_at": 1700000000,
|
||||
},
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch(
|
||||
"controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()
|
||||
):
|
||||
_client().post(
|
||||
"/openapi/v1/apps/app1/chat-messages",
|
||||
json={"query": "hi", "inputs": {}, "user": "spoof@x.com"},
|
||||
)
|
||||
args = svc.generate.call_args.kwargs["args"]
|
||||
assert "user" not in args
|
||||
|
||||
|
||||
def test_chat_rejects_non_chat_mode(bypass_pipeline):
|
||||
fake = SimpleNamespace(mode="completion")
|
||||
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake):
|
||||
r = _client().post(
|
||||
"/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}}
|
||||
)
|
||||
assert r.status_code in (400, 403)
|
||||
|
||||
|
||||
def test_chat_rejects_invalid_body(bypass_pipeline):
|
||||
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake):
|
||||
r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi"})
|
||||
assert r.status_code in (400, 422)
|
||||
@ -0,0 +1,54 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import completion_messages # noqa: F401
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
api.add_namespace(openapi_ns, path="/openapi/v1")
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch("controllers.openapi.completion_messages.AppGenerateService")
|
||||
def test_completion_returns_response_model(svc, bypass_pipeline):
|
||||
svc.generate.return_value = (
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": "tk",
|
||||
"id": "m1",
|
||||
"message_id": "m1",
|
||||
"mode": "completion",
|
||||
"answer": "ok",
|
||||
"metadata": {},
|
||||
"created_at": 1700000000,
|
||||
},
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="completion", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.completion_messages._unpack_app", return_value=fake), patch(
|
||||
"controllers.openapi.completion_messages._unpack_caller", return_value=SimpleNamespace()
|
||||
):
|
||||
r = _client().post(
|
||||
"/openapi/v1/apps/app1/completion-messages",
|
||||
json={"inputs": {"x": 1}, "query": "hi"},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.get_json()
|
||||
assert body["answer"] == "ok"
|
||||
assert svc.generate.call_args.kwargs["invoke_from"].value == "openapi"
|
||||
|
||||
|
||||
def test_completion_rejects_chat_mode(bypass_pipeline):
|
||||
fake = SimpleNamespace(mode="chat")
|
||||
with patch("controllers.openapi.completion_messages._unpack_app", return_value=fake):
|
||||
r = _client().post(
|
||||
"/openapi/v1/apps/app1/completion-messages",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
)
|
||||
assert r.status_code in (400, 403)
|
||||
14
api/tests/unit_tests/controllers/openapi/test_models.py
Normal file
14
api/tests/unit_tests/controllers/openapi/test_models.py
Normal file
@ -0,0 +1,14 @@
|
||||
from controllers.openapi._models import MessageMetadata, UsageInfo
|
||||
|
||||
|
||||
def test_usage_info_defaults_zero():
|
||||
u = UsageInfo()
|
||||
assert u.prompt_tokens == 0
|
||||
assert u.completion_tokens == 0
|
||||
assert u.total_tokens == 0
|
||||
|
||||
|
||||
def test_message_metadata_accepts_partial():
|
||||
m = MessageMetadata(usage=UsageInfo(total_tokens=10))
|
||||
assert m.usage.total_tokens == 10
|
||||
assert m.retriever_resources == []
|
||||
@ -0,0 +1,50 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi import workflow_run # noqa: F401
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
api.add_namespace(openapi_ns, path="/openapi/v1")
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch("controllers.openapi.workflow_run.AppGenerateService")
|
||||
def test_workflow_run_returns_response_model(svc, bypass_pipeline):
|
||||
svc.generate.return_value = (
|
||||
{
|
||||
"workflow_run_id": "wr1",
|
||||
"task_id": "tk",
|
||||
"data": {
|
||||
"id": "wr1",
|
||||
"workflow_id": "wf1",
|
||||
"status": "succeeded",
|
||||
"outputs": {"result": "ok"},
|
||||
"elapsed_time": 1.0,
|
||||
},
|
||||
},
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="workflow", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.workflow_run._unpack_app", return_value=fake), patch(
|
||||
"controllers.openapi.workflow_run._unpack_caller", return_value=SimpleNamespace()
|
||||
):
|
||||
r = _client().post("/openapi/v1/apps/app1/workflows/run", json={"inputs": {"x": 1}})
|
||||
assert r.status_code == 200
|
||||
body = r.get_json()
|
||||
assert body["workflow_run_id"] == "wr1"
|
||||
assert body["data"]["status"] == "succeeded"
|
||||
assert svc.generate.call_args.kwargs["invoke_from"].value == "openapi"
|
||||
|
||||
|
||||
def test_workflow_run_rejects_non_workflow(bypass_pipeline):
|
||||
fake = SimpleNamespace(mode="chat")
|
||||
with patch("controllers.openapi.workflow_run._unpack_app", return_value=fake):
|
||||
r = _client().post("/openapi/v1/apps/app1/workflows/run", json={"inputs": {}})
|
||||
assert r.status_code in (400, 403)
|
||||
9
api/tests/unit_tests/core/app/test_invoke_from.py
Normal file
9
api/tests/unit_tests/core/app/test_invoke_from.py
Normal file
@ -0,0 +1,9 @@
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
|
||||
def test_openapi_variant_present():
|
||||
assert InvokeFrom.OPENAPI.value == "openapi"
|
||||
|
||||
|
||||
def test_openapi_distinct_from_service_api():
|
||||
assert InvokeFrom.OPENAPI != InvokeFrom.SERVICE_API
|
||||
Loading…
Reference in New Issue
Block a user