From cf5ebe9430e9f14f92dc3a7c3db6541ef4e823cf Mon Sep 17 00:00:00 2001 From: GareArc Date: Mon, 27 Apr 2026 17:25:17 -0700 Subject: [PATCH] feat(openapi): app-run endpoints with auth pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ports service_api/app/{completion,workflow}.py to bearer-authed /openapi/v1/apps//{info,chat-messages,completion-messages,workflows/run}. Architecture: - New controllers/openapi/auth/ package: Pipeline + Step protocol over one mutable Context. Endpoints attach via @APP_PIPELINE.guard(scope=...) — single attachment point; forgetting auth is structurally impossible. - Pipeline order: BearerCheck -> ScopeCheck -> AppResolver -> AppAuthzCheck -> CallerMount. - Strategies vary along independent axes: AclStrategy (EE webapp-auth inner API) vs MembershipStrategy (CE TenantAccountJoin); AccountMounter vs EndUserMounter dispatched by SubjectType. - App is in URL path (not header). Each non-GET has typed Pydantic Request; each non-SSE response has typed Pydantic Response. Bearer-as-identity: body 'user' field stripped, ignored if present. Adds InvokeFrom.OPENAPI enum variant. Emits app.run.openapi audit log on successful invocation via standard logger extra={"audit": True, ...} convention. --- api/controllers/openapi/__init__.py | 16 +- api/controllers/openapi/_audit.py | 32 ++++ api/controllers/openapi/_models.py | 17 ++ api/controllers/openapi/app_info.py | 36 ++++ api/controllers/openapi/auth/__init__.py | 3 + api/controllers/openapi/auth/composition.py | 37 ++++ api/controllers/openapi/auth/context.py | 39 ++++ api/controllers/openapi/auth/pipeline.py | 39 ++++ api/controllers/openapi/auth/steps.py | 112 ++++++++++++ api/controllers/openapi/auth/strategies.py | 103 +++++++++++ api/controllers/openapi/chat_messages.py | 167 ++++++++++++++++++ .../openapi/completion_messages.py | 124 +++++++++++++ api/controllers/openapi/workflow_run.py | 132 ++++++++++++++ api/core/app/entities/app_invoke_entities.py | 1 + .../controllers/openapi/auth/__init__.py | 0 .../openapi/auth/test_composition.py | 49 +++++ .../controllers/openapi/auth/test_context.py | 21 +++ .../controllers/openapi/auth/test_pipeline.py | 61 +++++++ .../openapi/auth/test_step_app_resolver.py | 64 +++++++ .../openapi/auth/test_step_authz.py | 58 ++++++ .../openapi/auth/test_step_bearer.py | 58 ++++++ .../openapi/auth/test_step_mount.py | 77 ++++++++ .../openapi/auth/test_step_scope.py | 27 +++ .../controllers/openapi/conftest.py | 14 ++ .../controllers/openapi/test_app_info.py | 46 +++++ .../controllers/openapi/test_audit_app_run.py | 19 ++ .../controllers/openapi/test_chat_messages.py | 89 ++++++++++ .../openapi/test_completion_messages.py | 54 ++++++ .../controllers/openapi/test_models.py | 14 ++ .../controllers/openapi/test_workflow_run.py | 50 ++++++ .../unit_tests/core/app/test_invoke_from.py | 9 + 31 files changed, 1567 insertions(+), 1 deletion(-) create mode 100644 api/controllers/openapi/_audit.py create mode 100644 api/controllers/openapi/_models.py create mode 100644 api/controllers/openapi/app_info.py create mode 100644 api/controllers/openapi/auth/__init__.py create mode 100644 api/controllers/openapi/auth/composition.py create mode 100644 api/controllers/openapi/auth/context.py create mode 100644 api/controllers/openapi/auth/pipeline.py create mode 100644 api/controllers/openapi/auth/steps.py create mode 100644 api/controllers/openapi/auth/strategies.py create mode 100644 api/controllers/openapi/chat_messages.py create mode 100644 api/controllers/openapi/completion_messages.py create mode 100644 api/controllers/openapi/workflow_run.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/__init__.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_composition.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_context.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py create mode 100644 api/tests/unit_tests/controllers/openapi/conftest.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_app_info.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_audit_app_run.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_chat_messages.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_completion_messages.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_models.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_workflow_run.py create mode 100644 api/tests/unit_tests/core/app/test_invoke_from.py diff --git a/api/controllers/openapi/__init__.py b/api/controllers/openapi/__init__.py index bb0829e20b..cf30ece801 100644 --- a/api/controllers/openapi/__init__.py +++ b/api/controllers/openapi/__init__.py @@ -16,13 +16,27 @@ api = ExternalApi( openapi_ns = Namespace("openapi", description="User-scoped operations", path="/") -from . import account, index, oauth_device, oauth_device_sso, workspaces +from . import ( + account, + app_info, + chat_messages, + completion_messages, + index, + oauth_device, + oauth_device_sso, + workflow_run, + workspaces, +) __all__ = [ "account", + "app_info", + "chat_messages", + "completion_messages", "index", "oauth_device", "oauth_device_sso", + "workflow_run", "workspaces", ] diff --git a/api/controllers/openapi/_audit.py b/api/controllers/openapi/_audit.py new file mode 100644 index 0000000000..30c3e1d143 --- /dev/null +++ b/api/controllers/openapi/_audit.py @@ -0,0 +1,32 @@ +"""Audit emission for openapi app-run endpoints. + +Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...} +matches the existing oauth_device convention. The EE OTel exporter consults +its own allowlist to decide whether to ship the line. +""" +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +EVENT_APP_RUN_OPENAPI = "app.run.openapi" + + +def emit_app_run(*, app_id: str, tenant_id: str, caller_kind: str, mode: str) -> None: + logger.info( + "audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s", + EVENT_APP_RUN_OPENAPI, + app_id, + tenant_id, + caller_kind, + mode, + extra={ + "audit": True, + "event": EVENT_APP_RUN_OPENAPI, + "app_id": app_id, + "tenant_id": tenant_id, + "caller_kind": caller_kind, + "mode": mode, + }, + ) diff --git a/api/controllers/openapi/_models.py b/api/controllers/openapi/_models.py new file mode 100644 index 0000000000..5971f59e42 --- /dev/null +++ b/api/controllers/openapi/_models.py @@ -0,0 +1,17 @@ +"""Shared response substructures for openapi endpoints.""" +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class MessageMetadata(BaseModel): + usage: UsageInfo | None = None + retriever_resources: list[dict[str, Any]] = [] diff --git a/api/controllers/openapi/app_info.py b/api/controllers/openapi/app_info.py new file mode 100644 index 0000000000..aa9b8d20a7 --- /dev/null +++ b/api/controllers/openapi/app_info.py @@ -0,0 +1,36 @@ +"""GET /openapi/v1/apps//info — port of service_api/app/app.py:AppInfoApi.""" +from __future__ import annotations + +from flask_restx import Resource +from pydantic import BaseModel + +from controllers.openapi import openapi_ns +from controllers.openapi.auth.composition import APP_PIPELINE + + +class AppInfoResponse(BaseModel): + id: str + name: str + description: str | None = None + mode: str + author_name: str | None = None + tags: list[str] = [] + + +def _unpack_app(app_model): + return app_model + + +@openapi_ns.route("/apps//info") +class AppInfoApi(Resource): + @APP_PIPELINE.guard(scope="apps:run") + def get(self, app_id, app_model, caller, caller_kind): + app = _unpack_app(app_model) + return AppInfoResponse( + id=app.id, + name=app.name, + description=app.description, + mode=app.mode, + author_name=app.author_name, + tags=[t.name for t in app.tags], + ).model_dump(mode="json") diff --git a/api/controllers/openapi/auth/__init__.py b/api/controllers/openapi/auth/__init__.py new file mode 100644 index 0000000000..ef255d2491 --- /dev/null +++ b/api/controllers/openapi/auth/__init__.py @@ -0,0 +1,3 @@ +from controllers.openapi.auth.composition import APP_PIPELINE + +__all__ = ["APP_PIPELINE"] diff --git a/api/controllers/openapi/auth/composition.py b/api/controllers/openapi/auth/composition.py new file mode 100644 index 0000000000..fa78f07e3a --- /dev/null +++ b/api/controllers/openapi/auth/composition.py @@ -0,0 +1,37 @@ +"""APP_PIPELINE — the only auth scheme for openapi app endpoints. + +Endpoints attach via @APP_PIPELINE.guard(scope=…). No alternative paths. +""" +from __future__ import annotations + +from controllers.openapi.auth.pipeline import Pipeline +from controllers.openapi.auth.steps import ( + AppAuthzCheck, + AppResolver, + BearerCheck, + CallerMount, + ScopeCheck, +) +from controllers.openapi.auth.strategies import ( + AccountMounter, + AclStrategy, + AppAuthzStrategy, + EndUserMounter, + MembershipStrategy, +) +from services.feature_service import FeatureService + + +def _resolve_app_authz_strategy() -> AppAuthzStrategy: + if FeatureService.get_system_features().webapp_auth.enabled: + return AclStrategy() + return MembershipStrategy() + + +APP_PIPELINE = Pipeline( + BearerCheck(), + ScopeCheck(), + AppResolver(), + AppAuthzCheck(_resolve_app_authz_strategy), + CallerMount(AccountMounter(), EndUserMounter()), +) diff --git a/api/controllers/openapi/auth/context.py b/api/controllers/openapi/auth/context.py new file mode 100644 index 0000000000..a23e4e981d --- /dev/null +++ b/api/controllers/openapi/auth/context.py @@ -0,0 +1,39 @@ +"""Mutable per-request context for the openapi auth pipeline. + +Every field starts None / empty and is filled in by a step. The pipeline +is the only thing that should construct or mutate Context — handlers +read populated values via the decorator's kwargs unpacking. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Literal, Protocol + +from flask import Request + +from libs.oauth_bearer import SubjectType + + +@dataclass +class Context: + request: Request + required_scope: str + subject_type: SubjectType | None = None + subject_email: str | None = None + subject_issuer: str | None = None + account_id: str | None = None + scopes: frozenset[str] = field(default_factory=frozenset) + token_id: str | None = None + source: str | None = None + expires_at: datetime | None = None + app: object | None = None + tenant: object | None = None + caller: object | None = None + caller_kind: Literal["account", "end_user"] | None = None + + +class Step(Protocol): + """One responsibility. Mutate ctx or raise to short-circuit.""" + + def __call__(self, ctx: Context) -> None: ... diff --git a/api/controllers/openapi/auth/pipeline.py b/api/controllers/openapi/auth/pipeline.py new file mode 100644 index 0000000000..c0df85367e --- /dev/null +++ b/api/controllers/openapi/auth/pipeline.py @@ -0,0 +1,39 @@ +"""Pipeline IS the auth scheme. + +`Pipeline.guard(scope=…)` is the only attachment point for endpoints — +that is the design lock-in: forgetting an auth layer is structurally +impossible because there is no "sometimes wrap, sometimes don't" choice. +""" +from __future__ import annotations + +from functools import wraps + +from flask import request + +from controllers.openapi.auth.context import Context, Step + + +class Pipeline: + def __init__(self, *steps: Step) -> None: + self._steps = steps + + def run(self, ctx: Context) -> None: + for step in self._steps: + step(ctx) + + def guard(self, *, scope: str): + def decorator(view): + @wraps(view) + def decorated(*args, **kwargs): + ctx = Context(request=request, required_scope=scope) + self.run(ctx) + kwargs.update( + app_model=ctx.app, + caller=ctx.caller, + caller_kind=ctx.caller_kind, + ) + return view(*args, **kwargs) + + return decorated + + return decorator diff --git a/api/controllers/openapi/auth/steps.py b/api/controllers/openapi/auth/steps.py new file mode 100644 index 0000000000..bf64cc5472 --- /dev/null +++ b/api/controllers/openapi/auth/steps.py @@ -0,0 +1,112 @@ +"""Pipeline steps. Each is one responsibility. + +BearerCheck is the only step that touches the token registry; downstream +steps see only the populated Context. +""" +from __future__ import annotations + +from typing import Callable + +from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized + +from controllers.openapi.auth.context import Context +from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter +from extensions.ext_database import db +from libs.oauth_bearer import TokenExpired, get_authenticator, sha256_hex +from models import App, Tenant, TenantStatus + + +def _registry(): + return get_authenticator()._registry # noqa: SLF001 + + +def _extract_bearer(req) -> str | None: + auth = req.headers.get("Authorization") + if not auth or not auth.lower().startswith("bearer "): + return None + return auth.split(None, 1)[1].strip() or None + + +def _hash_token(token: str) -> str: + return sha256_hex(token) + + +class BearerCheck: + """Resolve bearer → populate identity fields.""" + + def __call__(self, ctx: Context) -> None: + token = _extract_bearer(ctx.request) + if not token: + raise Unauthorized("bearer required") + + kind = _registry().find(token) + if kind is None: + raise Unauthorized("invalid bearer prefix") + + try: + row = kind.resolver.resolve(_hash_token(token)) + except TokenExpired: + raise Unauthorized("token expired") + if row is None: + raise Unauthorized("invalid bearer") + + ctx.subject_type = kind.subject_type + ctx.subject_email = row.subject_email + ctx.subject_issuer = row.subject_issuer + ctx.account_id = row.account_id + ctx.scopes = kind.scopes + ctx.source = kind.source + ctx.token_id = row.token_id + ctx.expires_at = row.expires_at + + +class ScopeCheck: + """Verify ctx.scopes (already populated by BearerCheck) covers required.""" + + def __call__(self, ctx: Context) -> None: + if "full" in ctx.scopes or ctx.required_scope in ctx.scopes: + return + raise Forbidden("insufficient_scope") + + +class AppResolver: + """Read app_id from request.view_args, populate ctx.app + ctx.tenant. + + Every endpoint using APP_PIPELINE must declare ```` in + its route — that is the design lock-in (no body / header coupling). + """ + + def __call__(self, ctx: Context) -> None: + app_id = (ctx.request.view_args or {}).get("app_id") + if not app_id: + raise BadRequest("app_id is required in path") + app = db.session.get(App, app_id) + if not app or app.status != "normal": + raise NotFound("app not found") + if not app.enable_api: + raise Forbidden("service_api_disabled") + tenant = db.session.get(Tenant, app.tenant_id) + if tenant is None or tenant.status == TenantStatus.ARCHIVE: + raise Forbidden("workspace unavailable") + ctx.app, ctx.tenant = app, tenant + + +class AppAuthzCheck: + def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None: + self._resolve = resolve_strategy + + def __call__(self, ctx: Context) -> None: + if not self._resolve().authorize(ctx): + raise Forbidden("subject_no_app_access") + + +class CallerMount: + def __init__(self, *mounters: CallerMounter) -> None: + self._mounters = mounters + + def __call__(self, ctx: Context) -> None: + for m in self._mounters: + if m.applies_to(ctx.subject_type): + m.mount(ctx) + return + raise Unauthorized("no caller mounter for subject type") diff --git a/api/controllers/openapi/auth/strategies.py b/api/controllers/openapi/auth/strategies.py new file mode 100644 index 0000000000..cda2e2ae51 --- /dev/null +++ b/api/controllers/openapi/auth/strategies.py @@ -0,0 +1,103 @@ +"""Strategy classes for the openapi auth pipeline. + +App authorization (Acl/Membership) and caller mounting (Account/EndUser) +vary along independent axes; each strategy is one class so the pipeline +composition stays a flat list. +""" +from __future__ import annotations + +from typing import Protocol + +from flask import current_app +from flask_login import user_logged_in +from sqlalchemy import select + +from controllers.openapi.auth.context import Context +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from libs.oauth_bearer import SubjectType +from models import Account, TenantAccountJoin +from services.end_user_service import EndUserService +from services.enterprise.enterprise_service import EnterpriseService + + +class AppAuthzStrategy(Protocol): + def authorize(self, ctx: Context) -> bool: ... + + +class AclStrategy: + """Per-app ACL via the workspace-auth inner API. + + Used when webapp-auth is enabled (EE deployment). The inner-API + allowlist is the source of truth. + """ + + def authorize(self, ctx: Context) -> bool: + return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( + user_id=ctx.subject_email, + app_id=ctx.app.id, + ) + + +class MembershipStrategy: + """Tenant-membership fallback. + + Used when webapp-auth is disabled (CE deployment). Account-bearing + subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is + denied (it requires the webapp-auth surface). + """ + + def authorize(self, ctx: Context) -> bool: + if ctx.subject_type == SubjectType.EXTERNAL_SSO: + return False + return _has_tenant_membership(ctx.account_id, ctx.tenant.id) + + +def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool: + if not account_id: + return False + row = db.session.execute( + select(TenantAccountJoin.id).where( + TenantAccountJoin.tenant_id == tenant_id, + TenantAccountJoin.account_id == account_id, + ) + ).scalar_one_or_none() + return row is not None + + +def _login_as(user) -> None: + """Set Flask-Login request user so downstream services see the caller.""" + current_app.login_manager._update_request_context_with_user(user) # noqa: SLF001 + user_logged_in.send(current_app._get_current_object(), user=user) # noqa: SLF001 + + +class CallerMounter(Protocol): + def applies_to(self, subject_type: SubjectType) -> bool: ... + + def mount(self, ctx: Context) -> None: ... + + +class AccountMounter: + def applies_to(self, st: SubjectType) -> bool: + return st == SubjectType.ACCOUNT + + def mount(self, ctx: Context) -> None: + account = db.session.get(Account, ctx.account_id) + account.current_tenant = ctx.tenant + _login_as(account) + ctx.caller, ctx.caller_kind = account, "account" + + +class EndUserMounter: + def applies_to(self, st: SubjectType) -> bool: + return st == SubjectType.EXTERNAL_SSO + + def mount(self, ctx: Context) -> None: + end_user = EndUserService.get_or_create_end_user_by_type( + InvokeFrom.OPENAPI, + tenant_id=ctx.tenant.id, + app_id=ctx.app.id, + user_id=ctx.subject_email, + ) + _login_as(end_user) + ctx.caller, ctx.caller_kind = end_user, "end_user" diff --git a/api/controllers/openapi/chat_messages.py b/api/controllers/openapi/chat_messages.py new file mode 100644 index 0000000000..2335e59da9 --- /dev/null +++ b/api/controllers/openapi/chat_messages.py @@ -0,0 +1,167 @@ +"""POST /openapi/v1/apps//chat-messages — port of +service_api/app/completion.py:ChatApi. + +Differences from service_api: +- App is in URL path, not header. +- One decorator: @APP_PIPELINE.guard(scope="apps:run"). +- Request body has no `user` field (Model 2: identity is the bearer). +- Typed Request and Response models. +- invoke_from = InvokeFrom.OPENAPI. +""" +from __future__ import annotations + +import logging +from typing import Any, Literal +from uuid import UUID + +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field, ValidationError, field_validator +from werkzeug.exceptions import BadRequest, InternalServerError, NotFound + +import services +from controllers.openapi import openapi_ns +from controllers.openapi._audit import emit_app_run +from controllers.openapi._models import MessageMetadata +from controllers.openapi.auth.composition import APP_PIPELINE +from controllers.service_api.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + NotChatAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from graphon.model_runtime.errors.invoke import InvokeError +from libs import helper +from libs.helper import UUIDStrOrEmpty +from models.model import App, AppMode +from services.app_generate_service import AppGenerateService +from services.errors.app import ( + IsDraftWorkflowError, + WorkflowIdFormatError, + WorkflowNotFoundError, +) +from services.errors.llm import InvokeRateLimitError + +logger = logging.getLogger(__name__) + + +class ChatMessageRequest(BaseModel): + inputs: dict[str, Any] + query: str + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + conversation_id: UUIDStrOrEmpty | None = Field(default=None) + auto_generate_name: bool = Field(default=True) + workflow_id: str | None = Field(default=None) + + @field_validator("conversation_id", mode="before") + @classmethod + def normalize_conversation_id(cls, value: str | UUID | None) -> str | None: + if isinstance(value, str): + value = value.strip() + if not value: + return None + try: + return helper.uuid_value(value) + except ValueError as exc: + raise ValueError("conversation_id must be a valid UUID") from exc + + +class ChatMessageResponse(BaseModel): + event: str + task_id: str + id: str + message_id: str + conversation_id: str + mode: str + answer: str + metadata: MessageMetadata = Field(default_factory=MessageMetadata) + created_at: int + + +def _unpack_app(app_model): + return app_model + + +def _unpack_caller(caller): + return caller + + +@openapi_ns.route("/apps//chat-messages") +class ChatMessagesApi(Resource): + @APP_PIPELINE.guard(scope="apps:run") + def post(self, app_id: str, app_model: App, caller, caller_kind: str): + app = _unpack_app(app_model) + if AppMode.value_of(app.mode) not in { + AppMode.CHAT, + AppMode.AGENT_CHAT, + AppMode.ADVANCED_CHAT, + }: + raise NotChatAppError() + + body = request.get_json(silent=True) or {} + body.pop("user", None) + try: + payload = ChatMessageRequest.model_validate(body) + except ValidationError as exc: + raise BadRequest(str(exc)) + args = payload.model_dump(exclude_none=True) + streaming = payload.response_mode == "streaming" + + try: + response = AppGenerateService.generate( + app_model=app, + user=_unpack_caller(caller), + args=args, + invoke_from=InvokeFrom.OPENAPI, + streaming=streaming, + ) + except WorkflowNotFoundError as ex: + raise NotFound(str(ex)) + except (IsDraftWorkflowError, WorkflowIdFormatError) as ex: + raise BadRequest(str(ex)) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError: + raise + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + emit_app_run( + app_id=app.id, + tenant_id=app.tenant_id, + caller_kind=caller_kind, + mode=str(app.mode), + ) + + if streaming: + return helper.compact_generate_response(response) + + body_dict = response[0] if isinstance(response, tuple) else response + return ChatMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200 diff --git a/api/controllers/openapi/completion_messages.py b/api/controllers/openapi/completion_messages.py new file mode 100644 index 0000000000..9085297793 --- /dev/null +++ b/api/controllers/openapi/completion_messages.py @@ -0,0 +1,124 @@ +"""POST /openapi/v1/apps//completion-messages — port of +service_api/app/completion.py:CompletionApi.""" +from __future__ import annotations + +import logging +from typing import Any, Literal + +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field, ValidationError +from werkzeug.exceptions import BadRequest, InternalServerError, NotFound + +import services +from controllers.openapi import openapi_ns +from controllers.openapi._audit import emit_app_run +from controllers.openapi._models import MessageMetadata +from controllers.openapi.auth.composition import APP_PIPELINE +from controllers.service_api.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from core.app.entities.app_invoke_entities import InvokeFrom +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from graphon.model_runtime.errors.invoke import InvokeError +from libs import helper +from models.model import App, AppMode +from services.app_generate_service import AppGenerateService + +logger = logging.getLogger(__name__) + + +class CompletionMessageRequest(BaseModel): + inputs: dict[str, Any] + query: str = Field(default="") + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + + +class CompletionMessageResponse(BaseModel): + event: str + task_id: str + id: str + message_id: str + mode: str + answer: str + metadata: MessageMetadata = Field(default_factory=MessageMetadata) + created_at: int + + +def _unpack_app(app_model): + return app_model + + +def _unpack_caller(caller): + return caller + + +@openapi_ns.route("/apps//completion-messages") +class CompletionMessagesApi(Resource): + @APP_PIPELINE.guard(scope="apps:run") + def post(self, app_id: str, app_model: App, caller, caller_kind: str): + app = _unpack_app(app_model) + if AppMode.value_of(app.mode) != AppMode.COMPLETION: + raise AppUnavailableError() + + body = request.get_json(silent=True) or {} + body.pop("user", None) + try: + payload = CompletionMessageRequest.model_validate(body) + except ValidationError as exc: + raise BadRequest(str(exc)) + args = payload.model_dump(exclude_none=True) + args["auto_generate_name"] = False + streaming = payload.response_mode == "streaming" + + try: + response = AppGenerateService.generate( + app_model=app, + user=_unpack_caller(caller), + args=args, + invoke_from=InvokeFrom.OPENAPI, + streaming=streaming, + ) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError: + raise + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + emit_app_run( + app_id=app.id, + tenant_id=app.tenant_id, + caller_kind=caller_kind, + mode=str(app.mode), + ) + + if streaming: + return helper.compact_generate_response(response) + + body_dict = response[0] if isinstance(response, tuple) else response + return CompletionMessageResponse.model_validate(body_dict).model_dump(mode="json"), 200 diff --git a/api/controllers/openapi/workflow_run.py b/api/controllers/openapi/workflow_run.py new file mode 100644 index 0000000000..d76ff553de --- /dev/null +++ b/api/controllers/openapi/workflow_run.py @@ -0,0 +1,132 @@ +"""POST /openapi/v1/apps//workflows/run — port of +service_api/app/workflow.py:WorkflowRunApi.""" +from __future__ import annotations + +import logging +from typing import Any, Literal + +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, ValidationError +from werkzeug.exceptions import BadRequest, InternalServerError, NotFound + +from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase +from controllers.openapi import openapi_ns +from controllers.openapi._audit import emit_app_run +from controllers.openapi.auth.composition import APP_PIPELINE +from controllers.service_api.app.error import ( + CompletionRequestError, + NotWorkflowAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from graphon.model_runtime.errors.invoke import InvokeError +from libs import helper +from models.model import App, AppMode +from services.app_generate_service import AppGenerateService +from services.errors.app import ( + IsDraftWorkflowError, + WorkflowIdFormatError, + WorkflowNotFoundError, +) +from services.errors.llm import InvokeRateLimitError + +logger = logging.getLogger(__name__) + + +class WorkflowRunRequest(WorkflowRunPayloadBase): + response_mode: Literal["blocking", "streaming"] | None = None + + +class WorkflowRunData(BaseModel): + id: str + workflow_id: str + status: str + outputs: dict[str, Any] = {} + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + + +class WorkflowRunResponse(BaseModel): + workflow_run_id: str + task_id: str + data: WorkflowRunData + + +def _unpack_app(app_model): + return app_model + + +def _unpack_caller(caller): + return caller + + +@openapi_ns.route("/apps//workflows/run") +class WorkflowRunApi(Resource): + @APP_PIPELINE.guard(scope="apps:run") + def post(self, app_id: str, app_model: App, caller, caller_kind: str): + app = _unpack_app(app_model) + if AppMode.value_of(app.mode) != AppMode.WORKFLOW: + raise NotWorkflowAppError() + + body = request.get_json(silent=True) or {} + body.pop("user", None) + try: + payload = WorkflowRunRequest.model_validate(body) + except ValidationError as exc: + raise BadRequest(str(exc)) + args = payload.model_dump(exclude_none=True) + streaming = payload.response_mode == "streaming" + + try: + response = AppGenerateService.generate( + app_model=app, + user=_unpack_caller(caller), + args=args, + invoke_from=InvokeFrom.OPENAPI, + streaming=streaming, + ) + except WorkflowNotFoundError as ex: + raise NotFound(str(ex)) + except (IsDraftWorkflowError, WorkflowIdFormatError) as ex: + raise BadRequest(str(ex)) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError: + raise + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + emit_app_run( + app_id=app.id, + tenant_id=app.tenant_id, + caller_kind=caller_kind, + mode=str(app.mode), + ) + + if streaming: + return helper.compact_generate_response(response) + + body_dict = response[0] if isinstance(response, tuple) else response + return WorkflowRunResponse.model_validate(body_dict).model_dump(mode="json"), 200 diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 09992f4bbf..0c4d184f1e 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -24,6 +24,7 @@ class UserFrom(StrEnum): class InvokeFrom(StrEnum): SERVICE_API = "service-api" + OPENAPI = "openapi" WEB_APP = "web-app" TRIGGER = "trigger" EXPLORE = "explore" diff --git a/api/tests/unit_tests/controllers/openapi/auth/__init__.py b/api/tests/unit_tests/controllers/openapi/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_composition.py b/api/tests/unit_tests/controllers/openapi/auth/test_composition.py new file mode 100644 index 0000000000..48fe5fd6aa --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_composition.py @@ -0,0 +1,49 @@ +from unittest.mock import patch + +from controllers.openapi.auth.composition import APP_PIPELINE, _resolve_app_authz_strategy +from controllers.openapi.auth.pipeline import Pipeline +from controllers.openapi.auth.steps import ( + AppAuthzCheck, + AppResolver, + BearerCheck, + CallerMount, + ScopeCheck, +) +from controllers.openapi.auth.strategies import ( + AccountMounter, + AclStrategy, + EndUserMounter, + MembershipStrategy, +) + + +def test_app_pipeline_is_composed(): + assert isinstance(APP_PIPELINE, Pipeline) + + +def test_app_pipeline_step_order(): + steps = APP_PIPELINE._steps + assert isinstance(steps[0], BearerCheck) + assert isinstance(steps[1], ScopeCheck) + assert isinstance(steps[2], AppResolver) + assert isinstance(steps[3], AppAuthzCheck) + assert isinstance(steps[4], CallerMount) + + +def test_caller_mount_has_both_mounters(): + cm = APP_PIPELINE._steps[4] + kinds = {type(m) for m in cm._mounters} + assert AccountMounter in kinds + assert EndUserMounter in kinds + + +@patch("controllers.openapi.auth.composition.FeatureService") +def test_strategy_resolver_picks_acl_when_enabled(fs): + fs.get_system_features.return_value.webapp_auth.enabled = True + assert isinstance(_resolve_app_authz_strategy(), AclStrategy) + + +@patch("controllers.openapi.auth.composition.FeatureService") +def test_strategy_resolver_picks_membership_when_disabled(fs): + fs.get_system_features.return_value.webapp_auth.enabled = False + assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_context.py b/api/tests/unit_tests/controllers/openapi/auth/test_context.py new file mode 100644 index 0000000000..46e932af04 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_context.py @@ -0,0 +1,21 @@ +from unittest.mock import MagicMock + +from controllers.openapi.auth.context import Context + + +def test_context_starts_unpopulated(): + ctx = Context(request=MagicMock(), required_scope="apps:run") + assert ctx.subject_type is None + assert ctx.subject_email is None + assert ctx.account_id is None + assert ctx.scopes == frozenset() + assert ctx.app is None + assert ctx.tenant is None + assert ctx.caller is None + assert ctx.caller_kind is None + + +def test_context_fields_are_mutable(): + ctx = Context(request=MagicMock(), required_scope="apps:run") + ctx.scopes = frozenset({"full"}) + assert "full" in ctx.scopes diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py b/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py new file mode 100644 index 0000000000..cfeaf86cfe --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py @@ -0,0 +1,61 @@ +from unittest.mock import MagicMock + +import pytest +from flask import Flask + +from controllers.openapi.auth.context import Context +from controllers.openapi.auth.pipeline import Pipeline + + +def test_run_invokes_each_step_in_order(): + calls = [] + + class S: + def __init__(self, tag): + self.tag = tag + + def __call__(self, ctx): + calls.append(self.tag) + + Pipeline(S("a"), S("b"), S("c")).run(Context(request=MagicMock(), required_scope="x")) + assert calls == ["a", "b", "c"] + + +def test_run_short_circuits_on_raise(): + calls = [] + + class Boom: + def __call__(self, ctx): + raise RuntimeError("boom") + + class Tail: + def __call__(self, ctx): + calls.append("ran") + + with pytest.raises(RuntimeError): + Pipeline(Boom(), Tail()).run(Context(request=MagicMock(), required_scope="x")) + assert calls == [] + + +def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs(): + seen = {} + + class FakeStep: + def __call__(self, ctx): + ctx.app = "APP" + ctx.caller = "CALLER" + ctx.caller_kind = "account" + + pipeline = Pipeline(FakeStep()) + + @pipeline.guard(scope="apps:run") + def handler(app_model, caller, caller_kind): + seen["app_model"] = app_model + seen["caller"] = caller + seen["caller_kind"] = caller_kind + return "ok" + + app = Flask(__name__) + with app.test_request_context("/x", method="POST"): + assert handler() == "ok" + assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"} diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py new file mode 100644 index 0000000000..4d64f4b881 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py @@ -0,0 +1,64 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + +from controllers.openapi.auth.context import Context +from controllers.openapi.auth.steps import AppResolver +from models import TenantStatus + + +def _ctx(view_args): + req = MagicMock() + req.view_args = view_args + return Context(request=req, required_scope="apps:run") + + +def _app(*, status="normal", enable_api=True): + return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api) + + +def _tenant(*, status=TenantStatus.NORMAL): + return SimpleNamespace(id="t1", status=status) + + +def test_resolver_rejects_missing_path_param(): + with pytest.raises(BadRequest): + AppResolver()(_ctx({})) + + +def test_resolver_rejects_none_view_args(): + with pytest.raises(BadRequest): + AppResolver()(_ctx(None)) + + +@patch("controllers.openapi.auth.steps.db") +def test_resolver_404_when_app_missing(db): + db.session.get.side_effect = [None] + with pytest.raises(NotFound): + AppResolver()(_ctx({"app_id": "x"})) + + +@patch("controllers.openapi.auth.steps.db") +def test_resolver_403_when_disabled(db): + db.session.get.side_effect = [_app(enable_api=False)] + with pytest.raises(Forbidden) as exc: + AppResolver()(_ctx({"app_id": "x"})) + assert "service_api_disabled" in str(exc.value.description) + + +@patch("controllers.openapi.auth.steps.db") +def test_resolver_403_when_tenant_archived(db): + db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)] + with pytest.raises(Forbidden): + AppResolver()(_ctx({"app_id": "x"})) + + +@patch("controllers.openapi.auth.steps.db") +def test_resolver_populates_app_and_tenant(db): + db.session.get.side_effect = [_app(), _tenant()] + ctx = _ctx({"app_id": "x"}) + AppResolver()(ctx) + assert ctx.app.id == "app1" + assert ctx.tenant.id == "t1" diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py new file mode 100644 index 0000000000..e1f5114446 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py @@ -0,0 +1,58 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.openapi.auth.context import Context +from controllers.openapi.auth.steps import AppAuthzCheck +from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy +from libs.oauth_bearer import SubjectType + + +def _ctx(*, subject_type, account_id="acc1"): + c = Context(request=MagicMock(), required_scope="apps:run") + c.subject_type = subject_type + c.subject_email = "alice@example.com" + c.account_id = account_id + c.app = SimpleNamespace(id="app1") + c.tenant = SimpleNamespace(id="t1") + return c + + +@patch("controllers.openapi.auth.strategies.EnterpriseService") +def test_acl_strategy_calls_inner_api(ent): + ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True + assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True + ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with( + user_id="alice@example.com", + app_id="app1", + ) + + +@patch("controllers.openapi.auth.strategies._has_tenant_membership") +def test_membership_strategy_uses_join_lookup(member): + member.return_value = True + assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True + member.assert_called_once_with("acc1", "t1") + + +def test_membership_strategy_rejects_external_sso(): + assert ( + MembershipStrategy().authorize( + _ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None) + ) + is False + ) + + +def test_app_authz_check_raises_when_strategy_denies(): + deny = SimpleNamespace(authorize=lambda c: False) + with pytest.raises(Forbidden) as exc: + AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT)) + assert "subject_no_app_access" in str(exc.value.description) + + +def test_app_authz_check_passes_when_strategy_allows(): + allow = SimpleNamespace(authorize=lambda c: True) + AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT)) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py new file mode 100644 index 0000000000..f59120686b --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py @@ -0,0 +1,58 @@ +import uuid +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Unauthorized + +from controllers.openapi.auth.context import Context +from controllers.openapi.auth.steps import BearerCheck +from libs.oauth_bearer import ResolvedRow, SubjectType + + +def _ctx(headers): + req = MagicMock() + req.headers = headers + return Context(request=req, required_scope="apps:run") + + +def test_bearer_check_rejects_missing_header(): + with pytest.raises(Unauthorized): + BearerCheck()(_ctx({})) + + +@patch("controllers.openapi.auth.steps._registry") +def test_bearer_check_rejects_unknown_prefix(reg): + reg.return_value.find.return_value = None + with pytest.raises(Unauthorized): + BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"})) + + +@patch("controllers.openapi.auth.steps._registry") +def test_bearer_check_populates_context(reg): + tok_id = uuid.uuid4() + fake_resolver = MagicMock() + fake_resolver.resolve.return_value = ResolvedRow( + subject_email="a@x.com", + subject_issuer=None, + account_id=None, + token_id=tok_id, + expires_at=datetime.now(UTC), + ) + fake_kind = SimpleNamespace( + subject_type=SubjectType.ACCOUNT, + scopes=frozenset({"full"}), + source="oauth-account", + resolver=fake_resolver, + ) + reg.return_value.find.return_value = fake_kind + + ctx = _ctx({"Authorization": "Bearer dfoa_abc"}) + BearerCheck()(ctx) + + assert ctx.subject_type == SubjectType.ACCOUNT + assert ctx.subject_email == "a@x.com" + assert ctx.scopes == frozenset({"full"}) + assert ctx.source == "oauth-account" + assert ctx.token_id == tok_id diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py new file mode 100644 index 0000000000..e3a4c6675b --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py @@ -0,0 +1,77 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Unauthorized + +from controllers.openapi.auth.context import Context +from controllers.openapi.auth.steps import CallerMount +from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter +from core.app.entities.app_invoke_entities import InvokeFrom +from libs.oauth_bearer import SubjectType + + +def _ctx(*, subject_type, account_id=None, subject_email=None): + c = Context(request=MagicMock(), required_scope="apps:run") + c.subject_type = subject_type + c.account_id = account_id + c.subject_email = subject_email + c.app = SimpleNamespace(id="app1") + c.tenant = SimpleNamespace(id="t1") + return c + + +@patch("controllers.openapi.auth.strategies._login_as") +@patch("controllers.openapi.auth.strategies.db") +def test_account_mounter(db, login): + account = SimpleNamespace() + db.session.get.return_value = account + ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1") + AccountMounter().mount(ctx) + assert ctx.caller is account + assert ctx.caller.current_tenant is ctx.tenant + assert ctx.caller_kind == "account" + login.assert_called_once_with(account) + + +@patch("controllers.openapi.auth.strategies._login_as") +@patch("controllers.openapi.auth.strategies.EndUserService") +def test_end_user_mounter(svc, login): + eu = SimpleNamespace() + svc.get_or_create_end_user_by_type.return_value = eu + ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com") + EndUserMounter().mount(ctx) + svc.get_or_create_end_user_by_type.assert_called_once_with( + InvokeFrom.OPENAPI, + tenant_id="t1", + app_id="app1", + user_id="a@x.com", + ) + assert ctx.caller is eu + assert ctx.caller_kind == "end_user" + + +def test_caller_mount_dispatches_by_subject_type(): + seen = {} + + class Fake: + def __init__(self, st, tag): + self._st, self._tag = st, tag + + def applies_to(self, st): + return st == self._st + + def mount(self, ctx): + seen["who"] = self._tag + + cm = CallerMount( + Fake(SubjectType.ACCOUNT, "acct"), + Fake(SubjectType.EXTERNAL_SSO, "sso"), + ) + cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO)) + assert seen == {"who": "sso"} + + +def test_caller_mount_raises_when_none_applies(): + with pytest.raises(Unauthorized): + CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT)) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py new file mode 100644 index 0000000000..6e3044d73f --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py @@ -0,0 +1,27 @@ +from unittest.mock import MagicMock + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.openapi.auth.context import Context +from controllers.openapi.auth.steps import ScopeCheck + + +def _ctx(scopes, required): + c = Context(request=MagicMock(), required_scope=required) + c.scopes = frozenset(scopes) + return c + + +def test_scope_check_passes_on_full(): + ScopeCheck()(_ctx({"full"}, "apps:run")) + + +def test_scope_check_passes_on_explicit_match(): + ScopeCheck()(_ctx({"apps:run"}, "apps:run")) + + +def test_scope_check_rejects_when_missing(): + with pytest.raises(Forbidden) as exc: + ScopeCheck()(_ctx({"apps:read"}, "apps:run")) + assert "insufficient_scope" in str(exc.value.description) diff --git a/api/tests/unit_tests/controllers/openapi/conftest.py b/api/tests/unit_tests/controllers/openapi/conftest.py new file mode 100644 index 0000000000..42e3768a18 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/conftest.py @@ -0,0 +1,14 @@ +import pytest + +from controllers.openapi.auth.pipeline import Pipeline + + +@pytest.fixture +def bypass_pipeline(monkeypatch): + """Stub Pipeline.run so endpoint decoration does not invoke real auth. + + Module-level @APP_PIPELINE.guard(...) captures the real APP_PIPELINE at + import time; mocking the module attribute does not undo that. Patching + Pipeline.run on the class is the bypass that actually works. + """ + monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None) diff --git a/api/tests/unit_tests/controllers/openapi/test_app_info.py b/api/tests/unit_tests/controllers/openapi/test_app_info.py new file mode 100644 index 0000000000..3aeec7f0ca --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_app_info.py @@ -0,0 +1,46 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from flask import Flask +from flask_restx import Api + + +def _client(): + from controllers.openapi import app_info # noqa: F401 + from controllers.openapi import openapi_ns + + app = Flask(__name__) + api = Api(app) + api.add_namespace(openapi_ns, path="/openapi/v1") + return app.test_client() + + +def test_app_info_returns_response_model(bypass_pipeline): + app_obj = SimpleNamespace( + id="app1", + name="X", + description="d", + mode="chat", + author_name="alice", + tags=[SimpleNamespace(name="prod")], + ) + with patch("controllers.openapi.app_info._unpack_app", return_value=app_obj): + r = _client().get("/openapi/v1/apps/app1/info") + assert r.status_code == 200 + body = r.get_json() + assert body == { + "id": "app1", + "name": "X", + "description": "d", + "mode": "chat", + "author_name": "alice", + "tags": ["prod"], + } + + +def test_app_info_response_model_validates(): + from controllers.openapi.app_info import AppInfoResponse + + m = AppInfoResponse(id="x", name="N", mode="chat") + assert m.tags == [] + assert m.description is None diff --git a/api/tests/unit_tests/controllers/openapi/test_audit_app_run.py b/api/tests/unit_tests/controllers/openapi/test_audit_app_run.py new file mode 100644 index 0000000000..a2e2539dfd --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_audit_app_run.py @@ -0,0 +1,19 @@ +import logging + +from controllers.openapi._audit import EVENT_APP_RUN_OPENAPI, emit_app_run + + +def test_event_constant(): + assert EVENT_APP_RUN_OPENAPI == "app.run.openapi" + + +def test_emit_app_run_logs_with_audit_extra(caplog): + with caplog.at_level(logging.INFO, logger="controllers.openapi._audit"): + emit_app_run(app_id="app1", tenant_id="t1", caller_kind="account", mode="chat") + record = next(r for r in caplog.records if r.message and "app.run.openapi" in r.message) + assert record.audit is True + assert record.event == EVENT_APP_RUN_OPENAPI + assert record.app_id == "app1" + assert record.tenant_id == "t1" + assert record.caller_kind == "account" + assert record.mode == "chat" diff --git a/api/tests/unit_tests/controllers/openapi/test_chat_messages.py b/api/tests/unit_tests/controllers/openapi/test_chat_messages.py new file mode 100644 index 0000000000..35ec43cdd2 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_chat_messages.py @@ -0,0 +1,89 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from flask import Flask +from flask_restx import Api + + +def _client(): + from controllers.openapi import chat_messages # noqa: F401 + from controllers.openapi import openapi_ns + + app = Flask(__name__) + api = Api(app) + api.add_namespace(openapi_ns, path="/openapi/v1") + return app.test_client() + + +@patch("controllers.openapi.chat_messages.AppGenerateService") +def test_chat_dispatches_and_returns_response_model(svc, bypass_pipeline): + svc.generate.return_value = ( + { + "event": "message", + "task_id": "tk1", + "id": "m1", + "message_id": "m1", + "conversation_id": "c1", + "mode": "chat", + "answer": "hi", + "metadata": {}, + "created_at": 1700000000, + }, + 200, + ) + fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1") + with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch( + "controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace() + ): + r = _client().post( + "/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}} + ) + assert r.status_code == 200 + body = r.get_json() + assert body["conversation_id"] == "c1" + assert body["answer"] == "hi" + assert svc.generate.call_args.kwargs["invoke_from"].value == "openapi" + + +@patch("controllers.openapi.chat_messages.AppGenerateService") +def test_chat_strips_user_field_from_body(svc, bypass_pipeline): + svc.generate.return_value = ( + { + "event": "message", + "task_id": "tk1", + "id": "m1", + "message_id": "m1", + "conversation_id": "c1", + "mode": "chat", + "answer": "hi", + "metadata": {}, + "created_at": 1700000000, + }, + 200, + ) + fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1") + with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake), patch( + "controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace() + ): + _client().post( + "/openapi/v1/apps/app1/chat-messages", + json={"query": "hi", "inputs": {}, "user": "spoof@x.com"}, + ) + args = svc.generate.call_args.kwargs["args"] + assert "user" not in args + + +def test_chat_rejects_non_chat_mode(bypass_pipeline): + fake = SimpleNamespace(mode="completion") + with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake): + r = _client().post( + "/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}} + ) + assert r.status_code in (400, 403) + + +def test_chat_rejects_invalid_body(bypass_pipeline): + fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1") + with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake): + r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi"}) + assert r.status_code in (400, 422) diff --git a/api/tests/unit_tests/controllers/openapi/test_completion_messages.py b/api/tests/unit_tests/controllers/openapi/test_completion_messages.py new file mode 100644 index 0000000000..84fe214a26 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_completion_messages.py @@ -0,0 +1,54 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from flask import Flask +from flask_restx import Api + + +def _client(): + from controllers.openapi import completion_messages # noqa: F401 + from controllers.openapi import openapi_ns + + app = Flask(__name__) + api = Api(app) + api.add_namespace(openapi_ns, path="/openapi/v1") + return app.test_client() + + +@patch("controllers.openapi.completion_messages.AppGenerateService") +def test_completion_returns_response_model(svc, bypass_pipeline): + svc.generate.return_value = ( + { + "event": "message", + "task_id": "tk", + "id": "m1", + "message_id": "m1", + "mode": "completion", + "answer": "ok", + "metadata": {}, + "created_at": 1700000000, + }, + 200, + ) + fake = SimpleNamespace(mode="completion", id="app1", tenant_id="t1") + with patch("controllers.openapi.completion_messages._unpack_app", return_value=fake), patch( + "controllers.openapi.completion_messages._unpack_caller", return_value=SimpleNamespace() + ): + r = _client().post( + "/openapi/v1/apps/app1/completion-messages", + json={"inputs": {"x": 1}, "query": "hi"}, + ) + assert r.status_code == 200 + body = r.get_json() + assert body["answer"] == "ok" + assert svc.generate.call_args.kwargs["invoke_from"].value == "openapi" + + +def test_completion_rejects_chat_mode(bypass_pipeline): + fake = SimpleNamespace(mode="chat") + with patch("controllers.openapi.completion_messages._unpack_app", return_value=fake): + r = _client().post( + "/openapi/v1/apps/app1/completion-messages", + json={"inputs": {}, "query": "hi"}, + ) + assert r.status_code in (400, 403) diff --git a/api/tests/unit_tests/controllers/openapi/test_models.py b/api/tests/unit_tests/controllers/openapi/test_models.py new file mode 100644 index 0000000000..5cca6131cc --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_models.py @@ -0,0 +1,14 @@ +from controllers.openapi._models import MessageMetadata, UsageInfo + + +def test_usage_info_defaults_zero(): + u = UsageInfo() + assert u.prompt_tokens == 0 + assert u.completion_tokens == 0 + assert u.total_tokens == 0 + + +def test_message_metadata_accepts_partial(): + m = MessageMetadata(usage=UsageInfo(total_tokens=10)) + assert m.usage.total_tokens == 10 + assert m.retriever_resources == [] diff --git a/api/tests/unit_tests/controllers/openapi/test_workflow_run.py b/api/tests/unit_tests/controllers/openapi/test_workflow_run.py new file mode 100644 index 0000000000..ce0114a507 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_workflow_run.py @@ -0,0 +1,50 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from flask import Flask +from flask_restx import Api + + +def _client(): + from controllers.openapi import openapi_ns + from controllers.openapi import workflow_run # noqa: F401 + + app = Flask(__name__) + api = Api(app) + api.add_namespace(openapi_ns, path="/openapi/v1") + return app.test_client() + + +@patch("controllers.openapi.workflow_run.AppGenerateService") +def test_workflow_run_returns_response_model(svc, bypass_pipeline): + svc.generate.return_value = ( + { + "workflow_run_id": "wr1", + "task_id": "tk", + "data": { + "id": "wr1", + "workflow_id": "wf1", + "status": "succeeded", + "outputs": {"result": "ok"}, + "elapsed_time": 1.0, + }, + }, + 200, + ) + fake = SimpleNamespace(mode="workflow", id="app1", tenant_id="t1") + with patch("controllers.openapi.workflow_run._unpack_app", return_value=fake), patch( + "controllers.openapi.workflow_run._unpack_caller", return_value=SimpleNamespace() + ): + r = _client().post("/openapi/v1/apps/app1/workflows/run", json={"inputs": {"x": 1}}) + assert r.status_code == 200 + body = r.get_json() + assert body["workflow_run_id"] == "wr1" + assert body["data"]["status"] == "succeeded" + assert svc.generate.call_args.kwargs["invoke_from"].value == "openapi" + + +def test_workflow_run_rejects_non_workflow(bypass_pipeline): + fake = SimpleNamespace(mode="chat") + with patch("controllers.openapi.workflow_run._unpack_app", return_value=fake): + r = _client().post("/openapi/v1/apps/app1/workflows/run", json={"inputs": {}}) + assert r.status_code in (400, 403) diff --git a/api/tests/unit_tests/core/app/test_invoke_from.py b/api/tests/unit_tests/core/app/test_invoke_from.py new file mode 100644 index 0000000000..e0a8344d2f --- /dev/null +++ b/api/tests/unit_tests/core/app/test_invoke_from.py @@ -0,0 +1,9 @@ +from core.app.entities.app_invoke_entities import InvokeFrom + + +def test_openapi_variant_present(): + assert InvokeFrom.OPENAPI.value == "openapi" + + +def test_openapi_distinct_from_service_api(): + assert InvokeFrom.OPENAPI != InvokeFrom.SERVICE_API