mirror of
https://github.com/langgenius/dify.git
synced 2026-06-24 13:01:16 +08:00
Merge branch 'main' into feat/refine-snippet-siderbar
This commit is contained in:
commit
ce9d1c74af
@ -36,6 +36,7 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow.
|
||||
- Do not replace prop drilling with one top-level hook that returns a large view model and then thread that object through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it, or use feature-scoped Jotai atoms for simple shared form/UI state when siblings need the same source of truth.
|
||||
- When using feature-scoped Jotai state for a form, drawer, or other secondary surface, scope the store to that surface instance when stale cross-instance state is possible. Initialize stable config at the owning boundary, then let descendants read only the atoms or purpose-named hooks they actually need.
|
||||
- For Jotai-backed surfaces, put shared query atoms, mutation atoms, derived state, and write actions in the feature state file when they coordinate multiple descendants. The lowest-owner rule still applies to independent visual surfaces that do not participate in shared state.
|
||||
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action.
|
||||
- Prefer uncontrolled DOM state and CSS variables before adding controlled props.
|
||||
|
||||
@ -45,8 +46,10 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Atom order in the state file follows the dependency graph: types/constants, editable primitives, query atoms, query-data derived atoms, readiness/business derived atoms, write actions, mutation atoms, submission orchestration, provider exports.
|
||||
- Derived atom names read as business facts. Write atom names read as user or workflow commands.
|
||||
- UI components read and write the exact atom they use with `useAtomValue` or `useSetAtom`. Repeated workflow semantics live in named derived atoms or write atoms.
|
||||
- Non-query derived atoms return a narrow value with a clear domain name. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract.
|
||||
- Write-only atoms own state transitions that update multiple primitives, reset dependent state, guard stale async work, or advance the workflow.
|
||||
- Non-query derived atoms return a narrow value with a clear domain name; avoid pass-through aliases or bundling unrelated UI facts. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract.
|
||||
- Write-only atoms own synchronous state transitions that update multiple primitives, reset dependent state, or advance the workflow. Async work with loading, error, caching, retry, or stale-result concerns should be modeled as query or mutation atoms, with write atoms only changing the inputs that drive them.
|
||||
- Avoid feature hooks that aggregate form values, query results, derived state, and commands for sibling components. Prefer named derived atoms and write atoms so UI components read the exact shared fact or command they need.
|
||||
- When a form library owns validation, keep submit orchestration in feature state when post-submit result or error state is shared by the surface. Avoid duplicating validation gates or request shaping in UI hooks.
|
||||
- `jotai-tanstack-query` atoms use the same QueryClient as the React Query provider. Query atoms belong in feature state when atoms are the feature's local state surface.
|
||||
- Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query atoms keep shared cache behavior through the shared QueryClient.
|
||||
|
||||
@ -108,7 +111,7 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow.
|
||||
- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment.
|
||||
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
|
||||
- Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook and forwards every returned field to one child, move the hook into that child or make the wrapper own a real surface.
|
||||
- Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, children-as-pass-through composition, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook, forwards props, or passes trigger/content through to one child, move the logic into that child or make the wrapper own a real surface.
|
||||
|
||||
## You Might Not Need An Effect
|
||||
|
||||
|
||||
@ -768,7 +768,6 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while use redis as event bus.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
|
||||
@ -2,7 +2,6 @@ from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic.types import NonNegativeInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -71,24 +70,6 @@ class RedisPubSubConfig(BaseSettings):
|
||||
default=600,
|
||||
)
|
||||
|
||||
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
|
||||
description=(
|
||||
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
|
||||
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
|
||||
"an SSE client and the response stream actually closing.\n\n"
|
||||
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
|
||||
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
|
||||
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
|
||||
"return promptly while the daemon listener thread cleans itself up on the next poll "
|
||||
"boundary - safe because the listener holds no critical state and exits within one poll "
|
||||
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
|
||||
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
|
||||
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
|
||||
),
|
||||
default=2000,
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
|
||||
107
api/controllers/common/app_access.py
Normal file
107
api/controllers/common/app_access.py
Normal file
@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.app_service import AppListBaseParams
|
||||
from services.enterprise.rbac_service import MyPermissionsResponse
|
||||
|
||||
# Permission keys (dot-notation, from MyPermissionsResponse) that grant
|
||||
# list/preview access to an app. Keep this the single source of truth for both
|
||||
# the console and OpenAPI app-list endpoints.
|
||||
APP_LIST_PERMISSION_KEYS: frozenset[str] = frozenset({"app.preview", "app.acl.preview", "app.full_access"})
|
||||
|
||||
# Workspace permission key that lets a caller see apps they maintain even when
|
||||
# those apps are not in their preview whitelist.
|
||||
_MANAGE_OWN_APPS_PERMISSION_KEY = "app.create_and_management"
|
||||
|
||||
|
||||
def has_app_list_permission(permission_keys: Sequence[str]) -> bool:
|
||||
"""Return True if any of ``permission_keys`` grants app list/preview access."""
|
||||
return any(permission_key in APP_LIST_PERMISSION_KEYS for permission_key in permission_keys)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AppAccessFilter:
|
||||
"""Resolved RBAC visibility for app list/read endpoints.
|
||||
|
||||
``accessible_app_ids`` of ``None`` means the caller can see every app in the
|
||||
workspace (unrestricted). Otherwise it is the exact set of app ids the
|
||||
caller may preview; combined with ``can_manage_own_apps`` it also covers
|
||||
apps the caller maintains.
|
||||
"""
|
||||
|
||||
accessible_app_ids: set[str] | None
|
||||
can_manage_own_apps: bool
|
||||
|
||||
@classmethod
|
||||
def unrestricted(cls) -> AppAccessFilter:
|
||||
"""Filter that imposes no restriction (RBAC disabled / not applicable)."""
|
||||
return cls(accessible_app_ids=None, can_manage_own_apps=False)
|
||||
|
||||
def is_app_accessible(self, app_id: str, maintainer: str | None, account_id: str) -> bool:
|
||||
"""Whether a single app is visible to the caller under this filter.
|
||||
|
||||
Mirrors the service-layer query gate: an app is visible when the filter
|
||||
is unrestricted, the app id is whitelisted, or the caller maintains it
|
||||
and holds ``app.create_and_management``.
|
||||
"""
|
||||
if self.accessible_app_ids is None:
|
||||
return True
|
||||
if app_id in self.accessible_app_ids:
|
||||
return True
|
||||
return self.can_manage_own_apps and maintainer is not None and maintainer == account_id
|
||||
|
||||
def apply_to_params(self, params: AppListBaseParams) -> None:
|
||||
if self.accessible_app_ids is None:
|
||||
return
|
||||
params.accessible_app_ids = sorted(self.accessible_app_ids)
|
||||
params.include_own_apps = self.can_manage_own_apps
|
||||
|
||||
|
||||
def resolve_app_access_filter(
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
*,
|
||||
permissions: MyPermissionsResponse | None = None,
|
||||
) -> AppAccessFilter:
|
||||
"""Compute the RBAC app-access filter for ``account_id`` in ``tenant_id``.
|
||||
|
||||
Pass ``permissions`` when the caller has already fetched the snapshot (the
|
||||
console controller reuses it for per-app permission keys) to avoid a second
|
||||
inner-API round trip; otherwise it is fetched here.
|
||||
"""
|
||||
if permissions is None:
|
||||
permissions = enterprise_rbac_service.RBACService.MyPermissions.get(tenant_id, account_id)
|
||||
whitelist_scope = enterprise_rbac_service.RBACService.AppAccess.whitelist_resources(tenant_id, account_id)
|
||||
|
||||
can_manage_own_apps = _MANAGE_OWN_APPS_PERMISSION_KEY in permissions.workspace.permission_keys
|
||||
has_default_preview = has_app_list_permission(permissions.app.default_permission_keys) or has_app_list_permission(
|
||||
permissions.workspace.permission_keys
|
||||
)
|
||||
|
||||
permission_app_ids: set[str] | None = None
|
||||
if not has_default_preview:
|
||||
# Collect apps the caller can preview via per-app permission overrides.
|
||||
permission_app_ids = {
|
||||
override.resource_id
|
||||
for override in permissions.app.overrides
|
||||
if has_app_list_permission(override.permission_keys)
|
||||
}
|
||||
|
||||
accessible_app_ids: set[str] | None
|
||||
if getattr(whitelist_scope, "unrestricted", False):
|
||||
accessible_app_ids = permission_app_ids
|
||||
else:
|
||||
accessible_app_ids = set(whitelist_scope.resource_ids)
|
||||
if permission_app_ids is not None:
|
||||
accessible_app_ids |= permission_app_ids
|
||||
elif has_default_preview:
|
||||
# Default preview overrides the whitelist restriction.
|
||||
accessible_app_ids = None
|
||||
|
||||
return AppAccessFilter(accessible_app_ids=accessible_app_ids, can_manage_own_apps=can_manage_own_apps)
|
||||
@ -1,23 +1,3 @@
|
||||
"""Shared decorator utilities for Dify controller layers.
|
||||
|
||||
This module provides decorators that are not tied to any single API group (e.g.
|
||||
console, inner, service). Currently it exposes the RBAC permission gate, which
|
||||
can be applied to any blueprint.
|
||||
|
||||
Key exports
|
||||
-----------
|
||||
``rbac_permission_required`` – decorator that enforces enterprise RBAC access
|
||||
control. When ``RBAC_ENABLED`` is ``False`` it is a no-op.
|
||||
|
||||
``RBACPermission``, ``RBACResourceScope`` – re-exported from ``core.rbac`` so
|
||||
callers only need a single import site.
|
||||
|
||||
Private helpers
|
||||
---------------
|
||||
``_extract_resource_id``, ``_is_resource_owned_by_current_user`` – kept module-
|
||||
private but accessible via the module namespace for unit-test patching.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
@ -32,7 +12,57 @@ from models.dataset import Dataset
|
||||
from models.model import App
|
||||
from services.enterprise.rbac_service import RBACService
|
||||
|
||||
__all__ = ["RBACPermission", "RBACResourceScope", "rbac_permission_required"]
|
||||
__all__ = ["RBACPermission", "RBACResourceScope", "enforce_rbac_access", "rbac_permission_required"]
|
||||
|
||||
|
||||
def enforce_rbac_access(
|
||||
*,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
resource_type: RBACResourceScope,
|
||||
scene: RBACPermission,
|
||||
resource_required: bool = True,
|
||||
path_args: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
"""Enforce enterprise RBAC for an explicit account/tenant pair.
|
||||
|
||||
This is the flask-login-independent core of the RBAC gate so it can run
|
||||
inside request-handling layers that resolve the caller themselves (e.g. the
|
||||
openapi auth pipeline, which has the account on ``AuthData`` before
|
||||
flask-login is mounted).
|
||||
|
||||
No-op when ``RBAC_ENABLED`` is ``False``. For resource-scoped checks the
|
||||
resource ID is taken from ``path_args`` merged with ``request.view_args``;
|
||||
resource ownership short-circuits the check. Raises ``Forbidden`` when
|
||||
access is denied. For workspace-level checks pass ``resource_required=False``
|
||||
so the RBAC request omits ``resource_id``.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant the access is evaluated against.
|
||||
account_id: The account requesting access.
|
||||
resource_type: The :class:`RBACResourceScope` member (app/dataset/workspace).
|
||||
scene: The :class:`RBACPermission` permission point, e.g. ``RBACPermission.APP_DELETE``.
|
||||
resource_required: Whether a concrete resource ID is required.
|
||||
path_args: Extra path arguments to merge with ``request.view_args``.
|
||||
"""
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
return
|
||||
|
||||
check_resource_type = None if resource_type == RBACResourceScope.WORKSPACE else resource_type
|
||||
resource_id = None
|
||||
if resource_required and check_resource_type:
|
||||
resource_id = _extract_resource_id(resource_type, path_args)
|
||||
if _is_resource_owned_by_current_user(tenant_id, account_id, resource_type, resource_id):
|
||||
return
|
||||
allowed = RBACService.CheckAccess.check(
|
||||
tenant_id,
|
||||
account_id,
|
||||
scene=scene,
|
||||
resource_type=check_resource_type,
|
||||
resource_id=resource_id,
|
||||
)
|
||||
if not allowed:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def rbac_permission_required[**P, R](
|
||||
@ -41,14 +71,12 @@ def rbac_permission_required[**P, R](
|
||||
*,
|
||||
resource_required: bool = True,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Check enterprise RBAC permissions for the current user.
|
||||
"""Check enterprise RBAC permissions for the current flask-login user.
|
||||
|
||||
When ``RBAC_ENABLED`` is ``False`` the decorator is a no-op and the
|
||||
request passes through unchanged. When enabled it extracts the resource ID
|
||||
from ``request.view_args`` for resource-scoped checks, calls the RBAC
|
||||
service ``check-access`` endpoint, and raises ``Forbidden`` if the access
|
||||
is denied. For workspace-level checks, set ``resource_required=False`` so
|
||||
the RBAC request omits ``resource_id``.
|
||||
request passes through unchanged. When enabled it resolves the current
|
||||
account/tenant and delegates to :func:`enforce_rbac_access`, raising
|
||||
``Forbidden`` if access is denied.
|
||||
|
||||
Args:
|
||||
resource_type: The :class:`RBACResourceScope` member (app/dataset/workspace).
|
||||
@ -63,23 +91,14 @@ def rbac_permission_required[**P, R](
|
||||
return view(*args, **kwargs)
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
check_resource_type = None if resource_type == RBACResourceScope.WORKSPACE else resource_type
|
||||
resource_id = None
|
||||
if resource_required and check_resource_type:
|
||||
resource_id = _extract_resource_id(resource_type, kwargs)
|
||||
if _is_resource_owned_by_current_user(current_tenant_id, current_user.id, resource_type, resource_id):
|
||||
return view(*args, **kwargs)
|
||||
allowed = RBACService.CheckAccess.check(
|
||||
current_tenant_id,
|
||||
current_user.id,
|
||||
enforce_rbac_access(
|
||||
tenant_id=current_tenant_id,
|
||||
account_id=current_user.id,
|
||||
resource_type=resource_type,
|
||||
scene=scene,
|
||||
resource_type=check_resource_type,
|
||||
resource_id=resource_id,
|
||||
resource_required=resource_required,
|
||||
path_args=kwargs,
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise Forbidden()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -3,10 +3,12 @@ from uuid import UUID
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.agent.app_helpers import resolve_agent_app_model
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList, BaseApiKeyListResource, BaseApiKeyResource
|
||||
from controllers.console.app.app import (
|
||||
AppDetailWithSite as GenericAppDetailWithSite,
|
||||
)
|
||||
@ -25,9 +27,13 @@ from controllers.console.app.app import (
|
||||
UpdateAppPayload as GenericUpdateAppPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
RBACPermission,
|
||||
RBACResourceScope,
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
rbac_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
@ -36,6 +42,7 @@ from extensions.ext_database import db
|
||||
from fields.agent_fields import (
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentConfigSnapshotRestoreResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentLogListResponse,
|
||||
AgentLogMessageListResponse,
|
||||
@ -48,7 +55,8 @@ from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import IconType
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App, IconType
|
||||
from services.agent.errors import AgentNotFoundError
|
||||
from services.agent.observability_service import (
|
||||
AgentLogQueryParams,
|
||||
@ -102,6 +110,27 @@ class AgentAppUpdatePayload(GenericUpdateAppPayload):
|
||||
return role
|
||||
|
||||
|
||||
class AgentApiStatusPayload(BaseModel):
|
||||
enable_api: bool = Field(..., description="Enable or disable Agent service API")
|
||||
|
||||
|
||||
class AgentApiAccessResponse(BaseModel):
|
||||
enabled: bool
|
||||
service_api_base_url: str
|
||||
streaming_only: bool = True
|
||||
chat_endpoint: str
|
||||
stop_endpoint: str
|
||||
conversations_endpoint: str
|
||||
messages_endpoint: str
|
||||
files_upload_endpoint: str
|
||||
parameters_endpoint: str
|
||||
info_endpoint: str
|
||||
meta_endpoint: str
|
||||
api_rpm: int
|
||||
api_rph: int
|
||||
api_key_count: int
|
||||
|
||||
|
||||
class AgentAppPublishedReferenceResponse(BaseModel):
|
||||
app_id: str
|
||||
app_name: str
|
||||
@ -185,6 +214,7 @@ class AgentStatisticsQuery(BaseModel):
|
||||
|
||||
class AgentAppPartial(GenericAppPartial):
|
||||
app_id: str | None = None
|
||||
debug_conversation_id: str | None = None
|
||||
role: str | None = None
|
||||
active_config_is_published: bool = False
|
||||
published_reference_count: int = 0
|
||||
@ -193,6 +223,7 @@ class AgentAppPartial(GenericAppPartial):
|
||||
|
||||
class AgentAppDetailWithSite(GenericAppDetailWithSite):
|
||||
app_id: str | None = None
|
||||
debug_conversation_id: str | None = None
|
||||
role: str | None = None
|
||||
active_config_is_published: bool = False
|
||||
|
||||
@ -207,6 +238,7 @@ register_schema_models(
|
||||
console_ns,
|
||||
AgentAppCreatePayload,
|
||||
AgentAppUpdatePayload,
|
||||
AgentApiStatusPayload,
|
||||
CopyAppPayload,
|
||||
AgentInviteOptionsQuery,
|
||||
AgentLogsQuery,
|
||||
@ -218,11 +250,13 @@ register_schema_models(
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentAppPagination,
|
||||
AgentApiAccessResponse,
|
||||
AgentAppPublishedReferenceResponse,
|
||||
AgentAppDetailWithSite,
|
||||
AgentAppPartial,
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentConfigSnapshotRestoreResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentLogListResponse,
|
||||
AgentLogMessageListResponse,
|
||||
@ -237,7 +271,7 @@ def _agent_roster_service() -> AgentRosterService:
|
||||
return AgentRosterService(db.session)
|
||||
|
||||
|
||||
def _serialize_agent_app_detail(app_model) -> dict:
|
||||
def _serialize_agent_app_detail(app_model, *, current_user: Account) -> dict:
|
||||
"""Serialize an Agent App detail using roster-only DTOs.
|
||||
|
||||
`/agent` responses are roster-shaped rather than raw app-shaped: `id`
|
||||
@ -260,6 +294,11 @@ def _serialize_agent_app_detail(app_model) -> dict:
|
||||
payload.pop("bound_agent_id", None)
|
||||
payload["app_id"] = str(app_model.id)
|
||||
payload["id"] = agent.id
|
||||
payload["debug_conversation_id"] = roster_service.get_or_create_agent_app_debug_conversation_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
agent_id=agent.id,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
payload["role"] = agent.role or ""
|
||||
payload["active_config_is_published"] = roster_service.active_config_is_published(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@ -268,7 +307,7 @@ def _serialize_agent_app_detail(app_model) -> dict:
|
||||
return payload
|
||||
|
||||
|
||||
def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
|
||||
def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str, current_user: Account) -> dict:
|
||||
"""Serialize Agent App lists with roster-shaped items.
|
||||
|
||||
Each item starts from the shared App list shape, then drops
|
||||
@ -291,6 +330,11 @@ def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
|
||||
tenant_id=tenant_id,
|
||||
agent_ids=[agent.id for agent in agents_by_app_id.values()],
|
||||
)
|
||||
debug_conversation_ids_by_agent_id = roster_service.load_or_create_agent_app_debug_conversation_ids_by_agent_id(
|
||||
tenant_id=tenant_id,
|
||||
agents=list(agents_by_app_id.values()),
|
||||
account_id=current_user.id,
|
||||
)
|
||||
payload = AgentAppPagination.model_validate(app_pagination, from_attributes=True).model_dump(mode="json")
|
||||
for item in payload["data"]:
|
||||
app_id = item["id"]
|
||||
@ -299,6 +343,7 @@ def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
|
||||
if agent:
|
||||
item["app_id"] = app_id
|
||||
item["id"] = agent.id
|
||||
item["debug_conversation_id"] = debug_conversation_ids_by_agent_id.get(agent.id)
|
||||
item["role"] = agent.role or ""
|
||||
item["active_config_is_published"] = active_config_is_published_by_agent_id.get(agent.id, False)
|
||||
published_references = published_references_by_agent_id.get(agent.id, [])
|
||||
@ -323,6 +368,38 @@ def _resolve_agent_app_model(*, tenant_id: str, agent_id: UUID):
|
||||
return resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
|
||||
|
||||
def _agent_api_key_count(app_id: str) -> int:
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == ApiTokenType.APP,
|
||||
ApiToken.app_id == app_id,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def _serialize_agent_api_access(app_model: App) -> dict:
|
||||
base_url = app_model.api_base_url
|
||||
response = AgentApiAccessResponse(
|
||||
enabled=bool(app_model.enable_api),
|
||||
service_api_base_url=base_url,
|
||||
chat_endpoint=f"{base_url}/chat-messages",
|
||||
stop_endpoint=f"{base_url}/chat-messages/{{task_id}}/stop",
|
||||
conversations_endpoint=f"{base_url}/conversations",
|
||||
messages_endpoint=f"{base_url}/messages",
|
||||
files_upload_endpoint=f"{base_url}/files/upload",
|
||||
parameters_endpoint=f"{base_url}/parameters",
|
||||
info_endpoint=f"{base_url}/info",
|
||||
meta_endpoint=f"{base_url}/meta",
|
||||
api_rpm=app_model.api_rpm or 0,
|
||||
api_rph=app_model.api_rph or 0,
|
||||
api_key_count=_agent_api_key_count(str(app_model.id)),
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
|
||||
def _agent_observability_service() -> AgentObservabilityService:
|
||||
return AgentObservabilityService(db.session)
|
||||
|
||||
@ -374,7 +451,11 @@ class AgentAppListApi(Resource):
|
||||
empty = AgentAppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json")
|
||||
|
||||
return _serialize_agent_app_pagination(app_pagination, tenant_id=current_tenant_id)
|
||||
return _serialize_agent_app_pagination(
|
||||
app_pagination,
|
||||
tenant_id=current_tenant_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[AgentAppCreatePayload.__name__])
|
||||
@console_ns.response(201, "Agent app created successfully", console_ns.models[AgentAppDetailWithSite.__name__])
|
||||
@ -399,7 +480,7 @@ class AgentAppListApi(Resource):
|
||||
)
|
||||
|
||||
app = AppService().create_app(current_tenant_id, params, current_user)
|
||||
return _serialize_agent_app_detail(app), 201
|
||||
return _serialize_agent_app_detail(app, current_user=current_user), 201
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>")
|
||||
@ -409,10 +490,11 @@ class AgentAppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
def get(self, tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return _serialize_agent_app_detail(app_model)
|
||||
return _serialize_agent_app_detail(app_model, current_user=current_user)
|
||||
|
||||
@console_ns.expect(console_ns.models[AgentAppUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Agent app updated successfully", console_ns.models[AgentAppDetailWithSite.__name__])
|
||||
@ -422,8 +504,9 @@ class AgentAppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, agent_id: UUID):
|
||||
def put(self, tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
args = AgentAppUpdatePayload.model_validate(console_ns.payload)
|
||||
args_dict: AppService.ArgsDict = {
|
||||
@ -437,7 +520,7 @@ class AgentAppApi(Resource):
|
||||
"role": args.role,
|
||||
}
|
||||
updated = AppService().update_app(app_model, args_dict)
|
||||
return _serialize_agent_app_detail(updated)
|
||||
return _serialize_agent_app_detail(updated, current_user=current_user)
|
||||
|
||||
@console_ns.response(204, "Agent app deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@ -476,7 +559,76 @@ class AgentAppCopyApi(Resource):
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
return _serialize_agent_app_detail(copied_app), 201
|
||||
return _serialize_agent_app_detail(copied_app, current_user=current_user), 201
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/api-access")
|
||||
class AgentApiAccessApi(Resource):
|
||||
@console_ns.response(200, "Agent service API access", console_ns.models[AgentApiAccessResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return _serialize_agent_api_access(app_model)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/api-enable")
|
||||
class AgentApiStatusApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AgentApiStatusPayload.__name__])
|
||||
@console_ns.response(200, "Agent service API status updated", console_ns.models[AgentApiAccessResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_RELEASE_AND_VERSION)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, agent_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
args = AgentApiStatusPayload.model_validate(console_ns.payload)
|
||||
app_model = AppService().update_app_api_status(app_model, args.enable_api)
|
||||
return _serialize_agent_api_access(app_model)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/api-keys")
|
||||
class AgentApiKeyListApi(BaseApiKeyListResource):
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
||||
@console_ns.response(200, "Agent service API keys", console_ns.models[ApiKeyList.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID) -> dict[str, object]:
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(app_model.id), tenant_id))
|
||||
|
||||
@console_ns.response(201, "Agent service API key created", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_RELEASE_AND_VERSION)
|
||||
def post(self, tenant_id: str, agent_id: UUID) -> tuple[dict[str, object], int]:
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(app_model.id), tenant_id)), 201
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/api-keys/<uuid:api_key_id>")
|
||||
class AgentApiKeyApi(BaseApiKeyResource):
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
|
||||
@console_ns.response(204, "Agent service API key deleted")
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_RELEASE_AND_VERSION)
|
||||
def delete(self, tenant_id: str, current_user: Account, agent_id: UUID, api_key_id: UUID) -> tuple[str, int]:
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
self._delete_api_key(str(app_model.id), str(api_key_id), tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/agent/invite-options")
|
||||
@ -649,3 +801,24 @@ class AgentRosterVersionDetailApi(Resource):
|
||||
version_id=str(version_id),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/versions/<uuid:version_id>/restore")
|
||||
class AgentRosterVersionRestoreApi(Resource):
|
||||
@console_ns.response(200, "Agent version restored", console_ns.models[AgentConfigSnapshotRestoreResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, current_user: Account, agent_id: UUID, version_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotRestoreResponse,
|
||||
_agent_roster_service().restore_agent_version(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
account_id=current_user.id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -10,8 +10,12 @@ backend — drive data lives in the API's own DB/storage, served straight from
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -49,6 +53,10 @@ class AgentDriveFileByAgentQuery(BaseModel):
|
||||
key: str = Field(min_length=1, description="Drive key, e.g. tender-analyzer/SKILL.md")
|
||||
|
||||
|
||||
class AgentDriveSkillInspectQuery(BaseModel):
|
||||
node_id: str | None = Field(default=None, description="Workflow node ID (workflow composer variant)")
|
||||
|
||||
|
||||
class AgentDriveItemResponse(ResponseModel):
|
||||
key: str
|
||||
size: int | None = None
|
||||
@ -56,12 +64,63 @@ class AgentDriveItemResponse(ResponseModel):
|
||||
hash: str | None = None
|
||||
file_kind: str
|
||||
created_at: int | None = None
|
||||
is_skill: bool | None = None
|
||||
skill_metadata: str | None = None
|
||||
|
||||
|
||||
class AgentDriveListResponse(ResponseModel):
|
||||
items: list[AgentDriveItemResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentDriveSkillItemResponse(ResponseModel):
|
||||
path: str
|
||||
skill_md_key: str
|
||||
archive_key: str | None = None
|
||||
name: str
|
||||
description: str
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
hash: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
|
||||
class AgentDriveSkillListResponse(ResponseModel):
|
||||
items: list[AgentDriveSkillItemResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentDriveSkillFileResponse(ResponseModel):
|
||||
path: str
|
||||
name: str
|
||||
type: str
|
||||
drive_key: str | None = None
|
||||
available_in_drive: bool
|
||||
|
||||
|
||||
class AgentDriveSkillMarkdownResponse(ResponseModel):
|
||||
key: str
|
||||
size: int | None = None
|
||||
truncated: bool
|
||||
binary: bool
|
||||
text: str | None = None
|
||||
|
||||
|
||||
class AgentDriveSkillInspectResponse(ResponseModel):
|
||||
path: str
|
||||
skill_md_key: str
|
||||
archive_key: str | None = None
|
||||
name: str
|
||||
description: str
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
hash: str | None = None
|
||||
created_at: int | None = None
|
||||
source: str
|
||||
files: list[AgentDriveSkillFileResponse] = Field(default_factory=list)
|
||||
file_tree: list[dict[str, Any]] = Field(default_factory=list)
|
||||
skill_md: AgentDriveSkillMarkdownResponse
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentDrivePreviewResponse(ResponseModel):
|
||||
key: str
|
||||
size: int | None = None
|
||||
@ -75,7 +134,12 @@ class AgentDriveDownloadResponse(ResponseModel):
|
||||
|
||||
|
||||
register_response_schema_models(
|
||||
console_ns, AgentDriveListResponse, AgentDrivePreviewResponse, AgentDriveDownloadResponse
|
||||
console_ns,
|
||||
AgentDriveDownloadResponse,
|
||||
AgentDriveListResponse,
|
||||
AgentDrivePreviewResponse,
|
||||
AgentDriveSkillInspectResponse,
|
||||
AgentDriveSkillListResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -96,6 +160,13 @@ def _handle(exc: AgentDriveError) -> tuple[dict[str, object], int]:
|
||||
return {"code": exc.code, "message": exc.message}, exc.status_code
|
||||
|
||||
|
||||
def _json_response(data: Mapping[str, Any]):
|
||||
return Response(
|
||||
response=json.dumps(data, ensure_ascii=False, separators=(",", ":")),
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
|
||||
|
||||
_WORKFLOW_APP_MODES = [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]
|
||||
|
||||
|
||||
@ -119,6 +190,49 @@ class AgentDriveListByAgentApi(Resource):
|
||||
return {"items": [{k: v for k, v in item.items() if k != "file_id"} for item in items]}
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/drive/skills")
|
||||
class AgentDriveSkillListByAgentApi(Resource):
|
||||
@console_ns.doc("list_agent_drive_skills_by_agent")
|
||||
@console_ns.doc(description="List drive-backed skills for an Agent App")
|
||||
@console_ns.doc(params={"agent_id": "Agent ID"})
|
||||
@console_ns.response(200, "Drive skills", console_ns.models[AgentDriveSkillListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
try:
|
||||
items = AgentDriveService().list_skills(tenant_id=tenant_id, agent_id=str(agent_id))
|
||||
except AgentDriveError as exc:
|
||||
return _handle(exc)
|
||||
return {"items": items}
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/drive/skills/<path:skill_path>/inspect")
|
||||
class AgentDriveSkillInspectByAgentApi(Resource):
|
||||
@console_ns.doc("inspect_agent_drive_skill_by_agent")
|
||||
@console_ns.doc(description="Inspect one drive-backed skill for slash-menu hover/detail UI")
|
||||
@console_ns.doc(params={"agent_id": "Agent ID", "skill_path": "Skill path/slug, e.g. tender-analyzer"})
|
||||
@console_ns.response(200, "Drive skill inspect view", console_ns.models[AgentDriveSkillInspectResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID, skill_path: str):
|
||||
resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
try:
|
||||
return _json_response(
|
||||
AgentDriveService().inspect_skill(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
skill_path=skill_path,
|
||||
)
|
||||
)
|
||||
except AgentDriveError as exc:
|
||||
return _handle(exc)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/drive/files/preview")
|
||||
class AgentDrivePreviewByAgentApi(Resource):
|
||||
@console_ns.doc("preview_agent_drive_file_by_agent")
|
||||
@ -182,6 +296,61 @@ class AgentDriveListApi(Resource):
|
||||
return {"items": [{k: v for k, v in item.items() if k != "file_id"} for item in items]}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/drive/skills")
|
||||
class AgentDriveSkillListApi(Resource):
|
||||
@console_ns.doc("list_agent_drive_skills")
|
||||
@console_ns.doc(description="List drive-backed skills for the bound agent")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentDriveListQuery)})
|
||||
@console_ns.response(200, "Drive skills", console_ns.models[AgentDriveSkillListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=_WORKFLOW_APP_MODES)
|
||||
def get(self, app_model: App):
|
||||
query = query_params_from_request(AgentDriveListQuery)
|
||||
agent_id = _resolve_agent_id(app_model, query.node_id)
|
||||
if not agent_id:
|
||||
return _agent_not_bound()
|
||||
try:
|
||||
items = AgentDriveService().list_skills(tenant_id=app_model.tenant_id, agent_id=agent_id)
|
||||
except AgentDriveError as exc:
|
||||
return _handle(exc)
|
||||
return {"items": items}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/drive/skills/<path:skill_path>/inspect")
|
||||
class AgentDriveSkillInspectApi(Resource):
|
||||
@console_ns.doc("inspect_agent_drive_skill")
|
||||
@console_ns.doc(description="Inspect one drive-backed skill for slash-menu hover/detail UI")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"skill_path": "Skill path/slug, e.g. tender-analyzer",
|
||||
**query_params_from_model(AgentDriveSkillInspectQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Drive skill inspect view", console_ns.models[AgentDriveSkillInspectResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=_WORKFLOW_APP_MODES)
|
||||
def get(self, app_model: App, skill_path: str):
|
||||
query = query_params_from_request(AgentDriveSkillInspectQuery)
|
||||
agent_id = _resolve_agent_id(app_model, query.node_id)
|
||||
if not agent_id:
|
||||
return _agent_not_bound()
|
||||
try:
|
||||
return _json_response(
|
||||
AgentDriveService().inspect_skill(
|
||||
tenant_id=app_model.tenant_id,
|
||||
agent_id=agent_id,
|
||||
skill_path=skill_path,
|
||||
)
|
||||
)
|
||||
except AgentDriveError as exc:
|
||||
return _handle(exc)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/drive/files/preview")
|
||||
class AgentDrivePreviewApi(Resource):
|
||||
@console_ns.doc("preview_agent_drive_file")
|
||||
@ -232,4 +401,8 @@ __all__ = [
|
||||
"AgentDriveListByAgentApi",
|
||||
"AgentDrivePreviewApi",
|
||||
"AgentDrivePreviewByAgentApi",
|
||||
"AgentDriveSkillInspectApi",
|
||||
"AgentDriveSkillInspectByAgentApi",
|
||||
"AgentDriveSkillListApi",
|
||||
"AgentDriveSkillListByAgentApi",
|
||||
]
|
||||
|
||||
@ -14,6 +14,7 @@ from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.app_access import resolve_app_access_filter
|
||||
from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
|
||||
from controllers.common.helpers import FileInfo
|
||||
from controllers.common.schema import (
|
||||
@ -78,7 +79,6 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
_CREATOR_IDS_BRACKET_PATTERN = re.compile(r"^creator_ids\[(\d+)\]$")
|
||||
AppListMode = Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"]
|
||||
DEFAULT_APP_LIST_MODE: AppListMode = "all"
|
||||
APP_LIST_PERMISSION_KEYS = frozenset({"app.preview", "app.acl.preview", "app.full_access"})
|
||||
|
||||
|
||||
class AppListBaseQuery(BaseModel):
|
||||
@ -167,10 +167,6 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
|
||||
return normalized
|
||||
|
||||
|
||||
def _has_app_list_permission(permission_keys: Sequence[str]) -> bool:
|
||||
return any(permission_key in APP_LIST_PERMISSION_KEYS for permission_key in permission_keys)
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -612,38 +608,12 @@ class AppListApi(Resource):
|
||||
current_user_id,
|
||||
)
|
||||
if dify_config.RBAC_ENABLED:
|
||||
whitelist_scope = enterprise_rbac_service.RBACService.AppAccess.whitelist_resources(
|
||||
access_filter = resolve_app_access_filter(
|
||||
str(current_tenant_id),
|
||||
current_user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
can_manage_own_apps = "app.create_and_management" in permissions.workspace.permission_keys
|
||||
has_default_preview = _has_app_list_permission(
|
||||
permissions.app.default_permission_keys
|
||||
) or _has_app_list_permission(permissions.workspace.permission_keys)
|
||||
permission_app_ids: set[str] | None = None
|
||||
if not has_default_preview:
|
||||
permission_app_ids = {
|
||||
override.resource_id
|
||||
for override in permissions.app.overrides
|
||||
if _has_app_list_permission(override.permission_keys)
|
||||
}
|
||||
|
||||
if getattr(whitelist_scope, "unrestricted", False):
|
||||
accessible_app_ids = permission_app_ids
|
||||
else:
|
||||
accessible_app_ids = set(whitelist_scope.resource_ids)
|
||||
if permission_app_ids is not None:
|
||||
accessible_app_ids |= permission_app_ids
|
||||
elif has_default_preview:
|
||||
accessible_app_ids = None
|
||||
|
||||
if accessible_app_ids:
|
||||
params.accessible_app_ids = sorted(accessible_app_ids)
|
||||
params.include_own_apps = can_manage_own_apps
|
||||
elif accessible_app_ids is not None and can_manage_own_apps:
|
||||
params.is_created_by_me = True
|
||||
elif accessible_app_ids is not None:
|
||||
params.accessible_app_ids = []
|
||||
access_filter.apply_to_params(params)
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
|
||||
@ -40,12 +40,15 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App, AppMode
|
||||
from services.agent.errors import AgentNotFoundError
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -191,10 +194,11 @@ class ChatMessageApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
return _create_chat_message(current_user=current_user, app_model=app_model)
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App):
|
||||
return _create_chat_message(current_tenant_id=current_tenant_id, current_user=current_user, app_model=app_model)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/chat-messages")
|
||||
@ -215,7 +219,12 @@ class AgentChatMessageApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
app_model = resolve_agent_app_model(tenant_id=current_tenant_id, agent_id=agent_id)
|
||||
return _create_chat_message(current_user=current_user, app_model=app_model)
|
||||
return _create_chat_message(
|
||||
current_tenant_id=current_tenant_id,
|
||||
current_user=current_user,
|
||||
app_model=app_model,
|
||||
agent_id=str(agent_id),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||
@ -249,11 +258,45 @@ class AgentChatMessageStopApi(Resource):
|
||||
return _stop_chat_message(current_user_id=current_user_id, app_model=app_model, task_id=task_id)
|
||||
|
||||
|
||||
def _create_chat_message(*, current_user: Account, app_model: App):
|
||||
def _resolve_current_user_agent_debug_conversation_id(
|
||||
*, current_tenant_id: str, current_user: Account, app_model: App, agent_id: str | None
|
||||
) -> str:
|
||||
roster_service = AgentRosterService(db.session)
|
||||
if agent_id:
|
||||
return roster_service.get_or_create_agent_app_debug_conversation_id(
|
||||
tenant_id=current_tenant_id,
|
||||
agent_id=agent_id,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
|
||||
agent = roster_service.get_app_backing_agent(tenant_id=current_tenant_id, app_id=str(app_model.id))
|
||||
if agent is None:
|
||||
raise AgentNotFoundError()
|
||||
return roster_service.get_or_create_agent_app_debug_conversation_id(
|
||||
tenant_id=current_tenant_id,
|
||||
agent_id=agent.id,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
|
||||
|
||||
def _create_chat_message(
|
||||
*, current_user: Account, app_model: App, current_tenant_id: str | None = None, agent_id: str | None = None
|
||||
):
|
||||
raw_payload = console_ns.payload or {}
|
||||
args_model = ChatMessagePayload.model_validate(raw_payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
|
||||
debug_conversation_id = _resolve_current_user_agent_debug_conversation_id(
|
||||
current_tenant_id=current_tenant_id or app_model.tenant_id,
|
||||
current_user=current_user,
|
||||
app_model=app_model,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
if args_model.conversation_id and args_model.conversation_id != debug_conversation_id:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
args["conversation_id"] = debug_conversation_id
|
||||
|
||||
streaming = _resolve_debugger_chat_streaming(
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
response_mode=args_model.response_mode,
|
||||
|
||||
@ -53,6 +53,7 @@ from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
@ -186,10 +187,11 @@ class ChatMessageListApi(Resource):
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_VIEW_LAYOUT)
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
return _list_chat_messages(app_model=app_model)
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
return _list_chat_messages(app_model=app_model, current_user=current_user)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/chat-messages")
|
||||
@ -205,10 +207,11 @@ class AgentChatMessageListApi(Resource):
|
||||
@setup_required
|
||||
@edit_permission_required
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_VIEW_LAYOUT)
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, agent_id: UUID):
|
||||
def get(self, current_tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
app_model = resolve_agent_app_model(tenant_id=current_tenant_id, agent_id=agent_id)
|
||||
return _list_chat_messages(app_model=app_model)
|
||||
return _list_chat_messages(app_model=app_model, current_user=current_user)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
@ -390,14 +393,24 @@ class AgentMessageApi(Resource):
|
||||
return _get_message_detail(app_model=app_model, message_id=message_id)
|
||||
|
||||
|
||||
def _list_chat_messages(*, app_model: App):
|
||||
def _list_chat_messages(*, app_model: App, current_user: Account | None = None):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.limit(1)
|
||||
)
|
||||
if AppMode.value_of(app_model.mode) == AppMode.AGENT and current_user is not None:
|
||||
try:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
conversation_id=args.conversation_id,
|
||||
user=current_user,
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
else:
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -83,7 +83,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
|
||||
@ -222,7 +222,7 @@ class DatasourceAuth(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -5,6 +5,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.account import TenantPluginPermission
|
||||
@ -17,6 +18,9 @@ def plugin_permission_required(
|
||||
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if dify_config.RBAC_ENABLED:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@ -169,7 +169,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
@ -244,7 +244,7 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
@ -326,7 +326,7 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
|
||||
@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
@ -481,7 +481,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
|
||||
@ -469,6 +469,7 @@ class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_DEBUG, resource_required=False)
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
@ -614,6 +615,7 @@ class PluginUploadFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -634,6 +636,7 @@ class PluginUploadFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -653,6 +656,7 @@ class PluginUploadFromBundleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -673,6 +677,7 @@ class PluginInstallFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -693,6 +698,7 @@ class PluginInstallFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -719,6 +725,7 @@ class PluginInstallFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -739,6 +746,7 @@ class PluginFetchMarketplacePkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
@ -764,6 +772,7 @@ class PluginFetchManifestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
@ -784,6 +793,7 @@ class PluginFetchInstallTasksApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
@ -801,6 +811,7 @@ class PluginFetchInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, task_id: str):
|
||||
@ -816,6 +827,7 @@ class PluginDeleteInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, task_id: str):
|
||||
@ -831,6 +843,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -846,6 +859,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, task_id: str, identifier: str):
|
||||
@ -862,6 +876,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -884,6 +899,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -911,6 +927,7 @@ class PluginUninstallApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -1041,10 +1058,11 @@ class PluginChangeAutoUpgradeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_PREFERENCES, resource_required=False)
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
if not user.is_admin_or_owner:
|
||||
if not dify_config.RBAC_ENABLED and not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = ParserAutoUpgradeChange.model_validate(console_ns.payload)
|
||||
@ -1097,6 +1115,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_PREFERENCES, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
# exclude one single plugin
|
||||
|
||||
@ -211,7 +211,7 @@ def _legacy_workspace_roles(
|
||||
name=role_name,
|
||||
description="",
|
||||
is_builtin=True,
|
||||
permission_keys=list(_LEGACY_ROLE_PERMISSION_KEYS[role_name]),
|
||||
permission_keys=list(dict.fromkeys(_LEGACY_ROLE_PERMISSION_KEYS[role_name])),
|
||||
role_tag="owner" if role_name == "owner" else "",
|
||||
)
|
||||
for role_name in ("owner", "admin", "editor", "normal", "dataset_operator")
|
||||
@ -244,11 +244,6 @@ def _legacy_workspace_roles(
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission catalogs.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog")
|
||||
class RBACWorkspaceCatalogApi(Resource):
|
||||
@login_required
|
||||
@ -375,30 +370,6 @@ class RBACRoleCopyApi(Resource):
|
||||
return _dump(role), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>/members")
|
||||
class RBACRoleMembersApi(Resource):
|
||||
@login_required
|
||||
@rbac_permission_required(
|
||||
RBACResourceScope.WORKSPACE, RBACPermission.WORKSPACE_ROLE_MANAGE, resource_required=False
|
||||
)
|
||||
@console_ns.response(200, "Success", console_ns.models[_RBACRoleAccountList.__name__])
|
||||
def get(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.Roles.members(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(role_id),
|
||||
options=_pagination_options(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access policies (tenant-level permission sets).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AccessPolicyCreateRequest(BaseModel):
|
||||
name: str
|
||||
resource_type: svc.RBACResourceType
|
||||
@ -788,11 +759,6 @@ class RBACDatasetMemberBindingsApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workspace-level access (Settings > Access Rules).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy")
|
||||
class RBACWorkspaceAppMatrixApi(Resource):
|
||||
@login_required
|
||||
|
||||
@ -971,7 +971,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
@ -1070,6 +1070,7 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
@ -1125,6 +1126,7 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str):
|
||||
payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
@ -1178,6 +1180,7 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str):
|
||||
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
|
||||
@ -1196,6 +1199,7 @@ class ToolMCPAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
payload = MCPAuthPayload.model_validate(console_ns.payload or {})
|
||||
@ -1300,6 +1304,7 @@ class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider_id: str):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
|
||||
@ -31,7 +31,7 @@ from controllers.openapi._models import (
|
||||
AppDslExportQuery,
|
||||
AppDslExportResponse,
|
||||
AppDslImportPayload,
|
||||
AppInfoResponse,
|
||||
AppInfo,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
@ -62,7 +62,6 @@ from controllers.openapi._models import (
|
||||
SessionListQuery,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
TagItem,
|
||||
TaskStopResponse,
|
||||
UsageInfo,
|
||||
WorkflowRunData,
|
||||
@ -96,12 +95,11 @@ register_response_schema_models(
|
||||
openapi_ns,
|
||||
ErrorBody,
|
||||
EventStreamResponse,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
MessageMetadata,
|
||||
AppListRow,
|
||||
AppListResponse,
|
||||
AppInfoResponse,
|
||||
AppInfo,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
AppDslExportResponse,
|
||||
|
||||
@ -63,6 +63,8 @@ class OpenApiErrorCode(StrEnum):
|
||||
FILE_EXTENSION_BLOCKED = "file_extension_blocked"
|
||||
MEMBER_LIMIT_EXCEEDED = "member_limit_exceeded"
|
||||
MEMBER_LICENSE_EXCEEDED = "member_license_exceeded"
|
||||
HUMAN_INPUT_FORM_NOT_FOUND = "form_not_found"
|
||||
RECIPIENT_SURFACE_MISMATCH = "recipient_surface_mismatch"
|
||||
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
@ -239,3 +241,16 @@ class MemberLicenseExceeded(OpenApiError): # noqa: N818
|
||||
error_code = OpenApiErrorCode.MEMBER_LICENSE_EXCEEDED
|
||||
description = "Workspace member license capacity reached."
|
||||
hint = "Contact your workspace administrator to expand the license seat count."
|
||||
|
||||
|
||||
class HumanInputFormNotFound(OpenApiError): # noqa: N818
|
||||
code = 404
|
||||
error_code = OpenApiErrorCode.HUMAN_INPUT_FORM_NOT_FOUND
|
||||
description = "No human-input form matches this token. It may be wrong, expired, or already submitted."
|
||||
|
||||
|
||||
class RecipientSurfaceMismatch(OpenApiError): # noqa: N818
|
||||
code = 403
|
||||
error_code = OpenApiErrorCode.RECIPIENT_SURFACE_MISMATCH
|
||||
description = "This form's recipient can't be submitted via the OpenAPI surface."
|
||||
hint = "Action it through its channel (web app or console)."
|
||||
|
||||
@ -38,18 +38,12 @@ class PaginationEnvelope[T](BaseModel):
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class TagItem(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: AppMode
|
||||
tags: list[TagItem] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
@ -70,16 +64,14 @@ class PermittedExternalAppsListResponse(BaseModel):
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
class AppInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[TagItem] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
class AppDescribeInfo(AppInfo):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
is_agent: bool = False
|
||||
@ -294,7 +286,6 @@ class AppListQuery(BaseModel):
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
|
||||
@ -5,11 +5,12 @@ from typing import cast
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import AppDslExportQuery, AppDslExportResponse, AppDslImportPayload
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import Account, App
|
||||
@ -37,6 +38,11 @@ class AppDslImportApi(Resource):
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
rbac=RBACRequirement(
|
||||
resource_type=RBACResourceScope.APP,
|
||||
scene=RBACPermission.APP_IMPORT_EXPORT_DSL,
|
||||
resource_required=False,
|
||||
),
|
||||
)
|
||||
@returns(200, Import, "Import completed")
|
||||
@returns(202, Import, "Import pending confirmation")
|
||||
@ -89,6 +95,11 @@ class AppDslImportConfirmApi(Resource):
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
rbac=RBACRequirement(
|
||||
resource_type=RBACResourceScope.APP,
|
||||
scene=RBACPermission.APP_IMPORT_EXPORT_DSL,
|
||||
resource_required=False,
|
||||
),
|
||||
)
|
||||
@returns(200, Import, "Import confirmed")
|
||||
@returns(400, Import, "Import failed")
|
||||
@ -125,6 +136,7 @@ class AppDslExportApi(Resource):
|
||||
scope=Scope.APPS_READ,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_IMPORT_EXPORT_DSL),
|
||||
)
|
||||
@accepts(query=AppDslExportQuery)
|
||||
@returns(200, AppDslExportResponse, "Export successful")
|
||||
@ -155,6 +167,7 @@ class AppDslCheckDependenciesApi(Resource):
|
||||
scope=Scope.APPS_READ,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_IMPORT_EXPORT_DSL),
|
||||
)
|
||||
@returns(200, CheckDependenciesResult, "Dependencies checked")
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
|
||||
@ -19,12 +19,13 @@ from werkzeug.exceptions import (
|
||||
|
||||
import services
|
||||
from controllers.common.fields import EventStreamResponse
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import AppRunRequest, TaskStopResponse
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -136,7 +137,10 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
@openapi_ns.response(200, "Run result (SSE stream)", openapi_ns.models[EventStreamResponse.__name__])
|
||||
@accepts(body=AppRunRequest)
|
||||
def post(self, app_id: str, *, auth_data: AuthData, body: AppRunRequest):
|
||||
@ -167,7 +171,10 @@ class AppRunApi(Resource):
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
@returns(200, TaskStopResponse, description="Task stopped")
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
|
||||
@ -8,7 +8,10 @@ from typing import Any, cast
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.app_access import AppAccessFilter, resolve_app_access_filter
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
@ -19,18 +22,17 @@ from controllers.openapi._models import (
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
@ -84,54 +86,55 @@ def parameters_payload(app: App) -> dict:
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
def build_app_describe_response(app: App, fields: set[str] | None) -> AppDescribeResponse:
|
||||
"""Public projection of an app (name / params / input schema) — never internal config."""
|
||||
want_info = fields is None or "info" in fields
|
||||
want_params = fields is None or "parameters" in fields
|
||||
want_schema = fields is None or "input_schema" in fields
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
is_agent=app.mode in (AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return AppDescribeResponse(info=info, parameters=parameters, input_schema=input_schema)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_VIEW_LAYOUT),
|
||||
)
|
||||
@returns(200, AppDescribeResponse, description="App description")
|
||||
@accepts(query=AppDescribeQuery)
|
||||
def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery):
|
||||
# describe is UUID-only (workspace_id query param dropped in #37212).
|
||||
app = self._load(app_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
is_agent=app.mode in ("agent-chat", "advanced-chat"),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
return build_app_describe_response(app, query.fields)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
@ -152,45 +155,55 @@ class AppListApi(Resource):
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
# Compute RBAC-accessible app IDs when RBAC is enabled and the caller is an account.
|
||||
# ``None`` means unrestricted (caller can see all apps in the workspace);
|
||||
# an empty set or list means the caller has no accessible apps.
|
||||
# End-users bypass RBAC here — their access is controlled by scope upstream.
|
||||
apply_rbac_filter = (
|
||||
dify_config.RBAC_ENABLED and auth_data.caller_kind != "end_user" and auth_data.account_id is not None
|
||||
)
|
||||
access_filter = AppAccessFilter.unrestricted()
|
||||
if apply_rbac_filter:
|
||||
access_filter = resolve_app_access_filter(workspace_id, str(auth_data.account_id))
|
||||
|
||||
tenant_name: str | None = None
|
||||
if parsed_uuid is not None:
|
||||
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None or str(app.tenant_id) != workspace_id:
|
||||
return empty
|
||||
# Apply RBAC visibility to the UUID fast-path the same way the service
|
||||
# layer does for paginated queries (id in accessible set OR own app).
|
||||
if apply_rbac_filter and not access_filter.is_app_accessible(
|
||||
str(app.id), str(app.maintainer) if app.maintainer else None, str(auth_data.account_id)
|
||||
):
|
||||
return empty
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||
return env
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag, db.session)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
params = AppListParams(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else "all", # type:ignore
|
||||
name=query.name,
|
||||
tag_ids=tag_ids,
|
||||
status="normal",
|
||||
# Visibility gate pushed into the query — pagination.total stays
|
||||
# consistent across pages because invisible rows never count.
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
if apply_rbac_filter:
|
||||
access_filter.apply_to_params(params)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params, db.session)
|
||||
if pagination is None:
|
||||
return empty
|
||||
@ -205,9 +218,7 @@ class AppListApi(Resource):
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[TagItem(name=t.name) for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
|
||||
@ -8,14 +8,18 @@ EE blueprint chain so this module is unreachable there.
|
||||
from __future__ import annotations
|
||||
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import (
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppListRow,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.apps import build_app_describe_response
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
from extensions.ext_database import db
|
||||
@ -67,9 +71,7 @@ class PermittedExternalAppsListApi(Resource):
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
@ -82,3 +84,20 @@ class PermittedExternalAppsListApi(Resource):
|
||||
data=items,
|
||||
)
|
||||
return env
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps/<string:app_id>/describe")
|
||||
class PermittedExternalAppDescribeApi(Resource):
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
|
||||
edition=frozenset({Edition.EE}),
|
||||
)
|
||||
@returns(200, AppDescribeResponse, description="Permitted external app description")
|
||||
@accepts(query=AppDescribeQuery)
|
||||
def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery):
|
||||
# App already loaded and ACL-checked by the external_sso pipeline; project it.
|
||||
app = auth_data.app
|
||||
if app is None:
|
||||
raise NotFound("app not found")
|
||||
return build_app_describe_response(app, query.fields)
|
||||
|
||||
@ -3,9 +3,11 @@ from __future__ import annotations
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_EE,
|
||||
HAS_ALLOWED_ROLES,
|
||||
HAS_RBAC,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
WEBAPP_RUN_SCOPED,
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED,
|
||||
WORKSPACE_SCOPED,
|
||||
)
|
||||
@ -25,6 +27,7 @@ from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_api_enabled,
|
||||
check_private_app_permission,
|
||||
check_rbac_permission,
|
||||
check_scope,
|
||||
check_workspace_member,
|
||||
check_workspace_mismatch,
|
||||
@ -47,8 +50,9 @@ account_pipeline = AuthPipeline(
|
||||
When(WORKSPACE_SCOPED, then=check_workspace_member),
|
||||
When(PATH_HAS_APP_ID, then=check_workspace_mismatch),
|
||||
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
When(HAS_RBAC, then=check_rbac_permission),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED & WEBAPP_RUN_SCOPED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE & WEBAPP_RUN_SCOPED, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
|
||||
from libs.oauth_bearer import TokenType
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -50,8 +50,11 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
WEBAPP_RUN_SCOPED = request_cond(lambda ctx: ctx.scope == Scope.APPS_RUN)
|
||||
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
|
||||
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
|
||||
HAS_RBAC = request_cond(lambda ctx: ctx.rbac is not None)
|
||||
|
||||
# Caller must belong to the resolved tenant: either an app-scoped path (tenant
|
||||
# from the app) or an explicit workspace-membership path (tenant from request).
|
||||
|
||||
@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from configs import dify_config
|
||||
from core.rbac import RBACPermission, RBACResourceScope
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant, TenantAccountRole
|
||||
from models.model import App, EndUser
|
||||
@ -35,6 +36,14 @@ class ExternalIdentity(BaseModel):
|
||||
issuer: str | None = None
|
||||
|
||||
|
||||
class RBACRequirement(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
resource_type: RBACResourceScope
|
||||
scene: RBACPermission
|
||||
resource_required: bool = True
|
||||
|
||||
|
||||
class RequestContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@ -43,6 +52,7 @@ class RequestContext(BaseModel):
|
||||
path_params: dict[str, str]
|
||||
workspace_membership: bool = False
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
rbac: RBACRequirement | None = None
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
@ -59,6 +69,7 @@ class AuthData(BaseModel):
|
||||
path_params: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
rbac: RBACRequirement | None = None
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
|
||||
@ -21,6 +21,7 @@ from controllers.openapi.auth.data import (
|
||||
AuthData,
|
||||
Edition,
|
||||
ExternalIdentity,
|
||||
RBACRequirement,
|
||||
RequestContext,
|
||||
current_edition,
|
||||
)
|
||||
@ -59,6 +60,7 @@ class AuthPipeline:
|
||||
scope: Scope | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
rbac: RBACRequirement | None = None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
@ -66,6 +68,7 @@ class AuthPipeline:
|
||||
path_params=dict(request.view_args or {}),
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
data = AuthData(
|
||||
@ -77,6 +80,7 @@ class AuthPipeline:
|
||||
tenants=dict(identity.verified_tenants),
|
||||
required_scope=scope,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
path_params=dict(req_ctx.path_params),
|
||||
external_identity=(
|
||||
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
|
||||
@ -129,6 +133,7 @@ class PipelineRouter:
|
||||
edition: frozenset[Edition] | None = None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
rbac: RBACRequirement | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
@ -136,6 +141,7 @@ class PipelineRouter:
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
def guard_workspace(
|
||||
@ -145,6 +151,7 @@ class PipelineRouter:
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
rbac: RBACRequirement | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
@ -152,6 +159,7 @@ class PipelineRouter:
|
||||
edition=edition,
|
||||
workspace_membership=True,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
def _make_decorator(
|
||||
@ -162,6 +170,7 @@ class PipelineRouter:
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None,
|
||||
rbac: RBACRequirement | None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
@ -175,6 +184,7 @@ class PipelineRouter:
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
return decorated
|
||||
@ -192,6 +202,7 @@ class PipelineRouter:
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
rbac: RBACRequirement | None = None,
|
||||
) -> Any:
|
||||
# 404 not 403 — this edition doesn't expose the feature at all
|
||||
if edition is not None and current_edition() not in edition:
|
||||
@ -235,6 +246,7 @@ class PipelineRouter:
|
||||
scope=scope,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, NotFound, UnprocessableEntity
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.wraps import enforce_rbac_access
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
@ -38,6 +40,9 @@ def check_workspace_mismatch(data: AuthData) -> None:
|
||||
|
||||
|
||||
def check_workspace_role(data: AuthData) -> None:
|
||||
if dify_config.RBAC_ENABLED and data.rbac is not None:
|
||||
# fine-grained permission check is performed by RBAC
|
||||
return
|
||||
if data.allowed_roles is None:
|
||||
return
|
||||
if data.tenant_role is None:
|
||||
@ -46,6 +51,27 @@ def check_workspace_role(data: AuthData) -> None:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
|
||||
def check_rbac_permission(data: AuthData) -> None:
|
||||
req = data.rbac
|
||||
if req is None:
|
||||
return
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
return
|
||||
# Only account callers are subject to RBAC; end_user access is scope-controlled.
|
||||
if data.caller_kind != "account":
|
||||
return
|
||||
if data.account_id is None or data.tenant is None:
|
||||
raise Forbidden("rbac context missing")
|
||||
enforce_rbac_access(
|
||||
tenant_id=str(data.tenant.id),
|
||||
account_id=str(data.account_id),
|
||||
resource_type=req.resource_type,
|
||||
scene=req.scene,
|
||||
resource_required=req.resource_required,
|
||||
path_args=dict(data.path_params),
|
||||
)
|
||||
|
||||
|
||||
def check_app_api_enabled(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
|
||||
@ -12,16 +12,21 @@ import logging
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._errors import HumanInputFormNotFound, RecipientSurfaceMismatch
|
||||
from controllers.openapi._models import FormSubmitResponse, HumanInputFormDefinitionResponse
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from core.workflow.human_input_policy import (
|
||||
HumanInputSurface,
|
||||
is_recipient_type_allowed_for_surface,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
from libs.oauth_bearer import Scope
|
||||
@ -47,31 +52,37 @@ def _jsonify_form_definition(form) -> Response:
|
||||
|
||||
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise NotFound("Form not found")
|
||||
raise HumanInputFormNotFound()
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
|
||||
raise NotFound("Form not found")
|
||||
raise RecipientSurfaceMismatch()
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition", openapi_ns.models[HumanInputFormDefinitionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
app_model, _caller, _caller_kind = auth_data.require_app_context()
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
raise HumanInputFormNotFound()
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
@returns(200, FormSubmitResponse, description="Form submitted")
|
||||
@accepts(body=HumanInputFormSubmitPayload)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData, body: HumanInputFormSubmitPayload):
|
||||
@ -80,7 +91,7 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
raise HumanInputFormNotFound()
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
@ -106,6 +117,6 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
raise HumanInputFormNotFound()
|
||||
|
||||
return FormSubmitResponse()
|
||||
|
||||
@ -19,9 +19,10 @@ from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import EventStreamResponse
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
@ -46,7 +47,10 @@ class WorkflowEventsQuery(BaseModel):
|
||||
class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(WorkflowEventsQuery))
|
||||
@openapi_ns.response(200, "SSE event stream", openapi_ns.models[EventStreamResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Any, cast
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
@ -9,7 +10,11 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.agent_app.app_variable_projection import agent_app_variables_to_user_input_form
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from models.agent import Agent, AgentConfigSnapshot, AgentScope, AgentSource, AgentStatus
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
|
||||
@ -29,6 +34,40 @@ class AppMetaResponse(ResponseModel):
|
||||
register_response_schema_models(service_api_ns, Parameters, AppMetaResponse, AppInfoResponse)
|
||||
|
||||
|
||||
def _get_agent_app_feature_dict_and_user_input_form(app_model: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
app_model_config = app_model.app_model_config
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict()) if app_model_config is not None else {}
|
||||
|
||||
agent = db.session.scalar(
|
||||
select(Agent)
|
||||
.where(
|
||||
Agent.tenant_id == app_model.tenant_id,
|
||||
Agent.app_id == app_model.id,
|
||||
Agent.scope == AgentScope.ROSTER,
|
||||
Agent.source == AgentSource.AGENT_APP,
|
||||
Agent.status == AgentStatus.ACTIVE,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if agent is None or not agent.active_config_snapshot_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
snapshot = db.session.scalar(
|
||||
select(AgentConfigSnapshot)
|
||||
.where(
|
||||
AgentConfigSnapshot.tenant_id == app_model.tenant_id,
|
||||
AgentConfigSnapshot.agent_id == agent.id,
|
||||
AgentConfigSnapshot.id == agent.active_config_snapshot_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if snapshot is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
|
||||
return features_dict, agent_app_variables_to_user_input_form(agent_soul.app_variables)
|
||||
|
||||
|
||||
@service_api_ns.route("/parameters")
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
@ -61,12 +100,16 @@ class AppParameterApi(Resource):
|
||||
|
||||
Returns the input form parameters and configuration for the application.
|
||||
"""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
features_dict: dict[str, Any]
|
||||
user_input_form: list[dict[str, Any]]
|
||||
if app_model.mode == AppMode.AGENT:
|
||||
features_dict, user_input_form = _get_agent_app_feature_dict_and_user_input_form(app_model)
|
||||
elif app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -22,7 +22,7 @@ from core.app.entities.queue_entities import (
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.db.session_factory import create_session, session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
@ -107,7 +107,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_id=self.application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
|
||||
|
||||
if not app_record:
|
||||
@ -204,6 +204,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
|
||||
)
|
||||
|
||||
# Release the Flask scoped session before workflow execution so a checked-out DB connection
|
||||
# is not held for the lifetime of the graph run.
|
||||
db.session.close()
|
||||
|
||||
# RUN WORKFLOW
|
||||
@ -368,7 +370,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
:return: List of conversation variables ready for use
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
with create_session() as session, session.begin():
|
||||
existing_variables = self._load_existing_conversation_variables(session)
|
||||
|
||||
if not existing_variables:
|
||||
|
||||
@ -21,6 +21,7 @@ from core.app.app_config.entities import (
|
||||
EasyUIBasedAppModelConfigFrom,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps.agent_app.app_variable_projection import agent_app_variables_to_user_input_form
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
@ -98,8 +99,7 @@ class AgentAppConfigManager(BaseAppConfigManager):
|
||||
# pipeline's bookkeeping (token counting, persistence).
|
||||
base["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
base["pre_prompt"] = agent_soul.prompt.system_prompt or ""
|
||||
# Agent App takes the user message directly; no completion-style inputs form.
|
||||
base.setdefault("user_input_form", [])
|
||||
base["user_input_form"] = agent_app_variables_to_user_input_form(agent_soul.app_variables)
|
||||
return base
|
||||
|
||||
|
||||
|
||||
37
api/core/app/apps/agent_app/app_variable_projection.py
Normal file
37
api/core/app/apps/agent_app/app_variable_projection.py
Normal file
@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from models.agent_config_entities import AppVariableConfig
|
||||
|
||||
|
||||
def agent_app_variables_to_user_input_form(app_variables: Sequence[AppVariableConfig]) -> list[dict[str, Any]]:
|
||||
"""Project Agent Soul app variables into the legacy service-API parameter form."""
|
||||
|
||||
user_input_form: list[dict[str, Any]] = []
|
||||
for variable in app_variables:
|
||||
form_type = _form_type_for_agent_variable(variable.type)
|
||||
form_item: dict[str, Any] = {
|
||||
"label": variable.name,
|
||||
"variable": variable.name,
|
||||
"required": variable.required,
|
||||
}
|
||||
if variable.default is not None:
|
||||
form_item["default"] = variable.default
|
||||
user_input_form.append({form_type: form_item})
|
||||
return user_input_form
|
||||
|
||||
|
||||
def _form_type_for_agent_variable(variable_type: str) -> str:
|
||||
normalized = variable_type.strip().lower()
|
||||
if normalized in {"number", "integer", "float"}:
|
||||
return "number"
|
||||
if normalized in {"boolean", "bool"}:
|
||||
return "checkbox"
|
||||
if normalized in {"paragraph", "long_text", "multiline"}:
|
||||
return "paragraph"
|
||||
return "text-input"
|
||||
|
||||
|
||||
__all__ = ["agent_app_variables_to_user_input_form"]
|
||||
@ -12,10 +12,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.db.session_factory import create_session
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
@ -47,7 +47,10 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AgentChatAppConfig, app_config)
|
||||
app_stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(app_stmt)
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(app_stmt)
|
||||
if app_record:
|
||||
session.expunge(app_record)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -185,14 +188,18 @@ class AgentChatAppRunner(AppRunner):
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
|
||||
conversation_result = db.session.scalar(conversation_stmt)
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
msg_stmt = select(Message).where(Message.id == message.id)
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
with create_session() as session:
|
||||
conversation_result = session.scalar(conversation_stmt)
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
|
||||
message_result = session.scalar(msg_stmt)
|
||||
if message_result is not None:
|
||||
session.expunge(message_result)
|
||||
session.expunge(conversation_result)
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import create_session
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
@ -46,7 +47,10 @@ class ChatAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(ChatAppConfig, app_config)
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(stmt)
|
||||
if app_record:
|
||||
session.expunge(app_record)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -216,6 +220,8 @@ class ChatAppRunner(AppRunner):
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
# Release the Flask scoped session before LLM streaming so a checked-out DB connection
|
||||
# is not held for the lifetime of the provider response.
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
|
||||
@ -51,8 +51,11 @@ from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_forms import (
|
||||
load_form_dispositions_by_form_id,
|
||||
)
|
||||
from core.workflow.human_input_policy import (
|
||||
FormDisposition,
|
||||
HumanInputSurface,
|
||||
enrich_human_input_pause_reasons,
|
||||
resolve_human_input_pause_reason_inputs,
|
||||
@ -340,13 +343,14 @@ class WorkflowResponseConverter:
|
||||
human_input_form_ids = [reason.form_id for reason in resolved_reasons if isinstance(reason, HumanInputRequired)]
|
||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
||||
display_in_ui_by_form_id: dict[str, bool] = {}
|
||||
form_token_by_form_id: dict[str, str] = {}
|
||||
dispositions_by_form_id: dict[str, FormDisposition] = {}
|
||||
if human_input_form_ids:
|
||||
stmt = select(
|
||||
HumanInputForm.id,
|
||||
HumanInputForm.expiration_time,
|
||||
HumanInputForm.form_definition,
|
||||
).where(HumanInputForm.id.in_(human_input_form_ids))
|
||||
hitl_surface = _INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from)
|
||||
with Session(bind=db.engine) as session:
|
||||
for form_id, expiration_time, form_definition in session.execute(stmt):
|
||||
expiration_times_by_form_id[str(form_id)] = expiration_time
|
||||
@ -355,17 +359,17 @@ class WorkflowResponseConverter:
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||
dispositions_by_form_id = load_form_dispositions_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=_INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from),
|
||||
surface=hitl_surface,
|
||||
)
|
||||
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
pause_reasons = enrich_human_input_pause_reasons(
|
||||
pause_reasons,
|
||||
form_tokens_by_form_id=form_token_by_form_id,
|
||||
dispositions_by_form_id=dispositions_by_form_id,
|
||||
expiration_times_by_form_id={
|
||||
form_id: int(expiration_time.timestamp())
|
||||
for form_id, expiration_time in expiration_times_by_form_id.items()
|
||||
@ -379,6 +383,7 @@ class WorkflowResponseConverter:
|
||||
expiration_time = expiration_times_by_form_id.get(reason.form_id)
|
||||
if expiration_time is None:
|
||||
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
|
||||
disposition = dispositions_by_form_id.get(reason.form_id)
|
||||
responses.append(
|
||||
HumanInputRequiredResponse(
|
||||
task_id=task_id,
|
||||
@ -391,7 +396,8 @@ class WorkflowResponseConverter:
|
||||
inputs=reason.inputs,
|
||||
actions=reason.actions,
|
||||
display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False),
|
||||
form_token=form_token_by_form_id.get(reason.form_id),
|
||||
form_token=disposition.form_token if disposition else None,
|
||||
approval_channels=list(disposition.approval_channels) if disposition else [],
|
||||
resolved_default_values=reason.resolved_default_values,
|
||||
expiration_time=int(expiration_time.timestamp()),
|
||||
),
|
||||
|
||||
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import create_session
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
@ -39,7 +40,10 @@ class CompletionAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(CompletionAppConfig, app_config)
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(stmt)
|
||||
if app_record:
|
||||
session.expunge(app_record)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -174,6 +178,8 @@ class CompletionAppRunner(AppRunner):
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
# Release the Flask scoped session before LLM streaming so a checked-out DB connection
|
||||
# is not held for the lifetime of the provider response.
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class MessageBasedAppQueueManager(AppQueueManager):
|
||||
@ -47,4 +48,6 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
if self._app_mode == AppMode.ADVANCED_CHAT.value:
|
||||
return
|
||||
raise GenerateTaskStoppedError()
|
||||
|
||||
@ -3,6 +3,7 @@ import time
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
@ -14,12 +15,12 @@ from core.app.entities.app_invoke_entities import (
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.db.session_factory import create_session
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowType
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
@ -83,22 +84,24 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
user_from = self._resolve_user_from(invoke_from)
|
||||
|
||||
user_id = None
|
||||
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.get(EndUser, self.application_generate_entity.user_id)
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
with create_session() as session:
|
||||
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = session.get(EndUser, self.application_generate_entity.user_id)
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.get(Pipeline, app_config.app_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
pipeline = session.get(Pipeline, app_config.app_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
workflow = self.get_workflow(session=session, pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
session.expunge(pipeline)
|
||||
session.expunge(workflow)
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
@ -208,12 +211,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
def get_workflow(self, session: Session, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = db.session.scalar(
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
|
||||
.limit(1)
|
||||
@ -298,11 +301,11 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
if document_id and dataset_id:
|
||||
document = db.session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = event.error or "Unknown error"
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
with create_session() as session, session.begin():
|
||||
document = session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = event.error or "Unknown error"
|
||||
session.add(document)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -43,6 +42,3 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedError()
|
||||
|
||||
@ -288,6 +288,7 @@ class HumanInputRequiredResponse(StreamResponse):
|
||||
actions: Sequence[UserActionConfig] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
approval_channels: list[str] = Field(default_factory=list)
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int = Field(..., description="Unix timestamp in seconds")
|
||||
|
||||
@ -311,6 +312,7 @@ class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||
actions: Sequence[UserActionConfig] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
approval_channels: list[str] = Field(default_factory=list)
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
@ -325,6 +327,7 @@ class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||
actions=data.actions,
|
||||
display_in_ui=data.display_in_ui,
|
||||
form_token=data.form_token,
|
||||
approval_channels=data.approval_channels,
|
||||
resolved_default_values=data.resolved_default_values,
|
||||
expiration_time=data.expiration_time,
|
||||
)
|
||||
|
||||
@ -3,7 +3,6 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
@ -13,10 +12,19 @@ from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.db.session_factory import create_session
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models import Account, TenantAccountJoin
|
||||
from models.model import (
|
||||
App,
|
||||
AppMode,
|
||||
AppModelConfig,
|
||||
AppModelConfigDict,
|
||||
EndUser,
|
||||
load_annotation_reply_config,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
@ -30,18 +38,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
"""Retrieve app parameters."""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
workflow = cls._get_workflow(app)
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
app_model_config_dict = cls._get_app_model_config_dict(app)
|
||||
if app_model_config_dict is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
features_dict = cast(dict[str, Any], app_model_config_dict)
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
@ -68,7 +76,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if not user_id:
|
||||
user = EndUserService.get_or_create_end_user(app)
|
||||
else:
|
||||
user = cls._get_user(user_id)
|
||||
user = cls._get_user(user_id, app)
|
||||
|
||||
conversation_id = conversation_id or ""
|
||||
|
||||
@ -79,7 +87,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
|
||||
case AppMode.WORKFLOW:
|
||||
return cls.invoke_workflow_app(app, user, stream, inputs, files)
|
||||
workflow = cls._get_workflow(app)
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
return cls.invoke_workflow_app(app, workflow, user, stream, inputs, files)
|
||||
case AppMode.COMPLETION:
|
||||
return cls.invoke_completion_app(app, user, stream, inputs, files)
|
||||
case _:
|
||||
@ -101,7 +112,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
workflow = cls._get_workflow(app)
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
@ -158,6 +169,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def invoke_workflow_app(
|
||||
cls,
|
||||
app: App,
|
||||
workflow: Workflow,
|
||||
user: EndUser | Account,
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
@ -166,10 +178,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke workflow app
|
||||
"""
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
@ -207,16 +215,26 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_user(cls, user_id: str) -> EndUser | Account:
|
||||
def _get_user(cls, user_id: str, app: App) -> EndUser | Account:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(EndUser).where(EndUser.id == user_id)
|
||||
with create_session() as session:
|
||||
stmt = select(EndUser).where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == app.tenant_id,
|
||||
EndUser.app_id == app.id,
|
||||
)
|
||||
user = session.scalar(stmt)
|
||||
if not user:
|
||||
stmt = select(Account).where(Account.id == user_id)
|
||||
stmt = select(Account).where(
|
||||
Account.id == user_id,
|
||||
Account.id == TenantAccountJoin.account_id,
|
||||
TenantAccountJoin.tenant_id == app.tenant_id,
|
||||
)
|
||||
user = session.scalar(stmt)
|
||||
if user:
|
||||
session.expunge(user)
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
@ -229,7 +247,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
get app
|
||||
"""
|
||||
try:
|
||||
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
|
||||
with create_session() as session:
|
||||
app = session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
|
||||
if app:
|
||||
session.expunge(app)
|
||||
except Exception:
|
||||
raise ValueError("app not found")
|
||||
|
||||
@ -237,3 +258,41 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
raise ValueError("app not found")
|
||||
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def _get_workflow(cls, app: App) -> Workflow | None:
|
||||
"""
|
||||
get workflow without relying on App.workflow's request-scoped session property
|
||||
"""
|
||||
if not app.workflow_id:
|
||||
return None
|
||||
|
||||
with create_session() as session:
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.id == app.workflow_id, Workflow.tenant_id == app.tenant_id, Workflow.app_id == app.id)
|
||||
.limit(1)
|
||||
)
|
||||
if workflow:
|
||||
session.expunge(workflow)
|
||||
return workflow
|
||||
|
||||
@classmethod
|
||||
def _get_app_model_config_dict(cls, app: App) -> AppModelConfigDict | None:
|
||||
"""
|
||||
get app model config features without relying on request-scoped session-backed model properties
|
||||
"""
|
||||
if not app.app_model_config_id:
|
||||
return None
|
||||
|
||||
with create_session() as session:
|
||||
app_model_config = session.scalar(
|
||||
select(AppModelConfig)
|
||||
.where(AppModelConfig.id == app.app_model_config_id, AppModelConfig.app_id == app.id)
|
||||
.limit(1)
|
||||
)
|
||||
if app_model_config is None:
|
||||
return None
|
||||
|
||||
annotation_reply = load_annotation_reply_config(session, app_model_config.app_id)
|
||||
return app_model_config.to_dict(annotation_reply=annotation_reply)
|
||||
|
||||
@ -22,23 +22,35 @@ class RBACPermission(StrEnum):
|
||||
|
||||
APP_VIEW_LAYOUT = "app_view_layout"
|
||||
APP_TEST_AND_RUN = "app_test_and_run"
|
||||
APP_PREVIEW = "app_preview"
|
||||
APP_CREATE_AND_MANAGEMENT = "app_create_and_management"
|
||||
APP_RELEASE_AND_VERSION = "app_release_and_version"
|
||||
APP_IMPORT_EXPORT_DSL = "app_import_export_dsl"
|
||||
APP_EDIT = "app_edit"
|
||||
APP_MONITOR = "app_monitor"
|
||||
APP_DELETE = "app_delete"
|
||||
APP_ACCESS_CONFIG = "app_access_config"
|
||||
|
||||
DATASET_PREVIEW = "dataset_preview"
|
||||
DATASET_READONLY = "dataset_readonly"
|
||||
DATASET_EDIT = "dataset_edit"
|
||||
DATASET_CREATE_AND_MANAGEMENT = "dataset_create_and_management"
|
||||
DATASET_PIPELINE_TEST = "dataset_pipeline_test"
|
||||
DATASET_DOCUMENT_DOWNLOAD = "dataset_document_download"
|
||||
DATASET_RETRIEVAL_RECALL = "dataset_retrieval_recall"
|
||||
DATASET_USE = "dataset_use"
|
||||
DATASET_DELETE_FILE = "dataset_delete_file"
|
||||
DATASET_PIPELINE_RELEASE = "dataset_pipeline_release"
|
||||
DATASET_DELETE = "dataset_delete"
|
||||
DATASET_ACCESS_CONFIG = "dataset_access_config"
|
||||
DATASET_API_KEY_MANAGE = "dataset_api_key_manage"
|
||||
DATASET_EXTERNAL_CONNECT = "dataset_external_connect"
|
||||
DATASET_IMPORT_EXPORT_DSL = "dataset_import_export_dsl"
|
||||
|
||||
WORKSPACE_MEMBER_MANAGE = "workspace_member_manage"
|
||||
WORKSPACE_ROLE_MANAGE = "workspace_role_manage"
|
||||
API_EXTENSION_MANAGE = "api_extension_manage"
|
||||
CUSTOMIZATION_MANAGE = "customization_manage"
|
||||
|
||||
SNIPPETS_CREATE_AND_MODIFY = "snippets_create_and_modify"
|
||||
SNIPPETS_MANAGE = "snippets_management"
|
||||
@ -49,6 +61,7 @@ class RBACPermission(StrEnum):
|
||||
PLUGIN_DEBUG = "plugin_debug"
|
||||
|
||||
CREDENTIAL_USE = "credential_use"
|
||||
CREDENTIAL_CREATE = "credential_create"
|
||||
CREDENTIAL_MANAGE = "credential_manage"
|
||||
|
||||
TOOL_MANAGE = "tool_manage"
|
||||
|
||||
@ -12,60 +12,61 @@ from collections.abc import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
|
||||
from core.workflow.human_input_policy import (
|
||||
FormDisposition,
|
||||
HumanInputSurface,
|
||||
disposition_for_surface,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
|
||||
def load_form_dispositions_by_form_id(
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
session: Session | None = None,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, FormDisposition]:
|
||||
"""Resolve each paused form's resume token and approval channels for `surface`."""
|
||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||
if not unique_form_ids:
|
||||
return {}
|
||||
|
||||
if session is not None:
|
||||
return _load_form_dispositions_by_form_id(session, unique_form_ids, surface=surface)
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||
return _load_form_dispositions_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||
|
||||
|
||||
def _load_form_dispositions_by_form_id(
|
||||
session: Session,
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> dict[str, FormDisposition]:
|
||||
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(stmt):
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||
(recipient.recipient_type, recipient.access_token or "")
|
||||
)
|
||||
return {
|
||||
form_id: disposition_for_surface(recipients, surface=surface)
|
||||
for form_id, recipients in recipients_by_form_id.items()
|
||||
}
|
||||
|
||||
|
||||
def load_form_tokens_by_form_id(
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
session: Session | None = None,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Load the preferred access token for each human input form."""
|
||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||
if not unique_form_ids:
|
||||
return {}
|
||||
|
||||
if session is not None:
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||
|
||||
|
||||
def _load_form_tokens_by_form_id(
|
||||
session: Session,
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(stmt):
|
||||
if not recipient.access_token:
|
||||
continue
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||
(recipient.recipient_type, recipient.access_token)
|
||||
)
|
||||
|
||||
tokens_by_form_id: dict[str, str] = {}
|
||||
for form_id, recipients in recipients_by_form_id.items():
|
||||
token = _get_surface_form_token(recipients, surface=surface)
|
||||
if token is not None:
|
||||
tokens_by_form_id[form_id] = token
|
||||
return tokens_by_form_id
|
||||
|
||||
|
||||
def _get_surface_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> str | None:
|
||||
if surface in {HumanInputSurface.SERVICE_API, HumanInputSurface.OPENAPI}:
|
||||
for recipient_type, token in recipients:
|
||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||
return token
|
||||
|
||||
return get_preferred_form_token(recipients)
|
||||
"""Resume tokens only, for callers that don't surface approval channels."""
|
||||
dispositions = load_form_dispositions_by_form_id(form_ids, session=session, surface=surface)
|
||||
return {
|
||||
form_id: disposition.form_token
|
||||
for form_id, disposition in dispositions.items()
|
||||
if disposition.form_token is not None
|
||||
}
|
||||
|
||||
@ -2,14 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType
|
||||
from graphon.nodes.human_input.entities import FormInputConfig, SelectInputConfig
|
||||
from graphon.nodes.human_input.enums import ValueSourceType
|
||||
from graphon.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||
from graphon.variables import ArrayStringSegment
|
||||
from models.human_input import RecipientType
|
||||
from models.human_input import ApprovalChannel, RecipientType
|
||||
|
||||
|
||||
class HumanInputSurface(StrEnum):
|
||||
@ -20,7 +20,7 @@ class HumanInputSurface(StrEnum):
|
||||
|
||||
# SERVICE_API and OPENAPI are intentionally narrower than CONSOLE: token callers
|
||||
# should only be able to act on end-user web forms, not internal console flows.
|
||||
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||
ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
||||
HumanInputSurface.OPENAPI: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
@ -41,7 +41,7 @@ def is_recipient_type_allowed_for_surface(
|
||||
) -> bool:
|
||||
if recipient_type is None:
|
||||
return False
|
||||
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||
return recipient_type in ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||
|
||||
|
||||
def get_preferred_form_token(
|
||||
@ -59,10 +59,39 @@ def get_preferred_form_token(
|
||||
return chosen_token
|
||||
|
||||
|
||||
class FormDisposition(NamedTuple):
|
||||
"""How a paused form resolves for one API surface.
|
||||
|
||||
A form's recipients split into those the surface may act on (yielding a resume
|
||||
`form_token`) and those it may not (their channels named in `approval_channels`
|
||||
so the caller is told where approval actually happens instead).
|
||||
"""
|
||||
|
||||
form_token: str | None
|
||||
approval_channels: list[ApprovalChannel]
|
||||
|
||||
|
||||
def disposition_for_surface(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> FormDisposition:
|
||||
if surface is None:
|
||||
return FormDisposition(form_token=get_preferred_form_token(recipients), approval_channels=[])
|
||||
allowed = ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||
actionable = [(recipient_type, token) for recipient_type, token in recipients if recipient_type in allowed]
|
||||
return FormDisposition(
|
||||
form_token=get_preferred_form_token(actionable),
|
||||
approval_channels=sorted(
|
||||
{recipient_type.approval_channel for recipient_type, _ in recipients if recipient_type not in allowed}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def enrich_human_input_pause_reasons(
|
||||
reasons: Sequence[Mapping[str, Any]],
|
||||
*,
|
||||
form_tokens_by_form_id: Mapping[str, str],
|
||||
dispositions_by_form_id: Mapping[str, FormDisposition],
|
||||
expiration_times_by_form_id: Mapping[str, int],
|
||||
) -> list[dict[str, Any]]:
|
||||
enriched: list[dict[str, Any]] = []
|
||||
@ -71,7 +100,9 @@ def enrich_human_input_pause_reasons(
|
||||
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_id = updated.get("form_id")
|
||||
if isinstance(form_id, str):
|
||||
updated["form_token"] = form_tokens_by_form_id.get(form_id)
|
||||
disposition = dispositions_by_form_id.get(form_id)
|
||||
updated["form_token"] = disposition.form_token if disposition else None
|
||||
updated["approval_channels"] = list(disposition.approval_channels) if disposition else []
|
||||
expiration_time = expiration_times_by_form_id.get(form_id)
|
||||
if expiration_time is not None:
|
||||
updated["expiration_time"] = expiration_time
|
||||
|
||||
@ -25,7 +25,7 @@ from extensions.redis_names import (
|
||||
serialize_redis_name_args,
|
||||
)
|
||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.pubsub_channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||
|
||||
@ -457,16 +457,14 @@ def init_app(app: DifyApp):
|
||||
|
||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||
join_timeout_ms = dify_config.PUBSUB_LISTENER_JOIN_TIMEOUT_MS
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
|
||||
return StreamsBroadcastChannel(
|
||||
_pubsub_redis_client,
|
||||
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
|
||||
join_timeout_ms=join_timeout_ms,
|
||||
)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||
|
||||
|
||||
def redis_fallback[T](default_return: T | None = None): # type: ignore
|
||||
|
||||
@ -291,6 +291,11 @@ class AgentConfigSnapshotListResponse(ResponseModel):
|
||||
data: list[AgentConfigSnapshotSummaryResponse]
|
||||
|
||||
|
||||
class AgentConfigSnapshotRestoreResponse(ResponseModel):
|
||||
result: Literal["success"]
|
||||
active_config_snapshot_id: str
|
||||
|
||||
|
||||
class AgentComposerAgentResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .channel import BroadcastChannel
|
||||
from .pubsub_channel import BroadcastChannel
|
||||
from .sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, Self, override
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
from redis.client import PubSub
|
||||
|
||||
@ -26,8 +27,6 @@ class RedisSubscriptionBase(Subscription):
|
||||
client: Redis | RedisCluster,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._client = client
|
||||
@ -39,11 +38,6 @@ class RedisSubscriptionBase(Subscription):
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
# Max time close() will wait for the listener thread to finish before
|
||||
# returning. Bounds SSE close tail latency. The listener is a daemon
|
||||
# and exits on its own within one poll window (~1s), so a low value
|
||||
# here just means close() returns sooner without breaking anything.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
"""Start the subscription if not already started."""
|
||||
@ -90,6 +84,11 @@ class RedisSubscriptionBase(Subscription):
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
# If close() sent a control event to unblock us, exit immediately
|
||||
# without processing any message — the subscription is shutting down.
|
||||
if self._closed.is_set():
|
||||
break
|
||||
|
||||
if raw_message.get("type") != self._get_message_type():
|
||||
continue
|
||||
|
||||
@ -119,6 +118,8 @@ class RedisSubscriptionBase(Subscription):
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
if payload_bytes == SIG_CLOSE:
|
||||
break
|
||||
|
||||
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
|
||||
try:
|
||||
@ -212,13 +213,16 @@ class RedisSubscriptionBase(Subscription):
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# Send a control event on the same Redis channel to unblock the
|
||||
self._publish_close_event()
|
||||
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
|
||||
# message retrieval method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=self._join_timeout_ms / 1000.0)
|
||||
listener.join(timeout=2)
|
||||
self._listener_thread = None
|
||||
|
||||
# Abstract methods to be implemented by subclasses
|
||||
@ -226,6 +230,15 @@ class RedisSubscriptionBase(Subscription):
|
||||
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _publish_close_event(self) -> None:
|
||||
"""Publish a control event on the Redis channel to unblock the listener.
|
||||
|
||||
This is called by close() after setting _closed. The subclass should
|
||||
publish an empty message on the same topic so that a blocking
|
||||
get_message() call in the listener thread returns promptly.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
"""Subscribe to the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
@ -22,16 +26,11 @@ class BroadcastChannel:
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
# See `RedisSubscriptionBase._join_timeout_ms`: how long close()
|
||||
# waits for the listener thread before returning.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> Topic:
|
||||
return Topic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
|
||||
return Topic(self._client, topic)
|
||||
|
||||
|
||||
class Topic:
|
||||
@ -39,13 +38,10 @@ class Topic:
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
@ -61,7 +57,6 @@ class Topic:
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._redis_topic,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
@ -72,6 +67,13 @@ class _RedisSubscription(RedisSubscriptionBase):
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "regular"
|
||||
|
||||
@override
|
||||
def _publish_close_event(self) -> None:
|
||||
try:
|
||||
self._client.publish(self._topic, SIG_CLOSE)
|
||||
except Exception:
|
||||
logger.exception("failed to publish close event")
|
||||
|
||||
@override
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ShardedRedisBroadcastChannel:
|
||||
"""
|
||||
@ -20,14 +24,11 @@ class ShardedRedisBroadcastChannel:
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> ShardedTopic:
|
||||
return ShardedTopic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
|
||||
return ShardedTopic(self._client, topic)
|
||||
|
||||
|
||||
class ShardedTopic:
|
||||
@ -35,13 +36,10 @@ class ShardedTopic:
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
@ -57,7 +55,6 @@ class ShardedTopic:
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._redis_topic,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
@ -68,6 +65,13 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "sharded"
|
||||
|
||||
@override
|
||||
def _publish_close_event(self) -> None:
|
||||
try:
|
||||
self._client.spublish(self._topic, SIG_CLOSE) # type: ignore[attr-defined,union-attr]
|
||||
except Exception:
|
||||
logger.exception("failed to publish close event")
|
||||
|
||||
@override
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Self, override
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -29,20 +30,15 @@ class StreamsBroadcastChannel:
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
retention_seconds: int = 600,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._retention_seconds = max(int(retention_seconds or 0), 0)
|
||||
# Max time close() will wait for the listener thread to finish.
|
||||
# See `_StreamsSubscription._join_timeout_ms` for the rationale.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> StreamsTopic:
|
||||
return StreamsTopic(
|
||||
self._client,
|
||||
topic,
|
||||
retention_seconds=self._retention_seconds,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
@ -53,13 +49,11 @@ class StreamsTopic:
|
||||
topic: str,
|
||||
*,
|
||||
retention_seconds: int = 600,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._key = serialize_redis_name(f"stream:{topic}")
|
||||
self._retention_seconds = retention_seconds
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
self.max_length = 5000
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
@ -77,23 +71,15 @@ class StreamsTopic:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _StreamsSubscription(self._client, self._key, join_timeout_ms=self._join_timeout_ms)
|
||||
return _StreamsSubscription(self._client, self._key)
|
||||
|
||||
|
||||
class _StreamsSubscription(Subscription):
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(self, client: Redis | RedisCluster, key: str, *, join_timeout_ms: int = 2000):
|
||||
def __init__(self, client: Redis | RedisCluster, key: str):
|
||||
self._client = client
|
||||
self._key = key
|
||||
# Max time close() will wait for the listener thread to finish before
|
||||
# returning. Bounds SSE close tail latency: the listener blocks on
|
||||
# XREAD with BLOCK=1000ms, so close() naturally waits up to ~1s for
|
||||
# the thread to notice _closed. Setting this lower lets close()
|
||||
# return promptly while the daemon listener exits on its own within
|
||||
# one BLOCK window - safe because the listener holds no critical
|
||||
# state. ``0`` means close() does not wait at all.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
self._queue: queue.Queue[object] = queue.Queue()
|
||||
|
||||
@ -106,7 +92,6 @@ class _StreamsSubscription(Subscription):
|
||||
# reading and writing the _listener / `_closed` attribute.
|
||||
self._lock = threading.Lock()
|
||||
self._closed: bool = False
|
||||
# self._closed = threading.Event()
|
||||
self._listener: threading.Thread | None = None
|
||||
|
||||
def _listen(self) -> None:
|
||||
@ -144,6 +129,8 @@ class _StreamsSubscription(Subscription):
|
||||
case bytes() | bytearray():
|
||||
data_bytes = bytes(data)
|
||||
if data_bytes is not None:
|
||||
if data_bytes == SIG_CLOSE:
|
||||
break
|
||||
self._queue.put_nowait(data_bytes)
|
||||
last_id = entry_id
|
||||
finally:
|
||||
@ -203,6 +190,13 @@ class _StreamsSubscription(Subscription):
|
||||
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
|
||||
return bytes(item)
|
||||
|
||||
def _publish_close_event(self) -> None:
|
||||
"""Publish an empty message to the stream to unblock the listener's xread."""
|
||||
try:
|
||||
self._client.xadd(self._key, {b"data": SIG_CLOSE})
|
||||
except Exception:
|
||||
logger.exception("failed to publish close event")
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
@ -212,16 +206,17 @@ class _StreamsSubscription(Subscription):
|
||||
listener = self._listener
|
||||
if listener is not None:
|
||||
self._listener = None
|
||||
# We close the listener outside of the with block to avoid holding the
|
||||
# lock for a long time.
|
||||
|
||||
if listener is not None:
|
||||
self._publish_close_event()
|
||||
|
||||
if listener is not None and listener.is_alive():
|
||||
listener.join(timeout=self._join_timeout_ms / 1000.0)
|
||||
listener.join(timeout=2)
|
||||
if listener.is_alive():
|
||||
logger.debug(
|
||||
"Streams subscription listener for key %s did not stop within %dms; "
|
||||
"Streams subscription listener for key %s did not stop after join; "
|
||||
"daemon thread will exit on its own within one poll window.",
|
||||
self._key,
|
||||
self._join_timeout_ms,
|
||||
)
|
||||
|
||||
# Context manager helpers
|
||||
|
||||
1
api/libs/broadcast_channel/signals.py
Normal file
1
api/libs/broadcast_channel/signals.py
Normal file
@ -0,0 +1 @@
|
||||
SIG_CLOSE = b"__closed__"
|
||||
@ -0,0 +1,39 @@
|
||||
"""agent drive skill metadata refactor
|
||||
|
||||
Revision ID: b2515f9d4c2a
|
||||
Revises: 4f7b2c8d9a10
|
||||
Create Date: 2026-06-18 23:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b2515f9d4c2a"
|
||||
down_revision = "4f7b2c8d9a10"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"agent_drive_files",
|
||||
sa.Column("is_skill", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_drive_files",
|
||||
sa.Column("skill_metadata", sa.Text().with_variant(mysql.LONGTEXT(), "mysql"), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"agent_drive_files_tenant_agent_is_skill_key_idx",
|
||||
"agent_drive_files",
|
||||
["tenant_id", "agent_id", "is_skill", "key"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("agent_drive_files_tenant_agent_is_skill_key_idx", table_name="agent_drive_files")
|
||||
op.drop_column("agent_drive_files", "skill_metadata")
|
||||
op.drop_column("agent_drive_files", "is_skill")
|
||||
@ -0,0 +1,66 @@
|
||||
"""add agent debug conversations
|
||||
|
||||
Revision ID: c8f4a6b2d3e1
|
||||
Revises: b2515f9d4c2a
|
||||
Create Date: 2026-06-22 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c8f4a6b2d3e1"
|
||||
down_revision = "b2515f9d4c2a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _is_pg(conn) -> bool:
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
def _uuid_column(name: str, *, nullable: bool = False, primary_key: bool = False) -> sa.Column:
|
||||
kwargs = {"nullable": nullable, "primary_key": primary_key}
|
||||
if primary_key and _is_pg(op.get_bind()):
|
||||
kwargs["server_default"] = sa.text("uuidv7()")
|
||||
return sa.Column(name, models.types.StringUUID(), **kwargs)
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"agent_debug_conversations",
|
||||
_uuid_column("id", primary_key=True),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("agent_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("account_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("conversation_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("agent_debug_conversation_pkey")),
|
||||
sa.UniqueConstraint(
|
||||
"tenant_id",
|
||||
"agent_id",
|
||||
"account_id",
|
||||
name=op.f("agent_debug_conversation_agent_account_unique"),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"agent_debug_conversation_conversation_idx",
|
||||
"agent_debug_conversations",
|
||||
["conversation_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"agent_debug_conversation_account_idx",
|
||||
"agent_debug_conversations",
|
||||
["tenant_id", "account_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("agent_debug_conversation_account_idx", table_name="agent_debug_conversations")
|
||||
op.drop_index("agent_debug_conversation_conversation_idx", table_name="agent_debug_conversations")
|
||||
op.drop_table("agent_debug_conversations")
|
||||
@ -13,6 +13,7 @@ from .agent import (
|
||||
AgentConfigRevision,
|
||||
AgentConfigRevisionOperation,
|
||||
AgentConfigSnapshot,
|
||||
AgentDebugConversation,
|
||||
AgentDriveFile,
|
||||
AgentDriveFileKind,
|
||||
AgentIconType,
|
||||
@ -156,6 +157,7 @@ __all__ = [
|
||||
"AgentConfigRevision",
|
||||
"AgentConfigRevisionOperation",
|
||||
"AgentConfigSnapshot",
|
||||
"AgentDebugConversation",
|
||||
"AgentDriveFile",
|
||||
"AgentDriveFileKind",
|
||||
"AgentIconType",
|
||||
|
||||
@ -83,6 +83,8 @@ class AgentConfigRevisionOperation(StrEnum):
|
||||
SAVE_NEW_AGENT = "save_new_agent"
|
||||
# Promotes a workflow-only Agent into the reusable Agent Roster.
|
||||
SAVE_TO_ROSTER = "save_to_roster"
|
||||
# Switches the Agent's current published config back to an existing version.
|
||||
RESTORE_VERSION = "restore_version"
|
||||
|
||||
|
||||
class WorkflowAgentBindingType(StrEnum):
|
||||
@ -180,6 +182,34 @@ class Agent(DefaultFieldsMixin, Base):
|
||||
archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class AgentDebugConversation(DefaultFieldsMixin, Base):
|
||||
"""Per-account console debug conversation for an Agent App.
|
||||
|
||||
Agent App preview state must be isolated by editor account. The Agent row is
|
||||
shared by everyone in the workspace, so this table owns the user-specific
|
||||
conversation pointer used by console debug chat.
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_debug_conversations"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="agent_debug_conversation_pkey"),
|
||||
UniqueConstraint(
|
||||
"tenant_id",
|
||||
"agent_id",
|
||||
"account_id",
|
||||
name="agent_debug_conversation_agent_account_unique",
|
||||
),
|
||||
Index("agent_debug_conversation_conversation_idx", "conversation_id"),
|
||||
Index("agent_debug_conversation_account_idx", "tenant_id", "account_id"),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
agent_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
|
||||
class AgentConfigSnapshot(DefaultFieldsMixin, Base):
|
||||
"""Immutable Agent Soul snapshot.
|
||||
|
||||
@ -430,14 +460,17 @@ class AgentDriveFile(DefaultFieldsMixin, Base):
|
||||
synced. ``value_owned_by_drive`` gates physical cleanup: only drive-owned values
|
||||
(created by the agent runtime or Skill standardization, not shared with other
|
||||
business records) have their storage object + record deleted when the KV entry is
|
||||
overwritten or removed; otherwise only the KV row is dropped. Lifecycle never relies
|
||||
on ``UploadFile.used/used_by`` (not a reliable refcount).
|
||||
overwritten or removed; otherwise only the KV row is dropped. Skills are represented
|
||||
by the canonical ``<path>/SKILL.md`` row with ``is_skill=True`` and a serialized
|
||||
``skill_metadata`` string. Lifecycle never relies on ``UploadFile.used/used_by``
|
||||
(not a reliable refcount).
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_drive_files"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="agent_drive_file_pkey"),
|
||||
UniqueConstraint("tenant_id", "agent_id", "key", name="agent_drive_file_scope_key_unique"),
|
||||
Index("agent_drive_files_tenant_agent_is_skill_key_idx", "tenant_id", "agent_id", "is_skill", "key"),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
@ -453,6 +486,8 @@ class AgentDriveFile(DefaultFieldsMixin, Base):
|
||||
value_owned_by_drive: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, default=False, server_default=sa.text("false")
|
||||
)
|
||||
is_skill: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False, server_default=sa.text("false"))
|
||||
skill_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
size: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True)
|
||||
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
||||
@ -134,20 +134,40 @@ class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
)
|
||||
|
||||
|
||||
class ApprovalChannel(StrEnum):
|
||||
"""Where a paused human input form can be approved, surfaced to API callers."""
|
||||
|
||||
EMAIL = "email"
|
||||
WEB_APP = "web_app"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class RecipientType(StrEnum):
|
||||
# EMAIL_MEMBER member means that the
|
||||
EMAIL_MEMBER = "email_member"
|
||||
EMAIL_EXTERNAL = "email_external"
|
||||
# Second value = the approval channel this recipient maps to (surfaced in `approval_channels`).
|
||||
EMAIL_MEMBER = "email_member", ApprovalChannel.EMAIL
|
||||
EMAIL_EXTERNAL = "email_external", ApprovalChannel.EMAIL
|
||||
# STANDALONE_WEB_APP is used by the standalone web app.
|
||||
#
|
||||
# It's not used while running workflows / chatflows containing HumanInput
|
||||
# node inside console.
|
||||
STANDALONE_WEB_APP = "standalone_web_app"
|
||||
STANDALONE_WEB_APP = "standalone_web_app", ApprovalChannel.WEB_APP
|
||||
# CONSOLE is used while running workflows / chatflows containing HumanInput
|
||||
# node inside console. (E.G. running installed apps or debugging workflows / chatflows)
|
||||
CONSOLE = "console"
|
||||
CONSOLE = "console", ApprovalChannel.CONSOLE
|
||||
# BACKSTAGE is used for backstage input inside console.
|
||||
BACKSTAGE = "backstage"
|
||||
BACKSTAGE = "backstage", ApprovalChannel.CONSOLE
|
||||
|
||||
_approval_channel: ApprovalChannel
|
||||
|
||||
def __new__(cls, value: str, approval_channel: ApprovalChannel) -> "RecipientType":
|
||||
member = str.__new__(cls, value)
|
||||
member._value_ = value
|
||||
member._approval_channel = approval_channel
|
||||
return member
|
||||
|
||||
@property
|
||||
def approval_channel(self) -> ApprovalChannel:
|
||||
return self._approval_channel
|
||||
|
||||
|
||||
@final
|
||||
|
||||
@ -774,26 +774,7 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
if not collection_binding_detail:
|
||||
raise ValueError("Collection binding detail not found")
|
||||
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
|
||||
else:
|
||||
return {"enabled": False}
|
||||
return load_annotation_reply_config(db.session(), self.app_id)
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> EnabledConfig:
|
||||
@ -864,7 +845,7 @@ class AppModelConfig(TypeBase):
|
||||
},
|
||||
)
|
||||
|
||||
def to_dict(self) -> AppModelConfigDict:
|
||||
def to_dict(self, *, annotation_reply: AnnotationReplyConfig | None = None) -> AppModelConfigDict:
|
||||
return {
|
||||
"opening_statement": self.opening_statement,
|
||||
"suggested_questions": self.suggested_questions_list,
|
||||
@ -872,7 +853,7 @@ class AppModelConfig(TypeBase):
|
||||
"speech_to_text": self.speech_to_text_dict,
|
||||
"text_to_speech": self.text_to_speech_dict,
|
||||
"retriever_resource": self.retriever_resource_dict,
|
||||
"annotation_reply": self.annotation_reply_dict,
|
||||
"annotation_reply": annotation_reply if annotation_reply is not None else self.annotation_reply_dict,
|
||||
"more_like_this": self.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
|
||||
"external_data_tools": self.external_data_tools_list,
|
||||
@ -2038,6 +2019,30 @@ class AppAnnotationSetting(TypeBase):
|
||||
)
|
||||
|
||||
|
||||
def load_annotation_reply_config(session: Session, app_id: str) -> AnnotationReplyConfig:
|
||||
annotation_setting = session.scalar(select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id))
|
||||
if annotation_setting is None:
|
||||
return {"enabled": False}
|
||||
|
||||
from .dataset import DatasetCollectionBinding
|
||||
|
||||
collection_binding_detail = session.scalar(
|
||||
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == annotation_setting.collection_binding_id)
|
||||
)
|
||||
if collection_binding_detail is None:
|
||||
raise ValueError("Collection binding detail not found")
|
||||
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OperationLog(TypeBase):
|
||||
__tablename__ = "operation_logs"
|
||||
__table_args__ = (
|
||||
|
||||
@ -391,6 +391,80 @@ Check if activation token is valid
|
||||
| 400 | Invalid request parameters | |
|
||||
| 403 | Insufficient permissions | |
|
||||
|
||||
### [GET] /agent/{agent_id}/api-access
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| agent_id | path | | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Agent service API access | **application/json**: [AgentApiAccessResponse](#agentapiaccessresponse)<br> |
|
||||
|
||||
### [POST] /agent/{agent_id}/api-enable
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| agent_id | path | | Yes | string (uuid) |
|
||||
|
||||
#### Request Body
|
||||
|
||||
| Required | Schema |
|
||||
| -------- | ------ |
|
||||
| Yes | **application/json**: [AgentApiStatusPayload](#agentapistatuspayload)<br> |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Agent service API status updated | **application/json**: [AgentApiAccessResponse](#agentapiaccessresponse)<br> |
|
||||
| 403 | Insufficient permissions | |
|
||||
|
||||
### [GET] /agent/{agent_id}/api-keys
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| agent_id | path | | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Agent service API keys | **application/json**: [ApiKeyList](#apikeylist)<br> |
|
||||
|
||||
### [POST] /agent/{agent_id}/api-keys
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| agent_id | path | | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 201 | Agent service API key created | **application/json**: [ApiKeyItem](#apikeyitem)<br> |
|
||||
| 400 | Maximum keys exceeded | |
|
||||
|
||||
### [DELETE] /agent/{agent_id}/api-keys/{api_key_id}
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| agent_id | path | | Yes | string (uuid) |
|
||||
| api_key_id | path | | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description |
|
||||
| ---- | ----------- |
|
||||
| 204 | Agent service API key deleted |
|
||||
|
||||
### [GET] /agent/{agent_id}/chat-messages
|
||||
Get Agent App chat messages for a conversation with pagination
|
||||
|
||||
@ -576,6 +650,37 @@ Truncated text preview of one Agent App drive value
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Preview | **application/json**: [AgentDrivePreviewResponse](#agentdrivepreviewresponse)<br> |
|
||||
|
||||
### [GET] /agent/{agent_id}/drive/skills
|
||||
List drive-backed skills for an Agent App
|
||||
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| agent_id | path | Agent ID | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Drive skills | **application/json**: [AgentDriveSkillListResponse](#agentdriveskilllistresponse)<br> |
|
||||
|
||||
### [GET] /agent/{agent_id}/drive/skills/{skill_path}/inspect
|
||||
Inspect one drive-backed skill for slash-menu hover/detail UI
|
||||
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| agent_id | path | Agent ID | Yes | string (uuid) |
|
||||
| skill_path | path | Skill path/slug, e.g. tender-analyzer | Yes | string |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Drive skill inspect view | **application/json**: [AgentDriveSkillInspectResponse](#agentdriveskillinspectresponse)<br> |
|
||||
|
||||
### [POST] /agent/{agent_id}/features
|
||||
Update an Agent App's presentation features (opener, follow-up, citations, ...)
|
||||
|
||||
@ -905,6 +1010,20 @@ Infer CLI tool + ENV suggestions from a standardized Agent App skill
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Agent version detail | **application/json**: [AgentConfigSnapshotDetailResponse](#agentconfigsnapshotdetailresponse)<br> |
|
||||
|
||||
### [POST] /agent/{agent_id}/versions/{version_id}/restore
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| agent_id | path | | Yes | string (uuid) |
|
||||
| version_id | path | | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Agent version restored | **application/json**: [AgentConfigSnapshotRestoreResponse](#agentconfigsnapshotrestoreresponse)<br> |
|
||||
|
||||
### [GET] /all-workspaces
|
||||
#### Parameters
|
||||
|
||||
@ -1454,6 +1573,40 @@ Truncated text preview of one drive value (binary-safe; SKILL.md is the main cas
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Preview | **application/json**: [AgentDrivePreviewResponse](#agentdrivepreviewresponse)<br> |
|
||||
|
||||
### [GET] /apps/{app_id}/agent/drive/skills
|
||||
List drive-backed skills for the bound agent
|
||||
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| app_id | path | Application ID | Yes | string (uuid) |
|
||||
| node_id | query | Workflow node ID (workflow composer variant) | No | string |
|
||||
| prefix | query | Key prefix filter: '<slug>/' for one skill, 'files/' for files | No | string |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Drive skills | **application/json**: [AgentDriveSkillListResponse](#agentdriveskilllistresponse)<br> |
|
||||
|
||||
### [GET] /apps/{app_id}/agent/drive/skills/{skill_path}/inspect
|
||||
Inspect one drive-backed skill for slash-menu hover/detail UI
|
||||
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| app_id | path | Application ID | Yes | string (uuid) |
|
||||
| skill_path | path | Skill path/slug, e.g. tender-analyzer | Yes | string |
|
||||
| node_id | query | Workflow node ID (workflow composer variant) | No | string |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Drive skill inspect view | **application/json**: [AgentDriveSkillInspectResponse](#agentdriveskillinspectresponse)<br> |
|
||||
|
||||
### [DELETE] /apps/{app_id}/agent/files
|
||||
Delete one drive file by key; soul ref first, then the KV row (ENG-625 D5)
|
||||
|
||||
@ -11954,6 +12107,31 @@ Default namespace
|
||||
| chat_prompt_config | object | | No |
|
||||
| completion_prompt_config | object | | No |
|
||||
|
||||
#### AgentApiAccessResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| api_key_count | integer | | Yes |
|
||||
| api_rph | integer | | Yes |
|
||||
| api_rpm | integer | | Yes |
|
||||
| chat_endpoint | string | | Yes |
|
||||
| conversations_endpoint | string | | Yes |
|
||||
| enabled | boolean | | Yes |
|
||||
| files_upload_endpoint | string | | Yes |
|
||||
| info_endpoint | string | | Yes |
|
||||
| messages_endpoint | string | | Yes |
|
||||
| meta_endpoint | string | | Yes |
|
||||
| parameters_endpoint | string | | Yes |
|
||||
| service_api_base_url | string | | Yes |
|
||||
| stop_endpoint | string | | Yes |
|
||||
| streaming_only | boolean, <br>**Default:** true | | No |
|
||||
|
||||
#### AgentApiStatusPayload
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| enable_api | boolean | Enable or disable Agent service API | Yes |
|
||||
|
||||
#### AgentAppComposerResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
@ -11987,6 +12165,7 @@ Default namespace
|
||||
| bound_agent_id | string | | No |
|
||||
| created_at | integer | | No |
|
||||
| created_by | string | | No |
|
||||
| debug_conversation_id | string | | No |
|
||||
| deleted_tools | [ [DeletedTool](#deletedtool) ] | | No |
|
||||
| description | string | | No |
|
||||
| enable_api | boolean | | Yes |
|
||||
@ -12050,6 +12229,7 @@ default (the config form sends the full desired feature state on save).
|
||||
| create_user_name | string | | No |
|
||||
| created_at | integer | | No |
|
||||
| created_by | string | | No |
|
||||
| debug_conversation_id | string | | No |
|
||||
| description | string | | No |
|
||||
| has_draft_trigger | boolean | | No |
|
||||
| icon | string | | No |
|
||||
@ -12340,6 +12520,13 @@ Audit operation recorded for Agent Soul version/revision changes.
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| data | [ [AgentConfigSnapshotSummaryResponse](#agentconfigsnapshotsummaryresponse) ] | | Yes |
|
||||
|
||||
#### AgentConfigSnapshotRestoreResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| active_config_snapshot_id | string | | Yes |
|
||||
| result | string | | Yes |
|
||||
|
||||
#### AgentConfigSnapshotSummaryResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
@ -12425,9 +12612,11 @@ Audit operation recorded for Agent Soul version/revision changes.
|
||||
| created_at | integer | | No |
|
||||
| file_kind | string | | Yes |
|
||||
| hash | string | | No |
|
||||
| is_skill | boolean | | No |
|
||||
| key | string | | Yes |
|
||||
| mime_type | string | | No |
|
||||
| size | integer | | No |
|
||||
| skill_metadata | string | | No |
|
||||
|
||||
#### AgentDriveListResponse
|
||||
|
||||
@ -12445,6 +12634,65 @@ Audit operation recorded for Agent Soul version/revision changes.
|
||||
| text | string | | No |
|
||||
| truncated | boolean | | Yes |
|
||||
|
||||
#### AgentDriveSkillFileResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| available_in_drive | boolean | | Yes |
|
||||
| drive_key | string | | No |
|
||||
| name | string | | Yes |
|
||||
| path | string | | Yes |
|
||||
| type | string | | Yes |
|
||||
|
||||
#### AgentDriveSkillInspectResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| archive_key | string | | No |
|
||||
| created_at | integer | | No |
|
||||
| description | string | | Yes |
|
||||
| file_tree | [ object ] | | No |
|
||||
| files | [ [AgentDriveSkillFileResponse](#agentdriveskillfileresponse) ] | | No |
|
||||
| hash | string | | No |
|
||||
| mime_type | string | | No |
|
||||
| name | string | | Yes |
|
||||
| path | string | | Yes |
|
||||
| size | integer | | No |
|
||||
| skill_md | [AgentDriveSkillMarkdownResponse](#agentdriveskillmarkdownresponse) | | Yes |
|
||||
| skill_md_key | string | | Yes |
|
||||
| source | string | | Yes |
|
||||
| warnings | [ string ] | | No |
|
||||
|
||||
#### AgentDriveSkillItemResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| archive_key | string | | No |
|
||||
| created_at | integer | | No |
|
||||
| description | string | | Yes |
|
||||
| hash | string | | No |
|
||||
| mime_type | string | | No |
|
||||
| name | string | | Yes |
|
||||
| path | string | | Yes |
|
||||
| size | integer | | No |
|
||||
| skill_md_key | string | | Yes |
|
||||
|
||||
#### AgentDriveSkillListResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| items | [ [AgentDriveSkillItemResponse](#agentdriveskillitemresponse) ] | | No |
|
||||
|
||||
#### AgentDriveSkillMarkdownResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| binary | boolean | | Yes |
|
||||
| key | string | | Yes |
|
||||
| size | integer | | No |
|
||||
| text | string | | No |
|
||||
| truncated | boolean | | Yes |
|
||||
|
||||
#### AgentEnvVariableConfig
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
|
||||
@ -83,7 +83,6 @@ User-scoped operations
|
||||
| mode | query | | No | string, <br>**Available values:** "advanced-chat", "agent", "agent-chat", "channel", "chat", "completion", "rag-pipeline", "workflow" |
|
||||
| name | query | | No | string |
|
||||
| page | query | | No | integer, <br>**Default:** 1 |
|
||||
| tag | query | | No | string |
|
||||
| workspace_id | query | | Yes | string |
|
||||
|
||||
#### Responses
|
||||
@ -331,6 +330,22 @@ Upload a file to use as an input variable when running the app
|
||||
| 422 | Validation error | **application/json**: [ErrorBody](#errorbody)<br> |
|
||||
| default | Error | **application/json**: [ErrorBody](#errorbody)<br> |
|
||||
|
||||
### [GET] /permitted-external-apps/{app_id}/describe
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| fields | query | | No | string |
|
||||
| app_id | path | | Yes | string |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Permitted external app description | **application/json**: [AppDescribeResponse](#appdescriberesponse)<br> |
|
||||
| 422 | Validation error | **application/json**: [ErrorBody](#errorbody)<br> |
|
||||
| default | Error | **application/json**: [ErrorBody](#errorbody)<br> |
|
||||
|
||||
### [GET] /workspaces
|
||||
#### Responses
|
||||
|
||||
@ -507,14 +522,12 @@ Upload a file to use as an input variable when running the app
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| author | string | | No |
|
||||
| description | string | | No |
|
||||
| id | string | | Yes |
|
||||
| is_agent | boolean | | No |
|
||||
| mode | string | | Yes |
|
||||
| name | string | | Yes |
|
||||
| service_api_enabled | boolean | | Yes |
|
||||
| tags | [ [TagItem](#tagitem) ], <br>**Default:** | | No |
|
||||
| updated_at | string | | No |
|
||||
|
||||
#### AppDescribeQuery
|
||||
@ -568,16 +581,14 @@ Request body for POST /workspaces/<workspace_id>/apps/imports.
|
||||
| yaml_content | string | Inline YAML DSL string (required when mode is yaml-content) | No |
|
||||
| yaml_url | string | Remote URL to fetch YAML from (required when mode is yaml-url) | No |
|
||||
|
||||
#### AppInfoResponse
|
||||
#### AppInfo
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| author | string | | No |
|
||||
| description | string | | No |
|
||||
| id | string | | Yes |
|
||||
| mode | string | | Yes |
|
||||
| name | string | | Yes |
|
||||
| tags | [ [TagItem](#tagitem) ], <br>**Default:** | | No |
|
||||
|
||||
#### AppListQuery
|
||||
|
||||
@ -589,7 +600,6 @@ mode is a closed enum.
|
||||
| mode | [AppMode](#appmode) | | No |
|
||||
| name | string | | No |
|
||||
| page | integer, <br>**Default:** 1 | | No |
|
||||
| tag | string | | No |
|
||||
| workspace_id | string | | Yes |
|
||||
|
||||
#### AppListResponse
|
||||
@ -606,12 +616,10 @@ mode is a closed enum.
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| created_by_name | string | | No |
|
||||
| description | string | | No |
|
||||
| id | string | | Yes |
|
||||
| mode | [AppMode](#appmode) | | Yes |
|
||||
| name | string | | Yes |
|
||||
| tags | [ [TagItem](#tagitem) ], <br>**Default:** | | No |
|
||||
| updated_at | string | | No |
|
||||
| workspace_id | string | | No |
|
||||
| workspace_name | string | | No |
|
||||
@ -982,12 +990,6 @@ Pagination for GET /account/sessions. Strict (extra='forbid').
|
||||
| last_used_at | string | | No |
|
||||
| prefix | string | | Yes |
|
||||
|
||||
#### TagItem
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| name | string | | Yes |
|
||||
|
||||
#### TaskStopResponse
|
||||
|
||||
200 body for POST /apps/<id>/tasks/<task_id>/stop. The handler always returns
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -142,9 +143,8 @@ class TestTraceClient:
|
||||
mock_notify.assert_called_once()
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.logger")
|
||||
def test_add_span_queue_full(
|
||||
self, mock_logger: MagicMock, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]
|
||||
self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient], caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
|
||||
|
||||
@ -164,12 +164,15 @@ class TestTraceClient:
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 1
|
||||
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 1
|
||||
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
|
||||
with caplog.at_level(logging.WARNING):
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 1
|
||||
assert "Queue is full, likely spans will be dropped." in caplog.text
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export_batch_error(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
|
||||
def test_export_batch_error(
|
||||
self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient], caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
mock_exporter.export.side_effect = Exception("Export failed")
|
||||
|
||||
@ -177,9 +180,9 @@ class TestTraceClient:
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.queue.append(mock_span)
|
||||
|
||||
with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
client._export_batch()
|
||||
mock_logger.warning.assert_called()
|
||||
assert "Error exporting spans" in caplog.text
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_worker_loop(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
|
||||
|
||||
@ -307,13 +307,12 @@ class TestGetProjectUrl:
|
||||
monkeypatch.setattr(trace_instance, "entity", None)
|
||||
monkeypatch.setattr(trace_instance, "project_name", None)
|
||||
# Force an error by making string formatting fail
|
||||
with patch("dify_trace_weave.weave_trace.logger") as mock_logger:
|
||||
# Simulate exception via property
|
||||
original_entity = trace_instance.entity
|
||||
trace_instance.entity = None
|
||||
trace_instance.project_name = None
|
||||
url = trace_instance.get_project_url()
|
||||
assert "https://wandb.ai/" in url
|
||||
# Simulate exception via property
|
||||
original_entity = trace_instance.entity
|
||||
trace_instance.entity = None
|
||||
trace_instance.project_name = None
|
||||
url = trace_instance.get_project_url()
|
||||
assert "https://wandb.ai/" in url
|
||||
|
||||
|
||||
# ── TestTraceDispatcher ─────────────────────────────────────────────────────
|
||||
|
||||
@ -830,6 +830,16 @@ class AgentComposerService:
|
||||
) -> WorkflowAgentNodeBinding:
|
||||
node_job = payload.node_job or WorkflowNodeJobConfig()
|
||||
if binding:
|
||||
if cls._is_start_from_scratch_request(binding=binding, payload=payload):
|
||||
return cls._switch_roster_binding_to_inline_agent(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
binding=binding,
|
||||
payload=payload,
|
||||
)
|
||||
binding.node_job_config = node_job
|
||||
if payload.agent_soul is not None and binding.binding_type == WorkflowAgentBindingType.INLINE_AGENT:
|
||||
current_snapshot = cls._require_version(
|
||||
@ -880,6 +890,46 @@ class AgentComposerService:
|
||||
db.session.flush()
|
||||
return binding
|
||||
|
||||
@classmethod
|
||||
def _is_start_from_scratch_request(cls, *, binding: WorkflowAgentNodeBinding, payload: ComposerSavePayload) -> bool:
|
||||
return (
|
||||
binding.binding_type == WorkflowAgentBindingType.ROSTER_AGENT
|
||||
and payload.binding is not None
|
||||
and payload.binding.binding_type == WorkflowAgentBindingType.INLINE_AGENT.value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _switch_roster_binding_to_inline_agent(
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
account_id: str,
|
||||
binding: WorkflowAgentNodeBinding,
|
||||
payload: ComposerSavePayload,
|
||||
) -> WorkflowAgentNodeBinding:
|
||||
if payload.binding and (payload.binding.agent_id or payload.binding.current_snapshot_id):
|
||||
raise ValueError("Start from Scratch must not provide an existing inline agent binding.")
|
||||
|
||||
agent_soul = payload.agent_soul or AgentSoulConfig()
|
||||
agent = cls._create_workflow_only_agent(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
agent_soul=agent_soul,
|
||||
)
|
||||
binding.binding_type = WorkflowAgentBindingType.INLINE_AGENT
|
||||
binding.agent_id = agent.id
|
||||
binding.current_snapshot_id = agent.active_config_snapshot_id
|
||||
binding.node_job_config = payload.node_job or binding.node_job_config
|
||||
binding.updated_by = account_id
|
||||
db.session.flush()
|
||||
return binding
|
||||
|
||||
@classmethod
|
||||
def _save_to_current_version(
|
||||
cls,
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, TypedDict
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from models.agent import (
|
||||
@ -10,6 +11,7 @@ from models.agent import (
|
||||
AgentConfigRevision,
|
||||
AgentConfigRevisionOperation,
|
||||
AgentConfigSnapshot,
|
||||
AgentDebugConversation,
|
||||
AgentKind,
|
||||
AgentScope,
|
||||
AgentSource,
|
||||
@ -18,8 +20,8 @@ from models.agent import (
|
||||
WorkflowAgentNodeBinding,
|
||||
)
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from models.enums import AppStatus
|
||||
from models.model import App, AppMode, IconType
|
||||
from models.enums import AppStatus, ConversationFromSource, ConversationStatus
|
||||
from models.model import App, AppMode, Conversation, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.agent.agent_soul_state import agent_soul_has_model
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
@ -96,6 +98,7 @@ class AgentRosterService:
|
||||
"scope": agent.scope.value,
|
||||
"source": agent.source.value,
|
||||
"app_id": agent.app_id,
|
||||
"debug_conversation_id": None,
|
||||
"workflow_id": agent.workflow_id,
|
||||
"workflow_node_id": agent.workflow_node_id,
|
||||
"active_config_snapshot_id": agent.active_config_snapshot_id,
|
||||
@ -392,8 +395,126 @@ class AgentRosterService:
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(AgentSoulConfig())
|
||||
self._session.flush()
|
||||
self._get_or_create_agent_app_debug_conversation(agent=agent, account_id=account_id)
|
||||
return agent
|
||||
|
||||
def _create_agent_app_debug_conversation(self, *, app_id: str, account_id: str) -> str:
|
||||
"""Create one console debug conversation for an Agent App editor."""
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
mode=AppMode.AGENT,
|
||||
name="Agent Debugging Conversation",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status=ConversationStatus.NORMAL,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
self._session.add(conversation)
|
||||
self._session.flush()
|
||||
return conversation.id
|
||||
|
||||
def _get_or_create_agent_app_debug_conversation(self, *, agent: Agent, account_id: str) -> str:
|
||||
if not agent.app_id:
|
||||
raise AgentNotFoundError()
|
||||
|
||||
mapping = self._session.scalar(
|
||||
select(AgentDebugConversation).where(
|
||||
AgentDebugConversation.tenant_id == agent.tenant_id,
|
||||
AgentDebugConversation.agent_id == agent.id,
|
||||
AgentDebugConversation.account_id == account_id,
|
||||
)
|
||||
)
|
||||
if mapping is not None:
|
||||
conversation_id = self._session.scalar(
|
||||
select(Conversation.id).where(
|
||||
Conversation.id == mapping.conversation_id,
|
||||
Conversation.app_id == agent.app_id,
|
||||
Conversation.from_source == ConversationFromSource.CONSOLE,
|
||||
Conversation.from_account_id == account_id,
|
||||
Conversation.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
if conversation_id:
|
||||
return conversation_id
|
||||
|
||||
mapping.conversation_id = self._create_agent_app_debug_conversation(
|
||||
app_id=agent.app_id,
|
||||
account_id=account_id,
|
||||
)
|
||||
self._session.flush()
|
||||
return mapping.conversation_id
|
||||
|
||||
conversation_id = self._create_agent_app_debug_conversation(
|
||||
app_id=agent.app_id,
|
||||
account_id=account_id,
|
||||
)
|
||||
self._session.add(
|
||||
AgentDebugConversation(
|
||||
tenant_id=agent.tenant_id,
|
||||
agent_id=agent.id,
|
||||
app_id=agent.app_id,
|
||||
account_id=account_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
)
|
||||
self._session.flush()
|
||||
return conversation_id
|
||||
|
||||
def get_or_create_agent_app_debug_conversation_id(
|
||||
self, *, tenant_id: str, agent_id: str, account_id: str, commit: bool = True
|
||||
) -> str:
|
||||
"""Return the current editor's debug conversation for an Agent App."""
|
||||
|
||||
agent = self._session.scalar(
|
||||
select(Agent).where(
|
||||
Agent.tenant_id == tenant_id,
|
||||
Agent.id == agent_id,
|
||||
Agent.scope == AgentScope.ROSTER,
|
||||
Agent.source == AgentSource.AGENT_APP,
|
||||
Agent.status == AgentStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
if agent is None:
|
||||
raise AgentNotFoundError()
|
||||
|
||||
conversation_id = self._get_or_create_agent_app_debug_conversation(agent=agent, account_id=account_id)
|
||||
if commit:
|
||||
self._session.commit()
|
||||
return conversation_id
|
||||
|
||||
def load_or_create_agent_app_debug_conversation_ids_by_agent_id(
|
||||
self, *, tenant_id: str, agents: list[Agent], account_id: str
|
||||
) -> dict[str, str]:
|
||||
"""Return per-account debug conversations for a page of Agent Apps."""
|
||||
|
||||
conversation_ids_by_agent_id: dict[str, str] = {}
|
||||
changed = False
|
||||
for agent in agents:
|
||||
if (
|
||||
agent.tenant_id != tenant_id
|
||||
or agent.scope != AgentScope.ROSTER
|
||||
or agent.source != AgentSource.AGENT_APP
|
||||
):
|
||||
continue
|
||||
conversation_ids_by_agent_id[agent.id] = self._get_or_create_agent_app_debug_conversation(
|
||||
agent=agent,
|
||||
account_id=account_id,
|
||||
)
|
||||
changed = True
|
||||
if changed:
|
||||
self._session.commit()
|
||||
return conversation_ids_by_agent_id
|
||||
|
||||
def load_app_backing_agents_by_app_id(self, *, tenant_id: str, app_ids: list[str]) -> dict[str, Agent]:
|
||||
"""Return active app-backed Agents keyed by Agent App id."""
|
||||
if not app_ids:
|
||||
@ -666,12 +787,16 @@ class AgentRosterService:
|
||||
@staticmethod
|
||||
def _visible_version_operations(agent: Agent) -> set[AgentConfigRevisionOperation]:
|
||||
if agent.source == AgentSource.AGENT_APP:
|
||||
return {AgentConfigRevisionOperation.SAVE_NEW_VERSION}
|
||||
return {
|
||||
AgentConfigRevisionOperation.SAVE_NEW_VERSION,
|
||||
AgentConfigRevisionOperation.RESTORE_VERSION,
|
||||
}
|
||||
return {
|
||||
AgentConfigRevisionOperation.CREATE_VERSION,
|
||||
AgentConfigRevisionOperation.SAVE_NEW_VERSION,
|
||||
AgentConfigRevisionOperation.SAVE_NEW_AGENT,
|
||||
AgentConfigRevisionOperation.SAVE_TO_ROSTER,
|
||||
AgentConfigRevisionOperation.RESTORE_VERSION,
|
||||
}
|
||||
|
||||
def active_config_is_published(self, *, tenant_id: str, agent: Agent) -> bool:
|
||||
@ -764,6 +889,46 @@ class AgentRosterService:
|
||||
]
|
||||
return result
|
||||
|
||||
def restore_agent_version(
|
||||
self, *, tenant_id: str, agent_id: str, version_id: str, account_id: str
|
||||
) -> dict[str, Any]:
|
||||
agent = self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
|
||||
visible_version_ids = self._visible_version_ids_stmt(tenant_id=tenant_id, agent_id=agent_id, agent=agent)
|
||||
visible_version_id = self._session.scalar(
|
||||
select(AgentConfigSnapshot.id)
|
||||
.where(
|
||||
AgentConfigSnapshot.tenant_id == tenant_id,
|
||||
AgentConfigSnapshot.agent_id == agent_id,
|
||||
AgentConfigSnapshot.id == version_id,
|
||||
AgentConfigSnapshot.id.in_(select(visible_version_ids.c.current_snapshot_id)),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not visible_version_id:
|
||||
raise AgentVersionNotFoundError()
|
||||
|
||||
version = self._get_version(tenant_id=tenant_id, agent_id=agent_id, version_id=version_id)
|
||||
if agent.active_config_snapshot_id == version.id:
|
||||
return {"result": "success", "active_config_snapshot_id": version.id}
|
||||
|
||||
previous_snapshot_id = agent.active_config_snapshot_id
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(version.config_snapshot)
|
||||
agent.updated_by = account_id
|
||||
self._session.add(
|
||||
AgentConfigRevision(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
previous_snapshot_id=previous_snapshot_id,
|
||||
current_snapshot_id=version.id,
|
||||
revision=self._next_revision(tenant_id=tenant_id, agent_id=agent_id),
|
||||
operation=AgentConfigRevisionOperation.RESTORE_VERSION,
|
||||
created_by=account_id,
|
||||
)
|
||||
)
|
||||
self._session.commit()
|
||||
return {"result": "success", "active_config_snapshot_id": version.id}
|
||||
|
||||
def _get_agent(self, *, tenant_id: str, agent_id: str, roster_only: bool = False) -> Agent:
|
||||
stmt = select(Agent).where(Agent.tenant_id == tenant_id, Agent.id == agent_id)
|
||||
if roster_only:
|
||||
@ -789,6 +954,17 @@ class AgentRosterService:
|
||||
raise AgentVersionNotFoundError()
|
||||
return version
|
||||
|
||||
def _next_revision(self, *, tenant_id: str, agent_id: str) -> int:
|
||||
return (
|
||||
self._session.scalar(
|
||||
select(func.max(AgentConfigRevision.revision)).where(
|
||||
AgentConfigRevision.tenant_id == tenant_id,
|
||||
AgentConfigRevision.agent_id == agent_id,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
) + 1
|
||||
|
||||
def _load_published_active_snapshot_agent_ids(self, *, tenant_id: str, agents: list[Agent]) -> set[str]:
|
||||
predicates = [
|
||||
and_(
|
||||
|
||||
@ -23,7 +23,7 @@ from typing import Any
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models.agent_config_entities import AgentSkillRefConfig
|
||||
from services.agent.skill_package_service import SkillPackageService
|
||||
from services.agent_drive_service import AgentDriveService, DriveCommitItem, DriveFileRef
|
||||
from services.agent_drive_service import AgentDriveService, DriveCommitItem, DriveFileRef, DriveSkillMetadata
|
||||
|
||||
_FULL_ARCHIVE_NAME = ".DIFY-SKILL-FULL.zip"
|
||||
_SKILL_MD_NAME = "SKILL.md"
|
||||
@ -91,6 +91,12 @@ class SkillStandardizeService:
|
||||
key=skill_md_key,
|
||||
file_ref=DriveFileRef(kind="tool_file", id=md_tool_file.id),
|
||||
value_owned_by_drive=True,
|
||||
is_skill=True,
|
||||
skill_metadata=DriveSkillMetadata(
|
||||
name=manifest.name,
|
||||
description=manifest.description,
|
||||
manifest_files=manifest.files,
|
||||
),
|
||||
),
|
||||
DriveCommitItem(
|
||||
key=archive_key,
|
||||
|
||||
@ -17,12 +17,14 @@ ToolFile records (see ``AgentDriveFile``). This service is the control plane:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, TypedDict
|
||||
from urllib.parse import unquote
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import DataError, SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
@ -41,6 +43,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_KEY_LENGTH = 512
|
||||
_DRIVE_REF_PREFIX = "agent-"
|
||||
_SKILL_MD_SUFFIX = "/SKILL.md"
|
||||
_SKILL_ARCHIVE_NAME = ".DIFY-SKILL-FULL.zip"
|
||||
|
||||
|
||||
class AgentDriveError(Exception):
|
||||
@ -58,16 +62,86 @@ class AgentDriveError(Exception):
|
||||
|
||||
|
||||
class DriveFileRef(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
kind: Literal["upload_file", "tool_file"]
|
||||
id: str
|
||||
|
||||
|
||||
class DriveSkillMetadata(BaseModel):
|
||||
"""Validated skill catalog metadata stored as a JSON string on the drive row."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
# Safe archive member paths captured during skill standardization. The drive
|
||||
# stores only canonical SKILL.md + full archive, so the UI uses this manifest
|
||||
# to show the original uploaded package contents.
|
||||
manifest_files: list[str] | None = None
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def _validate_name(cls, value: str) -> str:
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
raise ValueError("skill metadata name must not be blank")
|
||||
return normalized
|
||||
|
||||
|
||||
class DriveCommitItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
key: str
|
||||
file_ref: DriveFileRef
|
||||
# Drive-owned values may be physically cleaned on overwrite/removal; refs to
|
||||
# files shared with other business records should set this False.
|
||||
value_owned_by_drive: bool = True
|
||||
is_skill: bool = False
|
||||
skill_metadata: DriveSkillMetadata | None = None
|
||||
|
||||
|
||||
class AgentDriveSkillInfo(TypedDict):
|
||||
path: str
|
||||
skill_md_key: str
|
||||
archive_key: str | None
|
||||
name: str
|
||||
description: str
|
||||
size: int | None
|
||||
mime_type: str | None
|
||||
hash: str | None
|
||||
created_at: int | None
|
||||
|
||||
|
||||
class AgentDriveSkillFileInfo(TypedDict):
|
||||
path: str
|
||||
name: str
|
||||
type: str
|
||||
drive_key: str | None
|
||||
available_in_drive: bool
|
||||
|
||||
|
||||
class AgentDriveSkillInspectInfo(TypedDict):
|
||||
path: str
|
||||
skill_md_key: str
|
||||
archive_key: str | None
|
||||
name: str
|
||||
description: str
|
||||
size: int | None
|
||||
mime_type: str | None
|
||||
hash: str | None
|
||||
created_at: int | None
|
||||
source: str
|
||||
files: list[AgentDriveSkillFileInfo]
|
||||
file_tree: list[dict[str, Any]]
|
||||
skill_md: dict[str, Any]
|
||||
warnings: list[str]
|
||||
|
||||
|
||||
def decode_drive_mention_ref(ref_id: str) -> str:
|
||||
"""Decode the prompt token's URL-encoded drive-key field."""
|
||||
|
||||
return unquote(ref_id or "")
|
||||
|
||||
|
||||
def parse_agent_drive_ref(drive_ref: str) -> str:
|
||||
@ -132,6 +206,8 @@ class AgentDriveService:
|
||||
"mime_type": row.mime_type,
|
||||
"file_kind": row.file_kind.value,
|
||||
"file_id": row.file_id,
|
||||
"is_skill": row.is_skill,
|
||||
"skill_metadata": row.skill_metadata,
|
||||
"created_at": int(row.created_at.timestamp()) if row.created_at else None,
|
||||
}
|
||||
if include_download_url:
|
||||
@ -217,6 +293,87 @@ class AgentDriveService:
|
||||
self._delete_storage(storage_key)
|
||||
return removed_keys
|
||||
|
||||
def list_skills(self, *, tenant_id: str, agent_id: str) -> list[AgentDriveSkillInfo]:
|
||||
"""Return the drive-backed skill catalog derived from canonical ``SKILL.md`` rows."""
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
self._assert_agent_belongs_to_tenant(session, tenant_id=tenant_id, agent_id=agent_id)
|
||||
skill_rows = list(
|
||||
session.scalars(
|
||||
select(AgentDriveFile)
|
||||
.where(
|
||||
AgentDriveFile.tenant_id == tenant_id,
|
||||
AgentDriveFile.agent_id == agent_id,
|
||||
AgentDriveFile.is_skill.is_(True),
|
||||
)
|
||||
.order_by(AgentDriveFile.key)
|
||||
)
|
||||
)
|
||||
archive_keys = set(
|
||||
session.scalars(
|
||||
select(AgentDriveFile.key).where(
|
||||
AgentDriveFile.tenant_id == tenant_id,
|
||||
AgentDriveFile.agent_id == agent_id,
|
||||
AgentDriveFile.key.in_([self._skill_archive_key(row.key) for row in skill_rows]),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
skills: list[AgentDriveSkillInfo] = []
|
||||
for row in skill_rows:
|
||||
metadata = self._parse_skill_metadata(row.key, row.skill_metadata)
|
||||
archive_key = self._skill_archive_key(row.key)
|
||||
skills.append(
|
||||
{
|
||||
"path": self._skill_path_from_key(row.key),
|
||||
"skill_md_key": row.key,
|
||||
"archive_key": archive_key if archive_key in archive_keys else None,
|
||||
"name": metadata.name,
|
||||
"description": metadata.description,
|
||||
"size": row.size,
|
||||
"mime_type": row.mime_type,
|
||||
"hash": row.hash,
|
||||
"created_at": int(row.created_at.timestamp()) if row.created_at else None,
|
||||
}
|
||||
)
|
||||
return skills
|
||||
|
||||
def inspect_skill(self, *, tenant_id: str, agent_id: str, skill_path: str) -> AgentDriveSkillInspectInfo:
|
||||
"""Return the UI-facing skill inspect view for slash-menu hover/detail."""
|
||||
|
||||
skill_path = normalize_drive_key(skill_path)
|
||||
skill_md_key = skill_path if skill_path.endswith(_SKILL_MD_SUFFIX) else f"{skill_path}{_SKILL_MD_SUFFIX}"
|
||||
skill_path = self._skill_path_from_key(skill_md_key)
|
||||
catalog = next(
|
||||
(item for item in self.list_skills(tenant_id=tenant_id, agent_id=agent_id) if item["path"] == skill_path),
|
||||
None,
|
||||
)
|
||||
if catalog is None:
|
||||
raise AgentDriveError("skill_not_found", "no drive-backed skill for this path", status_code=404)
|
||||
|
||||
manifest_files = self._manifest_files_from_skill_metadata(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
skill_md_key=skill_md_key,
|
||||
)
|
||||
drive_items = self.manifest(tenant_id=tenant_id, agent_id=agent_id, prefix=f"{skill_path}/")
|
||||
drive_keys = {item["key"] for item in drive_items}
|
||||
preview = self.preview(tenant_id=tenant_id, agent_id=agent_id, key=skill_md_key)
|
||||
files, warnings = self._skill_file_entries(
|
||||
skill_path=skill_path,
|
||||
skill_md_key=skill_md_key,
|
||||
manifest_files=manifest_files,
|
||||
drive_keys=drive_keys,
|
||||
)
|
||||
return {
|
||||
**catalog,
|
||||
"source": "skill_md",
|
||||
"files": files,
|
||||
"file_tree": self._build_file_tree(files),
|
||||
"skill_md": preview,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
def _commit_one(
|
||||
self,
|
||||
session: Session,
|
||||
@ -228,9 +385,10 @@ class AgentDriveService:
|
||||
pending_storage_deletes: list[str],
|
||||
) -> dict[str, Any]:
|
||||
key = normalize_drive_key(item.key)
|
||||
skill_metadata = self._validate_skill_commit_fields(key=key, item=item)
|
||||
file_kind = AgentDriveFileKind(item.file_ref.kind)
|
||||
file_id = item.file_ref.id
|
||||
size, mime_type = self._validate_source(
|
||||
size, mime_type, file_hash = self._validate_source(
|
||||
session, tenant_id=tenant_id, user_id=user_id, file_kind=file_kind, file_id=file_id
|
||||
)
|
||||
|
||||
@ -245,6 +403,11 @@ class AgentDriveService:
|
||||
# Idempotent re-commit of the same value: leave it (do not clean).
|
||||
if existing.file_kind == file_kind and existing.file_id == file_id:
|
||||
existing.value_owned_by_drive = item.value_owned_by_drive
|
||||
existing.is_skill = item.is_skill
|
||||
existing.skill_metadata = skill_metadata
|
||||
existing.size = size
|
||||
existing.mime_type = mime_type
|
||||
existing.hash = file_hash
|
||||
return self._row_dict(existing)
|
||||
# Overwrite: clean the previous drive-owned value if no longer referenced.
|
||||
if existing.value_owned_by_drive:
|
||||
@ -259,7 +422,10 @@ class AgentDriveService:
|
||||
existing.file_kind = file_kind
|
||||
existing.file_id = file_id
|
||||
existing.value_owned_by_drive = item.value_owned_by_drive
|
||||
existing.is_skill = item.is_skill
|
||||
existing.skill_metadata = skill_metadata
|
||||
existing.size = size
|
||||
existing.hash = file_hash
|
||||
existing.mime_type = mime_type
|
||||
return self._row_dict(existing)
|
||||
|
||||
@ -271,7 +437,10 @@ class AgentDriveService:
|
||||
file_kind=file_kind,
|
||||
file_id=file_id,
|
||||
value_owned_by_drive=item.value_owned_by_drive,
|
||||
is_skill=item.is_skill,
|
||||
skill_metadata=skill_metadata,
|
||||
size=size,
|
||||
hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
created_by=user_id,
|
||||
)
|
||||
@ -287,8 +456,187 @@ class AgentDriveService:
|
||||
"size": row.size,
|
||||
"mime_type": row.mime_type,
|
||||
"value_owned_by_drive": row.value_owned_by_drive,
|
||||
"is_skill": row.is_skill,
|
||||
"skill_metadata": row.skill_metadata,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _skill_path_from_key(key: str) -> str:
|
||||
if not key.endswith(_SKILL_MD_SUFFIX):
|
||||
raise AgentDriveError(
|
||||
"invalid_skill_key",
|
||||
"skill rows must use the canonical '<path>/SKILL.md' key",
|
||||
status_code=500,
|
||||
)
|
||||
path = key[: -len(_SKILL_MD_SUFFIX)]
|
||||
if not path:
|
||||
raise AgentDriveError(
|
||||
"invalid_skill_key",
|
||||
"skill rows must use the canonical '<path>/SKILL.md' key",
|
||||
status_code=500,
|
||||
)
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def _skill_archive_key(cls, key: str) -> str:
|
||||
return f"{cls._skill_path_from_key(key)}/{_SKILL_ARCHIVE_NAME}"
|
||||
|
||||
@classmethod
|
||||
def _validate_skill_commit_fields(cls, *, key: str, item: DriveCommitItem) -> str | None:
|
||||
if not item.is_skill:
|
||||
if item.skill_metadata is not None:
|
||||
raise AgentDriveError(
|
||||
"invalid_skill_metadata",
|
||||
"skill metadata is only allowed for canonical skill rows",
|
||||
status_code=400,
|
||||
)
|
||||
return None
|
||||
cls._skill_path_from_key(key)
|
||||
if item.skill_metadata is None:
|
||||
raise AgentDriveError(
|
||||
"invalid_skill_metadata",
|
||||
"skill metadata is required for canonical skill rows",
|
||||
status_code=400,
|
||||
)
|
||||
return json.dumps(
|
||||
item.skill_metadata.model_dump(mode="json", exclude_none=True),
|
||||
separators=(",", ":"),
|
||||
sort_keys=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_skill_metadata(key: str, raw_metadata: str | None) -> DriveSkillMetadata:
|
||||
if raw_metadata is None:
|
||||
raise AgentDriveError(
|
||||
"invalid_skill_metadata",
|
||||
f"skill row '{key}' is missing required metadata",
|
||||
status_code=500,
|
||||
)
|
||||
try:
|
||||
return DriveSkillMetadata.model_validate(json.loads(raw_metadata))
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise AgentDriveError(
|
||||
"invalid_skill_metadata",
|
||||
f"skill row '{key}' has invalid stored metadata",
|
||||
status_code=500,
|
||||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def _manifest_files_from_skill_metadata(*, tenant_id: str, agent_id: str, skill_md_key: str) -> list[str] | None:
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(
|
||||
select(AgentDriveFile).where(
|
||||
AgentDriveFile.tenant_id == tenant_id,
|
||||
AgentDriveFile.agent_id == agent_id,
|
||||
AgentDriveFile.key == skill_md_key,
|
||||
AgentDriveFile.is_skill.is_(True),
|
||||
)
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
try:
|
||||
metadata = AgentDriveService._parse_skill_metadata(row.key, row.skill_metadata)
|
||||
except Exception:
|
||||
logger.warning("drive skill inspect: malformed skill metadata for %s", skill_md_key, exc_info=True)
|
||||
return None
|
||||
return [str(item) for item in (metadata.manifest_files or []) if str(item).strip()] or None
|
||||
|
||||
@classmethod
|
||||
def _skill_file_entries(
|
||||
cls,
|
||||
*,
|
||||
skill_path: str,
|
||||
skill_md_key: str,
|
||||
manifest_files: list[str] | None,
|
||||
drive_keys: set[str],
|
||||
) -> tuple[list[AgentDriveSkillFileInfo], list[str]]:
|
||||
warnings: list[str] = []
|
||||
if manifest_files:
|
||||
paths = sorted({normalize_drive_key(path) for path in manifest_files})
|
||||
else:
|
||||
paths = sorted(
|
||||
{
|
||||
key.removeprefix(f"{skill_path}/")
|
||||
for key in drive_keys
|
||||
if not key.endswith(f"/{_SKILL_ARCHIVE_NAME}")
|
||||
}
|
||||
)
|
||||
warnings.append("manifest_files_unavailable")
|
||||
|
||||
files: list[AgentDriveSkillFileInfo] = []
|
||||
for path in paths:
|
||||
if path == _SKILL_ARCHIVE_NAME:
|
||||
continue
|
||||
drive_key = f"{skill_path}/{path}"
|
||||
files.append(
|
||||
{
|
||||
"path": path,
|
||||
"name": path.rsplit("/", 1)[-1],
|
||||
"type": "file",
|
||||
"drive_key": drive_key if drive_key in drive_keys else None,
|
||||
"available_in_drive": drive_key in drive_keys,
|
||||
}
|
||||
)
|
||||
if "SKILL.md" not in {file["path"] for file in files}:
|
||||
files.insert(
|
||||
0,
|
||||
{
|
||||
"path": "SKILL.md",
|
||||
"name": "SKILL.md",
|
||||
"type": "file",
|
||||
"drive_key": skill_md_key,
|
||||
"available_in_drive": skill_md_key in drive_keys,
|
||||
},
|
||||
)
|
||||
return files, warnings
|
||||
|
||||
@staticmethod
|
||||
def _build_file_tree(files: list[AgentDriveSkillFileInfo]) -> list[dict[str, Any]]:
|
||||
root: dict[str, Any] = {}
|
||||
for file in files:
|
||||
cursor = root
|
||||
parts = [part for part in file["path"].split("/") if part]
|
||||
path_parts: list[str] = []
|
||||
for part in parts[:-1]:
|
||||
path_parts.append(part)
|
||||
directory = cursor.setdefault(
|
||||
part,
|
||||
{
|
||||
"name": part,
|
||||
"path": "/".join(path_parts),
|
||||
"type": "directory",
|
||||
"children": {},
|
||||
},
|
||||
)
|
||||
cursor = directory["children"]
|
||||
leaf_name = parts[-1] if parts else file["name"]
|
||||
cursor[leaf_name] = {
|
||||
"name": leaf_name,
|
||||
"path": file["path"],
|
||||
"type": file["type"],
|
||||
"drive_key": file["drive_key"],
|
||||
"available_in_drive": file["available_in_drive"],
|
||||
}
|
||||
|
||||
def serialize(node: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
result: list[dict[str, Any]] = []
|
||||
for item in sorted(node.values(), key=lambda value: (value["type"] != "directory", value["name"])):
|
||||
if item["type"] == "directory":
|
||||
children = serialize(item["children"])
|
||||
result.append(
|
||||
{
|
||||
"name": item["name"],
|
||||
"path": item["path"],
|
||||
"type": "directory",
|
||||
"children": children,
|
||||
}
|
||||
)
|
||||
else:
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
return serialize(root)
|
||||
|
||||
@staticmethod
|
||||
def _assert_agent_belongs_to_tenant(session: Session, *, tenant_id: str, agent_id: str) -> None:
|
||||
try:
|
||||
@ -309,7 +657,7 @@ class AgentDriveService:
|
||||
user_id: str,
|
||||
file_kind: AgentDriveFileKind,
|
||||
file_id: str,
|
||||
) -> tuple[int | None, str | None]:
|
||||
) -> tuple[int | None, str | None, str | None]:
|
||||
"""Verify the source file exists for the tenant (and user, for ToolFile).
|
||||
|
||||
Malformed ids (e.g. a non-UUID hitting a UUID column) are treated as a
|
||||
@ -328,7 +676,7 @@ class AgentDriveService:
|
||||
raise AgentDriveError(
|
||||
"source_not_found", "source ToolFile not found for this tenant/user", status_code=404
|
||||
)
|
||||
return tool_file.size, tool_file.mimetype
|
||||
return tool_file.size, tool_file.mimetype, None
|
||||
upload_file = session.scalar(
|
||||
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id)
|
||||
)
|
||||
@ -337,7 +685,7 @@ class AgentDriveService:
|
||||
raise AgentDriveError("source_not_found", "source file ref is invalid", status_code=404) from exc
|
||||
if upload_file is None:
|
||||
raise AgentDriveError("source_not_found", "source UploadFile not found for this tenant", status_code=404)
|
||||
return upload_file.size, upload_file.mime_type
|
||||
return upload_file.size, upload_file.mime_type, upload_file.hash
|
||||
|
||||
def _cleanup_value(
|
||||
self,
|
||||
@ -509,6 +857,8 @@ __all__ = [
|
||||
"AgentDriveService",
|
||||
"DriveCommitItem",
|
||||
"DriveFileRef",
|
||||
"DriveSkillMetadata",
|
||||
"decode_drive_mention_ref",
|
||||
"normalize_drive_key",
|
||||
"parse_agent_drive_ref",
|
||||
]
|
||||
|
||||
@ -1,4 +1,12 @@
|
||||
"""Tenant credit pool accounting.
|
||||
|
||||
Credit deductions are guarded by a tenant-level Redis lock before the database
|
||||
row lock is acquired. This keeps concurrent usage accounting for one tenant
|
||||
from piling up database transactions while preserving cross-tenant concurrency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -7,13 +15,44 @@ from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.errors.error import QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CREDIT_POOL_TENANT_LOCK_TIMEOUT_SECONDS = 10
|
||||
CREDIT_POOL_TENANT_LOCK_BLOCKING_TIMEOUT_SECONDS = 5
|
||||
|
||||
|
||||
class CreditPoolService:
|
||||
@staticmethod
|
||||
def _get_tenant_lock_key(tenant_id: str) -> str:
|
||||
return f"credit_pool:tenant:{tenant_id}:deduct_lock"
|
||||
|
||||
@classmethod
|
||||
def _deduct_with_tenant_lock(cls, tenant_id: str, deduct: Callable[[], int]) -> int:
|
||||
lock_key = cls._get_tenant_lock_key(tenant_id)
|
||||
lock = redis_client.lock(
|
||||
lock_key,
|
||||
timeout=CREDIT_POOL_TENANT_LOCK_TIMEOUT_SECONDS,
|
||||
blocking_timeout=CREDIT_POOL_TENANT_LOCK_BLOCKING_TIMEOUT_SECONDS,
|
||||
)
|
||||
lock_acquired = False
|
||||
|
||||
try:
|
||||
lock_acquired = lock.acquire(blocking=True)
|
||||
if not lock_acquired:
|
||||
raise QuotaExceededError("Failed to acquire credit pool lock")
|
||||
|
||||
return deduct()
|
||||
finally:
|
||||
if lock_acquired:
|
||||
try:
|
||||
lock.release()
|
||||
except Exception:
|
||||
logger.warning("Failed to release credit pool lock, tenant_id=%s", tenant_id, exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None:
|
||||
return session.scalar(
|
||||
@ -76,7 +115,7 @@ class CreditPoolService:
|
||||
if credits_required <= 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
def deduct() -> int:
|
||||
with session_factory.get_session_maker().begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
@ -89,14 +128,16 @@ class CreditPoolService:
|
||||
raise QuotaExceededError("Insufficient credits remaining")
|
||||
|
||||
pool.quota_used += credits_required
|
||||
return credits_required
|
||||
|
||||
try:
|
||||
return cls._deduct_with_tenant_lock(tenant_id, deduct)
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
return credits_required
|
||||
|
||||
@classmethod
|
||||
def deduct_credits_capped(
|
||||
cls,
|
||||
@ -108,7 +149,7 @@ class CreditPoolService:
|
||||
if credits_required <= 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
def deduct() -> int:
|
||||
with session_factory.get_session_maker().begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
@ -121,6 +162,9 @@ class CreditPoolService:
|
||||
|
||||
pool.quota_used += deducted_credits
|
||||
return deducted_credits
|
||||
|
||||
try:
|
||||
return cls._deduct_with_tenant_lock(tenant_id, deduct)
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
|
||||
@ -182,6 +182,10 @@ class EnterpriseRequest(BaseRequest):
|
||||
inner_headers: dict[str, str] = {INNER_TENANT_ID_HEADER: tenant_id}
|
||||
if account_id:
|
||||
inner_headers[INNER_ACCOUNT_ID_HEADER] = account_id
|
||||
|
||||
if not cls.base_url.startswith("http") or not cls.base_url.startswith("https") or not cls.base_url:
|
||||
raise ValueError("ENTERPRISE_RBAC_API_URL is required when RBAC_ENABLED=true")
|
||||
|
||||
url = f"{cls.rbac_base_url}{endpoint}"
|
||||
mounts = cls._build_mounts()
|
||||
|
||||
|
||||
@ -312,15 +312,26 @@ _LEGACY_WORKSPACE_OWNER_KEYS: list[str] = [
|
||||
"plugin.manage",
|
||||
"plugin.debug",
|
||||
"credential.use",
|
||||
"credential.create",
|
||||
"credential.manage",
|
||||
"billing.view",
|
||||
"billing.subscription.manage",
|
||||
"billing.manage",
|
||||
"app.acl.preview",
|
||||
"app_library.access",
|
||||
"app.create_and_management",
|
||||
"app.tag.manage",
|
||||
"dataset.acl.preview",
|
||||
"dataset.create_and_management",
|
||||
"dataset.tag.manage",
|
||||
"dataset.external.connect",
|
||||
"dataset.api_key.manage",
|
||||
"snippets.create_and_modify",
|
||||
"snippets.management",
|
||||
"tool.manage",
|
||||
"mcp.manage",
|
||||
"snippets.create_and_modify",
|
||||
"snippets.management",
|
||||
]
|
||||
|
||||
_LEGACY_WORKSPACE_ADMIN_KEYS: list[str] = [
|
||||
@ -334,15 +345,24 @@ _LEGACY_WORKSPACE_ADMIN_KEYS: list[str] = [
|
||||
"plugin.manage",
|
||||
"plugin.debug",
|
||||
"credential.use",
|
||||
"credential.create",
|
||||
"credential.manage",
|
||||
"billing.view",
|
||||
"billing.subscription.manage",
|
||||
"billing.manage",
|
||||
"app_library.access",
|
||||
"app.create_and_management",
|
||||
"app.tag.manage",
|
||||
"dataset.create_and_management",
|
||||
"dataset.tag.manage",
|
||||
"dataset.external.connect",
|
||||
"dataset.api_key.manage",
|
||||
"snippets.create_and_modify",
|
||||
"snippets.management",
|
||||
"tool.manage",
|
||||
"mcp.manage",
|
||||
"snippets.create_and_modify",
|
||||
"snippets.management",
|
||||
]
|
||||
|
||||
_LEGACY_WORKSPACE_EDITOR_KEYS: list[str] = [
|
||||
@ -356,7 +376,9 @@ _LEGACY_WORKSPACE_EDITOR_KEYS: list[str] = [
|
||||
"dataset.create_and_management",
|
||||
"dataset.tag.manage",
|
||||
"dataset.external.connect",
|
||||
"snippets.create_and_modify",
|
||||
"tool.manage",
|
||||
"snippets.create_and_modify",
|
||||
]
|
||||
|
||||
_LEGACY_WORKSPACE_NORMAL_KEYS: list[str] = [
|
||||
@ -373,6 +395,7 @@ _LEGACY_WORKSPACE_DATASET_OPERATOR_KEYS: list[str] = [
|
||||
]
|
||||
|
||||
_LEGACY_APP_OWNER_KEYS: list[str] = [
|
||||
"app.acl.preview",
|
||||
"app.acl.view_layout",
|
||||
"app.acl.test_and_run",
|
||||
"app.acl.edit",
|
||||
@ -384,6 +407,7 @@ _LEGACY_APP_OWNER_KEYS: list[str] = [
|
||||
]
|
||||
|
||||
_LEGACY_APP_ADMIN_KEYS: list[str] = [
|
||||
"app.acl.preview",
|
||||
"app.acl.view_layout",
|
||||
"app.acl.test_and_run",
|
||||
"app.acl.edit",
|
||||
@ -395,6 +419,7 @@ _LEGACY_APP_ADMIN_KEYS: list[str] = [
|
||||
]
|
||||
|
||||
_LEGACY_APP_EDITOR_KEYS: list[str] = [
|
||||
"app.acl.preview",
|
||||
"app.acl.view_layout",
|
||||
"app.acl.test_and_run",
|
||||
"app.acl.edit",
|
||||
@ -406,12 +431,14 @@ _LEGACY_APP_EDITOR_KEYS: list[str] = [
|
||||
]
|
||||
|
||||
_LEGACY_APP_NORMAL_KEYS: list[str] = [
|
||||
"app.acl.preview",
|
||||
"app.acl.view_layout",
|
||||
"app.acl.test_and_run",
|
||||
"app.acl.monitor",
|
||||
]
|
||||
|
||||
_LEGACY_DATASET_OWNER_KEYS: list[str] = [
|
||||
"dataset.acl.preview",
|
||||
"dataset.acl.readonly",
|
||||
"dataset.acl.edit",
|
||||
"dataset.acl.import_export_dsl",
|
||||
@ -427,6 +454,7 @@ _LEGACY_DATASET_OWNER_KEYS: list[str] = [
|
||||
]
|
||||
|
||||
_LEGACY_DATASET_ADMIN_KEYS: list[str] = [
|
||||
"dataset.acl.preview",
|
||||
"dataset.acl.readonly",
|
||||
"dataset.acl.edit",
|
||||
"dataset.acl.import_export_dsl",
|
||||
@ -442,6 +470,7 @@ _LEGACY_DATASET_ADMIN_KEYS: list[str] = [
|
||||
]
|
||||
|
||||
_LEGACY_DATASET_EDITOR_KEYS: list[str] = [
|
||||
"dataset.acl.preview",
|
||||
"dataset.acl.readonly",
|
||||
"dataset.acl.edit",
|
||||
"dataset.acl.import_export_dsl",
|
||||
@ -492,6 +521,19 @@ _LEGACY_MY_PERMISSIONS: dict[TenantAccountRole, dict[str, list[str]]] = {
|
||||
}
|
||||
|
||||
|
||||
def _legacy_role_permission_keys(role: TenantAccountRole) -> list[str]:
|
||||
permissions = _LEGACY_MY_PERMISSIONS.get(role, {})
|
||||
return list(
|
||||
dict.fromkeys(
|
||||
[
|
||||
*permissions.get("workspace", []),
|
||||
*permissions.get("app", []),
|
||||
*permissions.get("dataset", []),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _legacy_my_permissions(tenant_id: str, account_id: str | None) -> MyPermissionsResponse:
|
||||
if not account_id:
|
||||
return MyPermissionsResponse()
|
||||
@ -1518,21 +1560,44 @@ class RBACService:
|
||||
)
|
||||
return AccessMatrixItem.model_validate(data or {})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Member ↔ role bindings (screenshot 3: Settings > Members > Assign roles).
|
||||
# ------------------------------------------------------------------
|
||||
class MemberRoles:
|
||||
@staticmethod
|
||||
def get(tenant_id: str, account_id: str | None, member_account_id: str) -> MemberRolesResponse:
|
||||
data = _inner_call(
|
||||
"GET",
|
||||
f"{_INNER_PREFIX}/members/rbac-roles",
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
params={"account_id": member_account_id},
|
||||
)
|
||||
rst = MemberRolesResponse.model_validate(data or {})
|
||||
return rst
|
||||
if dify_config.RBAC_ENABLED:
|
||||
data = _inner_call(
|
||||
"GET",
|
||||
f"{_INNER_PREFIX}/members/rbac-roles",
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
params={"account_id": member_account_id},
|
||||
)
|
||||
rst = MemberRolesResponse.model_validate(data or {})
|
||||
return rst
|
||||
else:
|
||||
with session_factory.create_session() as session:
|
||||
role = session.scalar(
|
||||
select(TenantAccountJoin.role).where(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == member_account_id,
|
||||
)
|
||||
)
|
||||
return MemberRolesResponse(
|
||||
account_id=member_account_id,
|
||||
roles=[
|
||||
RBACRole(
|
||||
id="",
|
||||
name=role,
|
||||
description="",
|
||||
is_builtin=True,
|
||||
type="",
|
||||
permission_keys=_legacy_role_permission_keys(role),
|
||||
role_tag="owner" if role == "owner" else role,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
]
|
||||
if role
|
||||
else [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def batch_get(
|
||||
|
||||
@ -3,6 +3,8 @@ from typing import TypedDict
|
||||
|
||||
import httpx
|
||||
|
||||
OPERATION_REQUEST_TIMEOUT = httpx.Timeout(10.0, connect=3.0)
|
||||
|
||||
|
||||
class UtmInfo(TypedDict, total=False):
|
||||
"""Expected shape of the utm_info dict passed to record_utm.
|
||||
@ -26,7 +28,9 @@ class OperationService:
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
||||
response = httpx.request(
|
||||
method, url, json=json, params=params, headers=headers, timeout=OPERATION_REQUEST_TIMEOUT
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
@ -23,8 +23,11 @@ from core.app.entities.task_entities import (
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_forms import (
|
||||
load_form_dispositions_by_form_id,
|
||||
)
|
||||
from core.workflow.human_input_policy import (
|
||||
FormDisposition,
|
||||
HumanInputSurface,
|
||||
enrich_human_input_pause_reasons,
|
||||
resolve_human_input_pause_reason_inputs,
|
||||
@ -359,7 +362,7 @@ def _build_human_input_required_events(
|
||||
|
||||
expiration_times_by_form_id: dict[str, int] = {}
|
||||
display_in_ui_by_form_id: dict[str, bool] = {}
|
||||
form_tokens_by_form_id: dict[str, str] = {}
|
||||
dispositions_by_form_id: dict[str, FormDisposition] = {}
|
||||
if human_input_form_ids and session_maker is not None:
|
||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where(
|
||||
HumanInputForm.id.in_(human_input_form_ids)
|
||||
@ -372,7 +375,7 @@ def _build_human_input_required_events(
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_tokens_by_form_id = load_form_tokens_by_form_id(
|
||||
dispositions_by_form_id = load_form_dispositions_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=human_input_surface,
|
||||
@ -393,6 +396,7 @@ def _build_human_input_required_events(
|
||||
reason.inputs,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
disposition = dispositions_by_form_id.get(form_id)
|
||||
|
||||
response = HumanInputRequiredResponse(
|
||||
task_id=task_id,
|
||||
@ -405,7 +409,8 @@ def _build_human_input_required_events(
|
||||
inputs=resolved_inputs,
|
||||
actions=reason.actions,
|
||||
display_in_ui=display_in_ui_by_form_id.get(form_id, False),
|
||||
form_token=form_tokens_by_form_id.get(form_id),
|
||||
form_token=disposition.form_token if disposition else None,
|
||||
approval_channels=list(disposition.approval_channels) if disposition else [],
|
||||
resolved_default_values=reason.resolved_default_values,
|
||||
expiration_time=expiration_time,
|
||||
),
|
||||
@ -493,11 +498,11 @@ def _build_pause_event(
|
||||
for form_id in [reason.get("form_id")]
|
||||
if isinstance(form_id, str)
|
||||
]
|
||||
form_tokens_by_form_id: dict[str, str] = {}
|
||||
dispositions_by_form_id: dict[str, FormDisposition] = {}
|
||||
expiration_times_by_form_id: dict[str, int] = {}
|
||||
if human_input_form_ids and session_maker is not None:
|
||||
with session_maker() as session:
|
||||
form_tokens_by_form_id = load_form_tokens_by_form_id(
|
||||
dispositions_by_form_id = load_form_dispositions_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=human_input_surface,
|
||||
@ -512,7 +517,7 @@ def _build_pause_event(
|
||||
# otherwise clients see schema drift after resume.
|
||||
reasons = enrich_human_input_pause_reasons(
|
||||
reasons,
|
||||
form_tokens_by_form_id=form_tokens_by_form_id,
|
||||
dispositions_by_form_id=dispositions_by_form_id,
|
||||
expiration_times_by_form_id=expiration_times_by_form_id,
|
||||
)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from testcontainers.redis import RedisContainer
|
||||
|
||||
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.pubsub_channel import BroadcastChannel as RedisBroadcastChannel
|
||||
|
||||
|
||||
class TestRedisBroadcastChannelIntegration:
|
||||
|
||||
161
api/tests/unit_tests/controllers/common/test_app_access.py
Normal file
161
api/tests/unit_tests/controllers/common/test_app_access.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""Unit tests for controllers.common.app_access RBAC app-id access filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.common.app_access import (
|
||||
APP_LIST_PERMISSION_KEYS,
|
||||
AppAccessFilter,
|
||||
has_app_list_permission,
|
||||
resolve_app_access_filter,
|
||||
)
|
||||
from services.app_service import AppListParams
|
||||
from services.enterprise.rbac_service import (
|
||||
MyPermissionsResponse,
|
||||
ResourcePermissionKeys,
|
||||
ResourcePermissionSnapshot,
|
||||
ResourceWhitelistResources,
|
||||
WorkspacePermissionSnapshot,
|
||||
)
|
||||
|
||||
_RBAC_MODULE = "controllers.common.app_access.enterprise_rbac_service"
|
||||
|
||||
|
||||
def _permissions(
|
||||
*,
|
||||
workspace_keys: list[str] | None = None,
|
||||
app_default_keys: list[str] | None = None,
|
||||
app_overrides: list[ResourcePermissionKeys] | None = None,
|
||||
) -> MyPermissionsResponse:
|
||||
return MyPermissionsResponse(
|
||||
workspace=WorkspacePermissionSnapshot(permission_keys=workspace_keys or []),
|
||||
app=ResourcePermissionSnapshot(
|
||||
default_permission_keys=app_default_keys or [],
|
||||
overrides=app_overrides or [],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestHasAppListPermission:
|
||||
def test_matches_known_preview_keys(self):
|
||||
for key in APP_LIST_PERMISSION_KEYS:
|
||||
assert has_app_list_permission([key])
|
||||
|
||||
def test_rejects_unknown_keys(self):
|
||||
assert not has_app_list_permission(["app.export", "app.delete"])
|
||||
assert not has_app_list_permission([])
|
||||
|
||||
|
||||
class TestAppAccessFilterIsAppAccessible:
|
||||
def test_unrestricted_sees_everything(self):
|
||||
flt = AppAccessFilter.unrestricted()
|
||||
assert flt.is_app_accessible("app-1", maintainer="someone", account_id="acc-1")
|
||||
|
||||
def test_whitelisted_app_is_visible(self):
|
||||
flt = AppAccessFilter(accessible_app_ids={"app-1"}, can_manage_own_apps=False)
|
||||
assert flt.is_app_accessible("app-1", maintainer=None, account_id="acc-1")
|
||||
assert not flt.is_app_accessible("app-2", maintainer=None, account_id="acc-1")
|
||||
|
||||
def test_own_app_visible_only_with_manage_permission(self):
|
||||
own = AppAccessFilter(accessible_app_ids=set(), can_manage_own_apps=True)
|
||||
assert own.is_app_accessible("app-1", maintainer="acc-1", account_id="acc-1")
|
||||
assert not own.is_app_accessible("app-1", maintainer="acc-2", account_id="acc-1")
|
||||
|
||||
no_manage = AppAccessFilter(accessible_app_ids=set(), can_manage_own_apps=False)
|
||||
assert not no_manage.is_app_accessible("app-1", maintainer="acc-1", account_id="acc-1")
|
||||
|
||||
|
||||
class TestAppAccessFilterApplyToParams:
|
||||
def test_unrestricted_leaves_params_untouched(self):
|
||||
params = AppListParams()
|
||||
AppAccessFilter.unrestricted().apply_to_params(params)
|
||||
assert params.accessible_app_ids is None
|
||||
assert params.include_own_apps is False
|
||||
assert params.is_created_by_me is None
|
||||
|
||||
def test_whitelisted_ids_are_sorted_with_own_apps_flag(self):
|
||||
params = AppListParams()
|
||||
AppAccessFilter(accessible_app_ids={"b", "a"}, can_manage_own_apps=True).apply_to_params(params)
|
||||
assert params.accessible_app_ids == ["a", "b"]
|
||||
assert params.include_own_apps is True
|
||||
|
||||
def test_empty_set_with_manage_falls_back_to_maintained_apps(self):
|
||||
# Own-app fallback must use maintainer (include_own_apps), consistent
|
||||
# with is_app_accessible — not created_by (is_created_by_me).
|
||||
params = AppListParams()
|
||||
AppAccessFilter(accessible_app_ids=set(), can_manage_own_apps=True).apply_to_params(params)
|
||||
assert params.accessible_app_ids == []
|
||||
assert params.include_own_apps is True
|
||||
assert params.is_created_by_me is None
|
||||
|
||||
def test_empty_set_without_manage_sees_nothing(self):
|
||||
params = AppListParams()
|
||||
AppAccessFilter(accessible_app_ids=set(), can_manage_own_apps=False).apply_to_params(params)
|
||||
assert params.accessible_app_ids == []
|
||||
assert params.include_own_apps is False
|
||||
assert params.is_created_by_me is None
|
||||
|
||||
|
||||
class TestResolveAppAccessFilter:
|
||||
def _patch_whitelist(self, monkeypatch: pytest.MonkeyPatch, whitelist: ResourceWhitelistResources) -> None:
|
||||
monkeypatch.setattr(
|
||||
f"{_RBAC_MODULE}.RBACService.AppAccess.whitelist_resources",
|
||||
lambda tenant_id, account_id: whitelist,
|
||||
)
|
||||
|
||||
def test_default_preview_is_unrestricted(self, monkeypatch: pytest.MonkeyPatch):
|
||||
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=True))
|
||||
permissions = _permissions(app_default_keys=["app.preview"])
|
||||
|
||||
flt = resolve_app_access_filter("tenant-1", "acc-1", permissions=permissions)
|
||||
|
||||
assert flt.accessible_app_ids is None
|
||||
assert flt.can_manage_own_apps is False
|
||||
|
||||
def test_default_preview_overrides_whitelist_restriction(self, monkeypatch: pytest.MonkeyPatch):
|
||||
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=False, resource_ids=["app-9"]))
|
||||
permissions = _permissions(
|
||||
workspace_keys=["app.full_access", "app.create_and_management"],
|
||||
)
|
||||
|
||||
flt = resolve_app_access_filter("tenant-1", "acc-1", permissions=permissions)
|
||||
|
||||
# Workspace-level preview grant defeats the whitelist restriction.
|
||||
assert flt.accessible_app_ids is None
|
||||
assert flt.can_manage_own_apps is True
|
||||
|
||||
def test_override_apps_collected_without_default_preview(self, monkeypatch: pytest.MonkeyPatch):
|
||||
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=True))
|
||||
permissions = _permissions(
|
||||
app_overrides=[
|
||||
ResourcePermissionKeys(resource_id="app-1", permission_keys=["app.preview"]),
|
||||
ResourcePermissionKeys(resource_id="app-2", permission_keys=["app.export"]),
|
||||
],
|
||||
)
|
||||
|
||||
flt = resolve_app_access_filter("tenant-1", "acc-1", permissions=permissions)
|
||||
|
||||
assert flt.accessible_app_ids == {"app-1"}
|
||||
|
||||
def test_whitelist_union_with_override_apps(self, monkeypatch: pytest.MonkeyPatch):
|
||||
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=False, resource_ids=["app-5"]))
|
||||
permissions = _permissions(
|
||||
app_overrides=[ResourcePermissionKeys(resource_id="app-1", permission_keys=["app.acl.preview"])],
|
||||
)
|
||||
|
||||
flt = resolve_app_access_filter("tenant-1", "acc-1", permissions=permissions)
|
||||
|
||||
assert flt.accessible_app_ids == {"app-1", "app-5"}
|
||||
|
||||
def test_fetches_permissions_when_not_supplied(self, monkeypatch: pytest.MonkeyPatch):
|
||||
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=False, resource_ids=[]))
|
||||
monkeypatch.setattr(
|
||||
f"{_RBAC_MODULE}.RBACService.MyPermissions.get",
|
||||
lambda tenant_id, account_id: _permissions(workspace_keys=["app.create_and_management"]),
|
||||
)
|
||||
|
||||
flt = resolve_app_access_filter("tenant-1", "acc-1")
|
||||
|
||||
assert flt.accessible_app_ids == set()
|
||||
assert flt.can_manage_own_apps is True
|
||||
@ -20,6 +20,10 @@ from controllers.console.agent.composer import (
|
||||
WorkflowAgentComposerValidateApi,
|
||||
)
|
||||
from controllers.console.agent.roster import (
|
||||
AgentApiAccessApi,
|
||||
AgentApiKeyApi,
|
||||
AgentApiKeyListApi,
|
||||
AgentApiStatusApi,
|
||||
AgentAppApi,
|
||||
AgentAppCopyApi,
|
||||
AgentAppListApi,
|
||||
@ -28,6 +32,7 @@ from controllers.console.agent.roster import (
|
||||
AgentLogsApi,
|
||||
AgentLogSourcesApi,
|
||||
AgentRosterVersionDetailApi,
|
||||
AgentRosterVersionRestoreApi,
|
||||
AgentRosterVersionsApi,
|
||||
AgentStatisticsSummaryApi,
|
||||
)
|
||||
@ -149,6 +154,10 @@ def test_agent_v2_console_routes_are_agent_id_first() -> None:
|
||||
"/agent/<uuid:agent_id>/sandbox/files",
|
||||
"/agent/<uuid:agent_id>/skills/upload",
|
||||
"/agent/<uuid:agent_id>/files",
|
||||
"/agent/<uuid:agent_id>/api-access",
|
||||
"/agent/<uuid:agent_id>/api-enable",
|
||||
"/agent/<uuid:agent_id>/api-keys",
|
||||
"/agent/<uuid:agent_id>/api-keys/<uuid:api_key_id>",
|
||||
"/agent/<uuid:agent_id>/chat-messages",
|
||||
"/agent/<uuid:agent_id>/chat-messages/<string:task_id>/stop",
|
||||
"/agent/<uuid:agent_id>/feedbacks",
|
||||
@ -158,6 +167,9 @@ def test_agent_v2_console_routes_are_agent_id_first() -> None:
|
||||
"/agent/<uuid:agent_id>/logs/<uuid:conversation_id>/messages",
|
||||
"/agent/<uuid:agent_id>/log-sources",
|
||||
"/agent/<uuid:agent_id>/statistics/summary",
|
||||
"/agent/<uuid:agent_id>/versions",
|
||||
"/agent/<uuid:agent_id>/versions/<uuid:version_id>",
|
||||
"/agent/<uuid:agent_id>/versions/<uuid:version_id>/restore",
|
||||
"/agent/invite-options",
|
||||
):
|
||||
assert route in paths
|
||||
@ -173,6 +185,7 @@ def test_agent_v2_console_routes_are_agent_id_first() -> None:
|
||||
"/apps/<uuid:app_id>/agent-features",
|
||||
"/apps/<uuid:app_id>/agent-referencing-workflows",
|
||||
"/apps/<uuid:app_id>/agent-sandbox/files",
|
||||
"/apps/<uuid:agent_id>/api-access",
|
||||
):
|
||||
assert route not in paths
|
||||
|
||||
@ -215,16 +228,34 @@ def test_agent_app_list_and_create_use_agent_route(
|
||||
roster_controller.AgentRosterService,
|
||||
"load_app_backing_agents_by_app_id",
|
||||
lambda _self, **kwargs: {
|
||||
"app-list": SimpleNamespace(id="agent-list", role="List role", active_config_snapshot_id=None)
|
||||
"app-list": SimpleNamespace(
|
||||
id="agent-list",
|
||||
role="List role",
|
||||
debug_conversation_id="debug-conversation-list",
|
||||
active_config_snapshot_id=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_app_backing_agent",
|
||||
lambda _self, **kwargs: SimpleNamespace(
|
||||
id="agent-created", role="Created role", active_config_snapshot_id=None
|
||||
id="agent-created",
|
||||
role="Created role",
|
||||
debug_conversation_id="debug-conversation-created",
|
||||
active_config_snapshot_id=None,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_or_create_agent_app_debug_conversation_id",
|
||||
lambda _self, **kwargs: "debug-conversation-detail",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_or_create_agent_app_debug_conversation_id",
|
||||
lambda _self, **kwargs: "debug-conversation-detail",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"load_published_references_by_agent_id",
|
||||
@ -245,6 +276,16 @@ def test_agent_app_list_and_create_use_agent_route(
|
||||
]
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"load_or_create_agent_app_debug_conversation_ids_by_agent_id",
|
||||
lambda _self, **kwargs: {"agent-list": "debug-conversation-list"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_or_create_agent_app_debug_conversation_id",
|
||||
lambda _self, **kwargs: "debug-conversation-created",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.FeatureService,
|
||||
"get_system_features",
|
||||
@ -259,6 +300,7 @@ def test_agent_app_list_and_create_use_agent_route(
|
||||
assert listed["total"] == 1
|
||||
assert listed["data"][0]["id"] == "agent-list"
|
||||
assert listed["data"][0]["app_id"] == "app-list"
|
||||
assert listed["data"][0]["debug_conversation_id"] == "debug-conversation-list"
|
||||
assert listed["data"][0]["role"] == "List role"
|
||||
assert listed["data"][0]["active_config_is_published"] is False
|
||||
assert listed["data"][0]["published_reference_count"] == 1
|
||||
@ -292,6 +334,7 @@ def test_agent_app_list_and_create_use_agent_route(
|
||||
assert status == 201
|
||||
assert created["id"] == "agent-created"
|
||||
assert created["app_id"] == "app-created"
|
||||
assert created["debug_conversation_id"] == "debug-conversation-created"
|
||||
assert created["role"] == "Created role"
|
||||
assert created["active_config_is_published"] is False
|
||||
assert "bound_agent_id" not in created
|
||||
@ -332,7 +375,17 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_app_backing_agent",
|
||||
lambda _self, **kwargs: SimpleNamespace(id=agent_id, role="Resolved role", active_config_snapshot_id=None),
|
||||
lambda _self, **kwargs: SimpleNamespace(
|
||||
id=agent_id,
|
||||
role="Resolved role",
|
||||
debug_conversation_id="debug-conversation-detail",
|
||||
active_config_snapshot_id=None,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_or_create_agent_app_debug_conversation_id",
|
||||
lambda _self, **kwargs: "debug-conversation-detail",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.FeatureService,
|
||||
@ -354,9 +407,10 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
|
||||
|
||||
monkeypatch.setattr(roster_controller, "AppService", FakeAppService)
|
||||
|
||||
detail = unwrap(AgentAppApi.get)(AgentAppApi(), "tenant-1", agent_id)
|
||||
detail = unwrap(AgentAppApi.get)(AgentAppApi(), "tenant-1", SimpleNamespace(id=account_id), agent_id)
|
||||
assert detail["id"] == agent_id
|
||||
assert detail["app_id"] == "app-1"
|
||||
assert detail["debug_conversation_id"] == "debug-conversation-detail"
|
||||
assert detail["role"] == "Resolved role"
|
||||
assert detail["active_config_is_published"] is False
|
||||
assert "bound_agent_id" not in detail
|
||||
@ -365,11 +419,12 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
|
||||
"/console/api/agent/00000000-0000-0000-0000-000000000001",
|
||||
json={"name": "Renamed", "description": "", "role": "Reviewer", "icon_type": "emoji", "icon": "R"},
|
||||
):
|
||||
updated = unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", agent_id)
|
||||
updated = unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", SimpleNamespace(id=account_id), agent_id)
|
||||
|
||||
assert updated["name"] == "Renamed"
|
||||
assert updated["id"] == agent_id
|
||||
assert updated["app_id"] == "app-1"
|
||||
assert updated["debug_conversation_id"] == "debug-conversation-detail"
|
||||
assert updated["role"] == "Resolved role"
|
||||
assert updated["active_config_is_published"] is False
|
||||
assert "bound_agent_id" not in updated
|
||||
@ -399,7 +454,7 @@ def test_agent_app_copy_uses_agent_id_and_returns_agent_detail(
|
||||
monkeypatch.setattr(
|
||||
roster_controller,
|
||||
"_serialize_agent_app_detail",
|
||||
lambda app_model: {"id": "copied-agent", "app_id": app_model.id, "name": app_model.name},
|
||||
lambda app_model, **_kwargs: {"id": "copied-agent", "app_id": app_model.id, "name": app_model.name},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
@ -428,6 +483,127 @@ def test_agent_app_copy_uses_agent_id_and_returns_agent_detail(
|
||||
}
|
||||
|
||||
|
||||
def test_agent_api_access_uses_agent_id_and_returns_service_api_metadata(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
agent_id = "00000000-0000-0000-0000-000000000001"
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
enable_api=True,
|
||||
api_base_url="https://api.example.test/v1",
|
||||
api_rpm=60,
|
||||
api_rph=600,
|
||||
)
|
||||
monkeypatch.setattr(roster_controller, "_resolve_agent_app_model", lambda **kwargs: app_model)
|
||||
monkeypatch.setattr(roster_controller, "_agent_api_key_count", lambda app_id: 2)
|
||||
|
||||
response = unwrap(AgentApiAccessApi.get)(AgentApiAccessApi(), "tenant-1", agent_id)
|
||||
|
||||
assert response == {
|
||||
"enabled": True,
|
||||
"service_api_base_url": "https://api.example.test/v1",
|
||||
"streaming_only": True,
|
||||
"chat_endpoint": "https://api.example.test/v1/chat-messages",
|
||||
"stop_endpoint": "https://api.example.test/v1/chat-messages/{task_id}/stop",
|
||||
"conversations_endpoint": "https://api.example.test/v1/conversations",
|
||||
"messages_endpoint": "https://api.example.test/v1/messages",
|
||||
"files_upload_endpoint": "https://api.example.test/v1/files/upload",
|
||||
"parameters_endpoint": "https://api.example.test/v1/parameters",
|
||||
"info_endpoint": "https://api.example.test/v1/info",
|
||||
"meta_endpoint": "https://api.example.test/v1/meta",
|
||||
"api_rpm": 60,
|
||||
"api_rph": 600,
|
||||
"api_key_count": 2,
|
||||
}
|
||||
|
||||
|
||||
def test_agent_api_status_and_key_routes_resolve_backing_app(
|
||||
app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
agent_id = "00000000-0000-0000-0000-000000000001"
|
||||
api_key_id = "00000000-0000-0000-0000-000000000002"
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
enable_api=False,
|
||||
api_base_url="https://api.example.test/v1",
|
||||
api_rpm=0,
|
||||
api_rph=0,
|
||||
)
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
monkeypatch.setattr(roster_controller, "_resolve_agent_app_model", lambda **kwargs: app_model)
|
||||
monkeypatch.setattr(roster_controller, "_agent_api_key_count", lambda app_id: 1)
|
||||
|
||||
class FakeAppService:
|
||||
def update_app_api_status(self, app_obj: object, enable_api: bool) -> object:
|
||||
captured["enable"] = {"app": app_obj, "enable_api": enable_api}
|
||||
app_model.enable_api = enable_api
|
||||
return app_model
|
||||
|
||||
monkeypatch.setattr(roster_controller, "AppService", FakeAppService)
|
||||
|
||||
def fake_get_api_key_list(self, resource_id: str, tenant_id: str):
|
||||
captured["list_keys"] = {"resource_id": resource_id, "tenant_id": tenant_id}
|
||||
return roster_controller.ApiKeyList(data=[])
|
||||
|
||||
def fake_create_api_key(self, resource_id: str, tenant_id: str):
|
||||
captured["create_key"] = {"resource_id": resource_id, "tenant_id": tenant_id}
|
||||
return SimpleNamespace(
|
||||
id=api_key_id,
|
||||
type="app",
|
||||
token="app-test-token",
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
def fake_delete_api_key(self, resource_id: str, key_id: str, tenant_id: str, current_user: object) -> None:
|
||||
captured["delete_key"] = {
|
||||
"resource_id": resource_id,
|
||||
"api_key_id": key_id,
|
||||
"tenant_id": tenant_id,
|
||||
"current_user": current_user,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(AgentApiKeyListApi, "_get_api_key_list", fake_get_api_key_list)
|
||||
monkeypatch.setattr(AgentApiKeyListApi, "_create_api_key", fake_create_api_key)
|
||||
monkeypatch.setattr(AgentApiKeyApi, "_delete_api_key", fake_delete_api_key)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/agent/00000000-0000-0000-0000-000000000001/api-enable",
|
||||
json={"enable_api": True},
|
||||
):
|
||||
enabled = unwrap(AgentApiStatusApi.post)(AgentApiStatusApi(), "tenant-1", agent_id)
|
||||
assert enabled["enabled"] is True
|
||||
assert captured["enable"] == {"app": app_model, "enable_api": True}
|
||||
|
||||
keys = unwrap(AgentApiKeyListApi.get)(AgentApiKeyListApi(), "tenant-1", agent_id)
|
||||
assert keys == {"data": []}
|
||||
assert captured["list_keys"] == {"resource_id": "app-1", "tenant_id": "tenant-1"}
|
||||
|
||||
created, status = unwrap(AgentApiKeyListApi.post)(AgentApiKeyListApi(), "tenant-1", agent_id)
|
||||
assert status == 201
|
||||
assert created["id"] == api_key_id
|
||||
assert created["token"] == "app-test-token"
|
||||
assert captured["create_key"] == {"resource_id": "app-1", "tenant_id": "tenant-1"}
|
||||
|
||||
current_user = SimpleNamespace(id="account-1", is_admin_or_owner=True)
|
||||
deleted, delete_status = unwrap(AgentApiKeyApi.delete)(
|
||||
AgentApiKeyApi(),
|
||||
"tenant-1",
|
||||
current_user,
|
||||
agent_id,
|
||||
api_key_id,
|
||||
)
|
||||
assert (deleted, delete_status) == ("", 204)
|
||||
assert captured["delete_key"] == {
|
||||
"resource_id": "app-1",
|
||||
"api_key_id": api_key_id,
|
||||
"tenant_id": "tenant-1",
|
||||
"current_user": current_user,
|
||||
}
|
||||
|
||||
|
||||
def test_agent_app_update_rejects_empty_role(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
agent_id = "00000000-0000-0000-0000-000000000001"
|
||||
app_model = _app_detail_obj(id="app-1", bound_agent_id=agent_id)
|
||||
@ -441,7 +617,12 @@ def test_agent_app_update_rejects_empty_role(app: Flask, monkeypatch: pytest.Mon
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_app_backing_agent",
|
||||
lambda _self, **kwargs: SimpleNamespace(id=agent_id, role="", active_config_snapshot_id=None),
|
||||
lambda _self, **kwargs: SimpleNamespace(
|
||||
id=agent_id,
|
||||
role="",
|
||||
debug_conversation_id="debug-conversation-detail",
|
||||
active_config_snapshot_id=None,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.FeatureService,
|
||||
@ -464,7 +645,7 @@ def test_agent_app_update_rejects_empty_role(app: Flask, monkeypatch: pytest.Mon
|
||||
json={"name": "Renamed", "description": "", "role": "", "icon_type": "emoji", "icon": "R"},
|
||||
):
|
||||
with pytest.raises(ValueError, match="String should have at least 1 character"):
|
||||
unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", agent_id)
|
||||
unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", SimpleNamespace(id="account-1"), agent_id)
|
||||
|
||||
|
||||
def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -513,6 +694,13 @@ def test_agent_versions_call_services(app: Flask, monkeypatch: pytest.MonkeyPatc
|
||||
],
|
||||
},
|
||||
)
|
||||
captured_restore: dict[str, object] = {}
|
||||
|
||||
def restore_agent_version(_self, **kwargs):
|
||||
captured_restore.update(kwargs)
|
||||
return {"result": "success", "active_config_snapshot_id": kwargs["version_id"]}
|
||||
|
||||
monkeypatch.setattr(roster_controller.AgentRosterService, "restore_agent_version", restore_agent_version)
|
||||
|
||||
assert (
|
||||
unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"]
|
||||
@ -523,6 +711,16 @@ def test_agent_versions_call_services(app: Flask, monkeypatch: pytest.MonkeyPatc
|
||||
)
|
||||
assert version_detail["id"] == version_id
|
||||
assert version_detail["agent_id"] == agent_id
|
||||
restored = unwrap(AgentRosterVersionRestoreApi.post)(
|
||||
AgentRosterVersionRestoreApi(), "tenant-1", SimpleNamespace(id="account-1"), agent_id, version_id
|
||||
)
|
||||
assert restored == {"result": "success", "active_config_snapshot_id": version_id}
|
||||
assert captured_restore == {
|
||||
"tenant_id": "tenant-1",
|
||||
"agent_id": agent_id,
|
||||
"version_id": version_id,
|
||||
"account_id": "account-1",
|
||||
}
|
||||
|
||||
|
||||
def test_agent_observability_routes_resolve_app_from_agent_id(
|
||||
@ -870,7 +1068,7 @@ def test_agent_chat_generate_and_stop_routes_resolve_app_from_agent_id(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
|
||||
) -> None:
|
||||
agent_id = "00000000-0000-0000-0000-000000000001"
|
||||
app_model = SimpleNamespace(id="app-1", mode="agent")
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="agent")
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def resolve_agent_app_model(**kwargs: object) -> object:
|
||||
@ -909,7 +1107,7 @@ def test_agent_chat_generate_and_stop_routes_resolve_app_from_agent_id(
|
||||
def test_agent_chat_helper_forces_agent_streaming_and_external_trace(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
|
||||
) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode="agent")
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="agent")
|
||||
current_user = SimpleNamespace(id=account_id)
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
@ -918,6 +1116,11 @@ def test_agent_chat_helper_forces_agent_streaming_and_external_trace(
|
||||
return {"answer": "ok"}
|
||||
|
||||
monkeypatch.setattr(completion_controller.AppGenerateService, "generate", generate)
|
||||
monkeypatch.setattr(
|
||||
completion_controller,
|
||||
"_resolve_current_user_agent_debug_conversation_id",
|
||||
lambda **kwargs: "debug-conversation-1",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
completion_controller.helper,
|
||||
"compact_generate_response",
|
||||
@ -936,10 +1139,83 @@ def test_agent_chat_helper_forces_agent_streaming_and_external_trace(
|
||||
assert captured["streaming"] is True
|
||||
args = cast(dict[str, object], captured["args"])
|
||||
assert args["response_mode"] == "streaming"
|
||||
assert args["conversation_id"] == "debug-conversation-1"
|
||||
assert args["auto_generate_name"] is False
|
||||
assert args["external_trace_id"] == "trace-1"
|
||||
|
||||
|
||||
def test_agent_chat_helper_rejects_foreign_debug_conversation(
|
||||
app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
account_id: str,
|
||||
) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="agent")
|
||||
|
||||
monkeypatch.setattr(
|
||||
completion_controller,
|
||||
"_resolve_current_user_agent_debug_conversation_id",
|
||||
lambda **kwargs: "owned-conversation",
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
json={
|
||||
"inputs": {},
|
||||
"query": "hello",
|
||||
"response_mode": "streaming",
|
||||
"conversation_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
completion_controller._create_chat_message(
|
||||
current_tenant_id="tenant-1",
|
||||
current_user=SimpleNamespace(id=account_id),
|
||||
app_model=app_model,
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_current_user_agent_debug_conversation_uses_agent_or_backing_app(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
calls: list[dict[str, object]] = []
|
||||
|
||||
class FakeRosterService:
|
||||
def __init__(self, session: object) -> None:
|
||||
calls.append({"session": session})
|
||||
|
||||
def get_or_create_agent_app_debug_conversation_id(self, **kwargs: object) -> str:
|
||||
calls.append({"get_or_create": kwargs})
|
||||
return f"debug-{kwargs['agent_id']}"
|
||||
|
||||
def get_app_backing_agent(self, **kwargs: object) -> object:
|
||||
calls.append({"get_app_backing_agent": kwargs})
|
||||
return SimpleNamespace(id="backing-agent")
|
||||
|
||||
monkeypatch.setattr(completion_controller, "AgentRosterService", FakeRosterService)
|
||||
monkeypatch.setattr(completion_controller, "db", SimpleNamespace(session="session-1"))
|
||||
|
||||
explicit_id = completion_controller._resolve_current_user_agent_debug_conversation_id(
|
||||
current_tenant_id="tenant-1",
|
||||
current_user=SimpleNamespace(id="account-1"),
|
||||
app_model=SimpleNamespace(id="app-1"),
|
||||
agent_id="agent-1",
|
||||
)
|
||||
fallback_id = completion_controller._resolve_current_user_agent_debug_conversation_id(
|
||||
current_tenant_id="tenant-1",
|
||||
current_user=SimpleNamespace(id="account-1"),
|
||||
app_model=SimpleNamespace(id="app-1"),
|
||||
agent_id=None,
|
||||
)
|
||||
|
||||
assert explicit_id == "debug-agent-1"
|
||||
assert fallback_id == "debug-backing-agent"
|
||||
assert calls[1] == {"get_or_create": {"tenant_id": "tenant-1", "agent_id": "agent-1", "account_id": "account-1"}}
|
||||
assert calls[3] == {"get_app_backing_agent": {"tenant_id": "tenant-1", "app_id": "app-1"}}
|
||||
assert calls[4] == {
|
||||
"get_or_create": {"tenant_id": "tenant-1", "agent_id": "backing-agent", "account_id": "account-1"}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("error", "expected"),
|
||||
[
|
||||
@ -986,7 +1262,7 @@ def test_agent_chat_helper_maps_generation_errors(
|
||||
def test_agent_chat_message_routes_resolve_app_from_agent_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
agent_id = "00000000-0000-0000-0000-000000000001"
|
||||
message_id = "00000000-0000-0000-0000-000000000002"
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
app_model = SimpleNamespace(id="app-1", mode="agent")
|
||||
current_user = SimpleNamespace(id="account-1")
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
@ -1016,7 +1292,9 @@ def test_agent_chat_message_routes_resolve_app_from_agent_id(app: Flask, monkeyp
|
||||
monkeypatch.setattr(message_controller, "_get_message_suggested_questions", get_message_suggested_questions)
|
||||
monkeypatch.setattr(message_controller, "_get_message_detail", get_message_detail)
|
||||
|
||||
assert unwrap(AgentChatMessageListApi.get)(AgentChatMessageListApi(), "tenant-1", agent_id) == {"data": []}
|
||||
assert unwrap(AgentChatMessageListApi.get)(AgentChatMessageListApi(), "tenant-1", current_user, agent_id) == {
|
||||
"data": []
|
||||
}
|
||||
assert cast(dict[str, object], captured["list"])["app_model"] is app_model
|
||||
|
||||
with app.test_request_context(json={"message_id": message_id, "rating": "like"}):
|
||||
@ -1073,11 +1351,73 @@ def test_list_chat_messages_supports_first_id_pagination(app: Flask, monkeypatch
|
||||
"/console/api/agent/agent-1/chat-messages"
|
||||
f"?conversation_id={conversation_id}&first_id={first_message_id}&limit=1"
|
||||
):
|
||||
result = message_controller._list_chat_messages(app_model=SimpleNamespace(id="app-1"))
|
||||
result = message_controller._list_chat_messages(app_model=SimpleNamespace(id="app-1", mode="chat"))
|
||||
|
||||
assert result == {"data": [older_message_id], "limit": 1, "has_more": True}
|
||||
|
||||
|
||||
def test_list_agent_chat_messages_uses_current_user_conversation(
|
||||
app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
conversation_id = "00000000-0000-0000-0000-000000000010"
|
||||
message_id = "00000000-0000-0000-0000-000000000011"
|
||||
conversation = SimpleNamespace(id=conversation_id)
|
||||
message = SimpleNamespace(id=message_id, created_at=1)
|
||||
current_user = SimpleNamespace(id="account-1")
|
||||
app_model = SimpleNamespace(id="app-1", mode="agent")
|
||||
captured: dict[str, object] = {}
|
||||
session = SimpleNamespace(
|
||||
scalar=lambda _stmt: False,
|
||||
scalars=lambda _stmt: SimpleNamespace(all=lambda: [message]),
|
||||
)
|
||||
|
||||
class FakeMessagePaginationResponse:
|
||||
@classmethod
|
||||
def model_validate(cls, pagination: object, from_attributes: bool = False) -> object:
|
||||
return SimpleNamespace(
|
||||
model_dump=lambda mode: {
|
||||
"data": [item.id for item in pagination.data],
|
||||
"limit": pagination.limit,
|
||||
"has_more": pagination.has_more,
|
||||
}
|
||||
)
|
||||
|
||||
def get_conversation(**kwargs: object) -> object:
|
||||
captured.update(kwargs)
|
||||
return conversation
|
||||
|
||||
monkeypatch.setattr(message_controller.ConversationService, "get_conversation", get_conversation)
|
||||
monkeypatch.setattr(message_controller, "db", SimpleNamespace(session=session))
|
||||
monkeypatch.setattr(message_controller, "attach_message_extra_contents", lambda messages: None)
|
||||
monkeypatch.setattr(message_controller, "MessageInfiniteScrollPaginationResponse", FakeMessagePaginationResponse)
|
||||
|
||||
with app.test_request_context(f"/console/api/agent/agent-1/chat-messages?conversation_id={conversation_id}"):
|
||||
result = message_controller._list_chat_messages(app_model=app_model, current_user=current_user)
|
||||
|
||||
assert result == {"data": [message_id], "limit": 20, "has_more": False}
|
||||
assert captured == {"app_model": app_model, "conversation_id": conversation_id, "user": current_user}
|
||||
|
||||
|
||||
def test_list_agent_chat_messages_rejects_foreign_conversation(
|
||||
app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
conversation_id = "00000000-0000-0000-0000-000000000010"
|
||||
monkeypatch.setattr(
|
||||
message_controller.ConversationService,
|
||||
"get_conversation",
|
||||
lambda **kwargs: (_ for _ in ()).throw(message_controller.ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
with app.test_request_context(f"/console/api/agent/agent-1/chat-messages?conversation_id={conversation_id}"):
|
||||
with pytest.raises(NotFound):
|
||||
message_controller._list_chat_messages(
|
||||
app_model=SimpleNamespace(id="app-1", mode="agent"),
|
||||
current_user=SimpleNamespace(id="account-1"),
|
||||
)
|
||||
|
||||
|
||||
def test_update_message_feedback_rejects_empty_rating_without_existing_feedback(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
|
||||
@ -20,6 +20,10 @@ from controllers.console.app.agent_drive_inspector import (
|
||||
AgentDriveListByAgentApi,
|
||||
AgentDrivePreviewApi,
|
||||
AgentDrivePreviewByAgentApi,
|
||||
AgentDriveSkillInspectApi,
|
||||
AgentDriveSkillInspectByAgentApi,
|
||||
AgentDriveSkillListApi,
|
||||
AgentDriveSkillListByAgentApi,
|
||||
)
|
||||
from services.agent_drive_service import AgentDriveError
|
||||
|
||||
@ -97,6 +101,124 @@ def test_list_resolves_workflow_node_binding_agent():
|
||||
assert composer.resolve_workflow_node_agent_id.call_args.kwargs["node_id"] == "agent-node-1"
|
||||
|
||||
|
||||
def test_skill_list_by_agent_calls_service():
|
||||
raw = _raw(AgentDriveSkillListByAgentApi.get)
|
||||
with app.test_request_context("/"):
|
||||
with (
|
||||
patch(f"{_MOD}.resolve_agent_app_model", return_value=_APP) as resolve_app,
|
||||
patch(f"{_MOD}.AgentDriveService") as drive,
|
||||
):
|
||||
drive.return_value.list_skills.return_value = [
|
||||
{
|
||||
"path": "pdf-toolkit",
|
||||
"skill_md_key": "pdf-toolkit/SKILL.md",
|
||||
"archive_key": "pdf-toolkit/.DIFY-SKILL-FULL.zip",
|
||||
"name": "PDF Toolkit",
|
||||
"description": "Work with PDFs.",
|
||||
"size": 5,
|
||||
"mime_type": "text/markdown",
|
||||
"hash": None,
|
||||
"created_at": 1718000000,
|
||||
}
|
||||
]
|
||||
body = raw(AgentDriveSkillListByAgentApi(), "tenant-1", "agent-1")
|
||||
|
||||
assert body["items"][0]["path"] == "pdf-toolkit"
|
||||
resolve_app.assert_called_once_with(tenant_id="tenant-1", agent_id="agent-1")
|
||||
assert drive.return_value.list_skills.call_args.kwargs["agent_id"] == "agent-1"
|
||||
|
||||
|
||||
def test_skill_list_resolves_workflow_node_binding_agent():
|
||||
raw = _raw(AgentDriveSkillListApi.get)
|
||||
with app.test_request_context("/?node_id=agent-node-1"):
|
||||
with (
|
||||
patch(f"{_MOD}.AgentComposerService") as composer,
|
||||
patch(f"{_MOD}.AgentDriveService") as drive,
|
||||
):
|
||||
composer.resolve_workflow_node_agent_id.return_value = "wf-agent-9"
|
||||
drive.return_value.list_skills.return_value = []
|
||||
body = raw(AgentDriveSkillListApi(), _APP)
|
||||
|
||||
assert body == {"items": []}
|
||||
assert drive.return_value.list_skills.call_args.kwargs["agent_id"] == "wf-agent-9"
|
||||
|
||||
|
||||
def test_skill_inspect_by_agent_returns_strict_json_response():
|
||||
raw = _raw(AgentDriveSkillInspectByAgentApi.get)
|
||||
payload = {
|
||||
"path": "pdf-toolkit",
|
||||
"skill_md_key": "pdf-toolkit/SKILL.md",
|
||||
"archive_key": "pdf-toolkit/.DIFY-SKILL-FULL.zip",
|
||||
"name": "PDF Toolkit",
|
||||
"description": "Work with PDFs.",
|
||||
"size": 5,
|
||||
"mime_type": "text/markdown",
|
||||
"hash": None,
|
||||
"created_at": 1718000000,
|
||||
"source": "skill_md",
|
||||
"files": [
|
||||
{
|
||||
"path": "SKILL.md",
|
||||
"name": "SKILL.md",
|
||||
"type": "file",
|
||||
"drive_key": "pdf-toolkit/SKILL.md",
|
||||
"available_in_drive": True,
|
||||
}
|
||||
],
|
||||
"file_tree": [],
|
||||
"skill_md": {
|
||||
"key": "pdf-toolkit/SKILL.md",
|
||||
"size": 5,
|
||||
"truncated": False,
|
||||
"binary": False,
|
||||
"text": "# PDF Toolkit\nUse it.\n",
|
||||
},
|
||||
"warnings": [],
|
||||
}
|
||||
with app.test_request_context("/"):
|
||||
with (
|
||||
patch(f"{_MOD}.resolve_agent_app_model", return_value=_APP),
|
||||
patch(f"{_MOD}.AgentDriveService") as drive,
|
||||
):
|
||||
drive.return_value.inspect_skill.return_value = payload
|
||||
response = raw(AgentDriveSkillInspectByAgentApi(), "tenant-1", "agent-1", "pdf-toolkit")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["skill_md"]["text"] == "# PDF Toolkit\nUse it.\n"
|
||||
assert b"# PDF Toolkit\\nUse it.\\n" in response.get_data()
|
||||
|
||||
|
||||
def test_skill_inspect_resolves_workflow_node_binding_agent():
|
||||
raw = _raw(AgentDriveSkillInspectApi.get)
|
||||
payload = {
|
||||
"path": "pdf-toolkit",
|
||||
"skill_md_key": "pdf-toolkit/SKILL.md",
|
||||
"archive_key": None,
|
||||
"name": "PDF Toolkit",
|
||||
"description": "",
|
||||
"size": 5,
|
||||
"mime_type": "text/markdown",
|
||||
"hash": None,
|
||||
"created_at": None,
|
||||
"source": "skill_md",
|
||||
"files": [],
|
||||
"file_tree": [],
|
||||
"skill_md": {"key": "pdf-toolkit/SKILL.md", "size": 5, "truncated": False, "binary": False, "text": "# hi"},
|
||||
"warnings": [],
|
||||
}
|
||||
with app.test_request_context("/?node_id=agent-node-1"):
|
||||
with (
|
||||
patch(f"{_MOD}.AgentComposerService") as composer,
|
||||
patch(f"{_MOD}.AgentDriveService") as drive,
|
||||
):
|
||||
composer.resolve_workflow_node_agent_id.return_value = "wf-agent-9"
|
||||
drive.return_value.inspect_skill.return_value = payload
|
||||
response = raw(AgentDriveSkillInspectApi(), _APP, "pdf-toolkit")
|
||||
|
||||
assert response.get_json()["path"] == "pdf-toolkit"
|
||||
assert drive.return_value.inspect_skill.call_args.kwargs["agent_id"] == "wf-agent-9"
|
||||
|
||||
|
||||
def test_list_400_when_no_agent_bound():
|
||||
raw = _raw(AgentDriveListApi.get)
|
||||
app_without_agent = SimpleNamespace(id="app-1", tenant_id="tenant-1", bound_agent_id=None)
|
||||
|
||||
@ -185,7 +185,7 @@ class TestPaginationMapping:
|
||||
"name": "owner",
|
||||
"description": "",
|
||||
"is_builtin": True,
|
||||
"permission_keys": list(rbac_mod._LEGACY_ROLE_PERMISSION_KEYS["owner"]),
|
||||
"permission_keys": list(dict.fromkeys(rbac_mod._LEGACY_ROLE_PERMISSION_KEYS["owner"])),
|
||||
"role_tag": "owner",
|
||||
},
|
||||
{
|
||||
@ -196,7 +196,7 @@ class TestPaginationMapping:
|
||||
"name": "admin",
|
||||
"description": "",
|
||||
"is_builtin": True,
|
||||
"permission_keys": list(rbac_mod._LEGACY_ROLE_PERMISSION_KEYS["admin"]),
|
||||
"permission_keys": list(dict.fromkeys(rbac_mod._LEGACY_ROLE_PERMISSION_KEYS["admin"])),
|
||||
"role_tag": "",
|
||||
},
|
||||
]
|
||||
@ -336,23 +336,6 @@ class TestResourceAccessScopeBindings:
|
||||
|
||||
|
||||
class TestPaginationForwarding:
|
||||
def test_role_members_get_forwards_outer_pagination_params(self, app):
|
||||
with (
|
||||
app.test_request_context("/workspaces/current/rbac/roles/role-1/members?page=2&limit=50&reverse=true"),
|
||||
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
|
||||
patch("controllers.console.workspace.rbac.svc.RBACService.Roles.members") as mock_members,
|
||||
patch("controllers.console.workspace.rbac._dump", return_value={}),
|
||||
):
|
||||
inspect.unwrap(rbac_mod.RBACRoleMembersApi.get)(rbac_mod.RBACRoleMembersApi(), "role-1")
|
||||
|
||||
_, _, role_id = mock_members.call_args.args
|
||||
_, kwargs = mock_members.call_args
|
||||
assert role_id == "role-1"
|
||||
options = kwargs["options"]
|
||||
assert options.page_number == 2
|
||||
assert options.results_per_page == 50
|
||||
assert options.reverse is True
|
||||
|
||||
def test_access_policies_get_forwards_outer_pagination_params(self, app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
|
||||
@ -1,16 +1,21 @@
|
||||
import uuid
|
||||
|
||||
from controllers.openapi.auth.composition import account_pipeline, auth_router, external_sso_pipeline
|
||||
from controllers.openapi.auth.data import RequestContext
|
||||
from controllers.openapi.auth.data import RBACRequirement, RequestContext
|
||||
from controllers.openapi.auth.flow import When
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_private_app_permission,
|
||||
check_rbac_permission,
|
||||
check_workspace_member,
|
||||
check_workspace_mismatch,
|
||||
check_workspace_role,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
from core.rbac import RBACPermission, RBACResourceScope
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import TenantAccountRole
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
def test_account_pipeline_is_auth_pipeline():
|
||||
@ -29,8 +34,8 @@ def test_account_pipeline_prepare_has_six_entries():
|
||||
assert len(account_pipeline._prepare) == 6
|
||||
|
||||
|
||||
def test_account_auth_list_has_seven_entries():
|
||||
assert len(account_pipeline._auth) == 7
|
||||
def test_account_auth_list_has_eight_entries():
|
||||
assert len(account_pipeline._auth) == 8
|
||||
|
||||
|
||||
def test_external_sso_pipeline_prepare_has_four_entries():
|
||||
@ -132,3 +137,89 @@ def test_app_path_selects_workspace_mismatch_check():
|
||||
def test_workspace_path_skips_workspace_mismatch_check():
|
||||
steps = _selected_auth_steps(app_id=False, workspace_membership=True, allowed_roles=None)
|
||||
assert check_workspace_mismatch not in steps
|
||||
|
||||
|
||||
def _selected_webapp_steps(*, scope, app_access_mode):
|
||||
"""Select auth steps for an EE, webapp-auth-enabled, app-scoped request.
|
||||
|
||||
Patches the config-backed conditions (edition + webapp_auth) so the gating
|
||||
reduces to PATH_HAS_APP_ID, LOADED_APP_IS_PRIVATE, and the request scope.
|
||||
"""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
|
||||
ctx = RequestContext(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
scope=scope,
|
||||
path_params={"app_id": str(uuid.uuid4())},
|
||||
)
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="x",
|
||||
scopes=frozenset({scope}) if scope is not None else frozenset(),
|
||||
app_access_mode=app_access_mode,
|
||||
)
|
||||
features = MagicMock()
|
||||
features.webapp_auth.enabled = True
|
||||
selected = []
|
||||
with (
|
||||
patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.EE),
|
||||
patch("controllers.openapi.auth.conditions.FeatureService.get_system_features", return_value=features),
|
||||
):
|
||||
for step in account_pipeline._auth:
|
||||
if isinstance(step, When):
|
||||
if step.applies(ctx, data):
|
||||
selected.append(step._step)
|
||||
else:
|
||||
selected.append(step)
|
||||
return selected
|
||||
|
||||
|
||||
def test_apps_run_scope_selects_webapp_checks():
|
||||
steps = _selected_webapp_steps(scope=Scope.APPS_RUN, app_access_mode=WebAppAccessMode.PRIVATE)
|
||||
assert check_acl in steps
|
||||
assert check_private_app_permission in steps
|
||||
|
||||
|
||||
def test_management_scope_skips_webapp_checks_on_private_app():
|
||||
# Export DSL et al. carry an app_id but use a management scope; the webapp
|
||||
# end-user ACL / private-app gate must not block workspace members.
|
||||
steps = _selected_webapp_steps(scope=Scope.APPS_READ, app_access_mode=WebAppAccessMode.PRIVATE)
|
||||
assert check_acl not in steps
|
||||
assert check_private_app_permission not in steps
|
||||
|
||||
|
||||
def _selected_auth_steps_with_rbac(rbac):
|
||||
ctx = RequestContext(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
scope=Scope.APPS_READ,
|
||||
path_params={"app_id": str(uuid.uuid4())},
|
||||
rbac=rbac,
|
||||
)
|
||||
selected = []
|
||||
for step in account_pipeline._auth:
|
||||
if isinstance(step, When):
|
||||
if step.applies(ctx, None):
|
||||
selected.append(step._step)
|
||||
else:
|
||||
selected.append(step)
|
||||
return selected
|
||||
|
||||
|
||||
def test_account_pipeline_selects_rbac_step_when_required():
|
||||
rbac = RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_VIEW_LAYOUT)
|
||||
assert check_rbac_permission in _selected_auth_steps_with_rbac(rbac)
|
||||
|
||||
|
||||
def test_account_pipeline_skips_rbac_step_without_requirement():
|
||||
assert check_rbac_permission not in _selected_auth_steps_with_rbac(None)
|
||||
|
||||
|
||||
def test_external_sso_pipeline_never_enforces_rbac():
|
||||
# RBAC is a console (account) concern; external SSO callers are scope-gated.
|
||||
rbac_steps = [
|
||||
s._step for s in external_sso_pipeline._auth if isinstance(s, When) and s._step is check_rbac_permission
|
||||
]
|
||||
assert rbac_steps == []
|
||||
assert check_rbac_permission not in external_sso_pipeline._auth
|
||||
|
||||
@ -5,19 +5,22 @@ from controllers.openapi.auth.conditions import (
|
||||
EDITION_EE,
|
||||
EDITION_SAAS,
|
||||
HAS_ALLOWED_ROLES,
|
||||
HAS_RBAC,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
TOKEN_IS_OAUTH_ACCOUNT,
|
||||
TOKEN_IS_OAUTH_EXTERNAL_SSO,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
WEBAPP_RUN_SCOPED,
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED,
|
||||
Cond,
|
||||
config_cond,
|
||||
data_cond,
|
||||
request_cond,
|
||||
)
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext
|
||||
from libs.oauth_bearer import TokenType
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RBACRequirement, RequestContext
|
||||
from core.rbac import RBACPermission, RBACResourceScope
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import TenantAccountRole
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
@ -137,6 +140,34 @@ def test_webapp_auth_enabled():
|
||||
assert WEBAPP_AUTH_ENABLED(_ctx()) is True
|
||||
|
||||
|
||||
def test_webapp_run_scoped_true_for_apps_run():
|
||||
assert WEBAPP_RUN_SCOPED(_ctx(scope=Scope.APPS_RUN)) is True
|
||||
|
||||
|
||||
def test_webapp_run_scoped_false_for_management_scope():
|
||||
assert WEBAPP_RUN_SCOPED(_ctx(scope=Scope.APPS_READ)) is False
|
||||
|
||||
|
||||
def test_webapp_run_scoped_false_when_scope_none():
|
||||
assert WEBAPP_RUN_SCOPED(_ctx()) is False
|
||||
|
||||
|
||||
def _rbac_req():
|
||||
return RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN)
|
||||
|
||||
|
||||
def test_has_rbac_true():
|
||||
assert HAS_RBAC(_ctx(rbac=_rbac_req())) is True
|
||||
|
||||
|
||||
def test_has_rbac_false():
|
||||
assert HAS_RBAC(_ctx(rbac=None)) is False
|
||||
|
||||
|
||||
def test_has_rbac_default():
|
||||
assert HAS_RBAC(_ctx()) is False
|
||||
|
||||
|
||||
def test_loaded_app_is_private():
|
||||
data_private = _data(app_access_mode=WebAppAccessMode.PRIVATE)
|
||||
data_public = _data(app_access_mode=WebAppAccessMode.PUBLIC)
|
||||
|
||||
@ -5,17 +5,19 @@ import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_app_api_enabled,
|
||||
check_private_app_permission,
|
||||
check_rbac_permission,
|
||||
check_scope,
|
||||
check_workspace_member,
|
||||
check_workspace_mismatch,
|
||||
check_workspace_role,
|
||||
)
|
||||
from core.rbac import RBACPermission, RBACResourceScope
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Tenant, TenantAccountRole
|
||||
from models.model import App
|
||||
@ -75,6 +77,67 @@ def test_check_app_access_raises_when_not_member():
|
||||
check_app_access(data)
|
||||
|
||||
|
||||
# --- check_rbac_permission ---
|
||||
|
||||
_RBAC_REQ = RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_VIEW_LAYOUT)
|
||||
|
||||
|
||||
def test_check_rbac_noop_when_no_requirement():
|
||||
with patch("controllers.openapi.auth.verify.enforce_rbac_access") as mock_enforce:
|
||||
check_rbac_permission(_data(rbac=None, caller_kind="account"))
|
||||
mock_enforce.assert_not_called()
|
||||
|
||||
|
||||
def test_check_rbac_noop_when_rbac_disabled():
|
||||
with (
|
||||
patch("controllers.openapi.auth.verify.dify_config.RBAC_ENABLED", False),
|
||||
patch("controllers.openapi.auth.verify.enforce_rbac_access") as mock_enforce,
|
||||
):
|
||||
check_rbac_permission(_data(rbac=_RBAC_REQ, caller_kind="account"))
|
||||
mock_enforce.assert_not_called()
|
||||
|
||||
|
||||
def test_check_rbac_skips_end_user_caller():
|
||||
with (
|
||||
patch("controllers.openapi.auth.verify.dify_config.RBAC_ENABLED", True),
|
||||
patch("controllers.openapi.auth.verify.enforce_rbac_access") as mock_enforce,
|
||||
):
|
||||
check_rbac_permission(_data(rbac=_RBAC_REQ, caller_kind="end_user"))
|
||||
mock_enforce.assert_not_called()
|
||||
|
||||
|
||||
def test_check_rbac_raises_when_context_missing():
|
||||
with patch("controllers.openapi.auth.verify.dify_config.RBAC_ENABLED", True):
|
||||
with pytest.raises(Forbidden, match="rbac context missing"):
|
||||
check_rbac_permission(_data(rbac=_RBAC_REQ, caller_kind="account", account_id=None, tenant=None))
|
||||
|
||||
|
||||
def test_check_rbac_enforces_for_account_caller():
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tenant.id = "t1"
|
||||
account_id = uuid.uuid4()
|
||||
data = _data(
|
||||
rbac=_RBAC_REQ,
|
||||
caller_kind="account",
|
||||
account_id=account_id,
|
||||
tenant=tenant,
|
||||
path_params={"app_id": "app-1"},
|
||||
)
|
||||
with (
|
||||
patch("controllers.openapi.auth.verify.dify_config.RBAC_ENABLED", True),
|
||||
patch("controllers.openapi.auth.verify.enforce_rbac_access") as mock_enforce,
|
||||
):
|
||||
check_rbac_permission(data)
|
||||
mock_enforce.assert_called_once_with(
|
||||
tenant_id="t1",
|
||||
account_id=str(account_id),
|
||||
resource_type=RBACResourceScope.APP,
|
||||
scene=RBACPermission.APP_VIEW_LAYOUT,
|
||||
resource_required=True,
|
||||
path_args={"app_id": "app-1"},
|
||||
)
|
||||
|
||||
|
||||
def test_check_acl_raises_when_app_or_mode_missing():
|
||||
with pytest.raises(Forbidden):
|
||||
check_acl(_data(app=None, app_access_mode=None))
|
||||
|
||||
@ -20,6 +20,7 @@ def _stub_execute(
|
||||
edition=None,
|
||||
workspace_membership=False,
|
||||
allowed_roles=None,
|
||||
rbac=None,
|
||||
):
|
||||
"""Bypass all auth logic; inject minimal AuthData and call the view directly."""
|
||||
kwargs["auth_data"] = AuthData(
|
||||
@ -30,6 +31,7 @@ def _stub_execute(
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
required_scope=scope,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA
|
||||
from controllers.openapi.apps import _EMPTY_PARAMETERS, build_app_describe_response
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
|
||||
|
||||
class _FakeApp(SimpleNamespace):
|
||||
pass
|
||||
|
||||
|
||||
def _app() -> _FakeApp:
|
||||
from datetime import datetime
|
||||
|
||||
return _FakeApp(
|
||||
id="11111111-1111-1111-1111-111111111111",
|
||||
name="Demo",
|
||||
mode="chat",
|
||||
description="d",
|
||||
tags=[],
|
||||
author_name="me",
|
||||
updated_at=datetime(2026, 1, 1),
|
||||
enable_api=True,
|
||||
)
|
||||
|
||||
|
||||
def test_fields_none_returns_all_blocks(monkeypatch):
|
||||
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", lambda app: {"k": "v"})
|
||||
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", lambda app: {"s": 1})
|
||||
resp = build_app_describe_response(_app(), None)
|
||||
assert resp.info is not None
|
||||
assert resp.info.name == "Demo"
|
||||
assert resp.parameters == {"k": "v"}
|
||||
assert resp.input_schema == {"s": 1}
|
||||
|
||||
|
||||
def test_fields_subset_limits_blocks(monkeypatch):
|
||||
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", lambda app: {"k": "v"})
|
||||
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", lambda app: {"s": 1})
|
||||
resp = build_app_describe_response(_app(), ["info"])
|
||||
assert resp.info is not None
|
||||
assert resp.parameters is None
|
||||
assert resp.input_schema is None
|
||||
|
||||
|
||||
def test_info_omits_author_and_tags(monkeypatch):
|
||||
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", lambda app: {})
|
||||
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", lambda app: {})
|
||||
resp = build_app_describe_response(_app(), ["info"])
|
||||
assert resp.info is not None
|
||||
# Usage-face describe must not expose creator identity or tags (cross-tenant leak).
|
||||
assert not hasattr(resp.info, "author")
|
||||
assert not hasattr(resp.info, "tags")
|
||||
|
||||
|
||||
def test_parameters_fallback_on_app_unavailable(monkeypatch):
|
||||
def _raise(app):
|
||||
raise AppUnavailableError()
|
||||
|
||||
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", _raise)
|
||||
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", lambda app: {"s": 1})
|
||||
resp = build_app_describe_response(_app(), ["parameters"])
|
||||
assert resp.parameters == dict(_EMPTY_PARAMETERS)
|
||||
|
||||
|
||||
def test_input_schema_fallback_on_app_unavailable(monkeypatch):
|
||||
def _raise(app):
|
||||
raise AppUnavailableError()
|
||||
|
||||
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", lambda app: {"k": "v"})
|
||||
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", _raise)
|
||||
resp = build_app_describe_response(_app(), ["input_schema"])
|
||||
assert resp.input_schema == dict(EMPTY_INPUT_SCHEMA)
|
||||
@ -5,7 +5,7 @@ Runs against the model directly, not the HTTP layer. Pins:
|
||||
- workspace_id is required.
|
||||
- numeric bounds enforced (page >= 1, limit in [1, MAX_PAGE_LIMIT]).
|
||||
- mode validates against the AppMode enum.
|
||||
- name and tag have length caps.
|
||||
- name has a length cap.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -24,7 +24,6 @@ def test_defaults():
|
||||
assert q.limit == 20
|
||||
assert q.mode is None
|
||||
assert q.name is None
|
||||
assert q.tag is None
|
||||
|
||||
|
||||
def test_workspace_id_required():
|
||||
@ -80,12 +79,6 @@ def test_name_length_capped():
|
||||
AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "name": "x" * 201})
|
||||
|
||||
|
||||
def test_tag_length_capped():
|
||||
AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "tag": "x" * 100})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "tag": "x" * 101})
|
||||
|
||||
|
||||
def test_all_fields_accept_valid_values():
|
||||
"""Pin the happy-path acceptance for every field in one place."""
|
||||
q = AppListQuery.model_validate(
|
||||
@ -95,7 +88,6 @@ def test_all_fields_accept_valid_values():
|
||||
"limit": 50,
|
||||
"mode": "workflow",
|
||||
"name": "search",
|
||||
"tag": "prod",
|
||||
}
|
||||
)
|
||||
assert q.workspace_id == "00000000-0000-0000-0000-000000000001"
|
||||
@ -104,4 +96,3 @@ def test_all_fields_accept_valid_values():
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "workflow"
|
||||
assert q.name == "search"
|
||||
assert q.tag == "prod"
|
||||
|
||||
@ -26,11 +26,13 @@ from controllers.openapi._errors import (
|
||||
ErrorBody,
|
||||
ErrorDetail,
|
||||
FilenameNotExists,
|
||||
HumanInputFormNotFound,
|
||||
MemberLicenseExceeded,
|
||||
MemberLimitExceeded,
|
||||
OpenApiError,
|
||||
OpenApiErrorCode,
|
||||
OpenApiErrorFormatter,
|
||||
RecipientSurfaceMismatch,
|
||||
)
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -319,6 +321,8 @@ ERROR_MATRIX = [
|
||||
(BlockedFileExtensionError(), 400, "file_extension_blocked"),
|
||||
(MemberLimitExceeded(), 403, "member_limit_exceeded"),
|
||||
(MemberLicenseExceeded(), 403, "member_license_exceeded"),
|
||||
(HumanInputFormNotFound(), 404, "form_not_found"),
|
||||
(RecipientSurfaceMismatch(), 403, "recipient_surface_mismatch"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -11,8 +11,9 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi._errors import HumanInputFormNotFound, RecipientSurfaceMismatch
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.human_input import RecipientType
|
||||
@ -89,7 +90,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/bad"):
|
||||
with pytest.raises(NotFound):
|
||||
with pytest.raises(HumanInputFormNotFound):
|
||||
api.get.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
@ -101,7 +102,10 @@ class TestOpenApiHumanInputFormGet:
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = SimpleNamespace(
|
||||
app_id="other-app", tenant_id="tenant-1", expiration_time=datetime(2099, 1, 1, tzinfo=UTC)
|
||||
app_id="other-app",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
@ -114,7 +118,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
with pytest.raises(NotFound):
|
||||
with pytest.raises(HumanInputFormNotFound):
|
||||
api.get.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
@ -142,7 +146,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
with pytest.raises(NotFound):
|
||||
with pytest.raises(RecipientSurfaceMismatch):
|
||||
api.get.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
@ -234,6 +238,38 @@ class TestOpenApiHumanInputFormPost:
|
||||
)
|
||||
assert result == ({}, 200)
|
||||
|
||||
def test_post_standalone_web_app_recipient_submits(
|
||||
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = self._make_form(recipient_type=RecipientType.STANDALONE_WEB_APP)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
|
||||
module = sys.modules["controllers.openapi.human_input_form"]
|
||||
monkeypatch.setattr(module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="anyone")
|
||||
|
||||
with app.test_request_context(
|
||||
"/openapi/v1/apps/app-1/form/human_input/tok-1",
|
||||
method="POST",
|
||||
json={"action": "approve", "inputs": {}},
|
||||
):
|
||||
result = api.post.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "end_user"),
|
||||
)
|
||||
|
||||
service_mock.submit_form_by_token.assert_called_once()
|
||||
assert result == ({}, 200)
|
||||
|
||||
def test_post_rejects_invalid_body_with_422(self, app: Flask, bypass_pipeline):
|
||||
"""Malformed body → 422 via @accepts (was an unmapped pydantic error → 500)."""
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
@ -63,23 +63,19 @@ def test_envelope_uses_pep695_generics():
|
||||
|
||||
|
||||
def test_app_info_response_dump_matches_spec():
|
||||
from controllers.openapi._models import AppInfoResponse
|
||||
from controllers.openapi._models import AppInfo
|
||||
|
||||
obj = AppInfoResponse(
|
||||
obj = AppInfo(
|
||||
id="app1",
|
||||
name="X",
|
||||
description="d",
|
||||
mode="chat",
|
||||
author="alice",
|
||||
tags=[{"name": "prod"}],
|
||||
)
|
||||
assert obj.model_dump(mode="json") == {
|
||||
"id": "app1",
|
||||
"name": "X",
|
||||
"description": "d",
|
||||
"mode": "chat",
|
||||
"author": "alice",
|
||||
"tags": [{"name": "prod"}],
|
||||
}
|
||||
|
||||
|
||||
@ -91,8 +87,6 @@ def test_app_describe_response_nests_info_and_parameters():
|
||||
name="X",
|
||||
mode="chat",
|
||||
description=None,
|
||||
tags=[],
|
||||
author=None,
|
||||
updated_at="2026-05-05T00:00:00+00:00",
|
||||
service_api_enabled=True,
|
||||
)
|
||||
|
||||
@ -136,6 +136,55 @@ class TestAppParameterApi:
|
||||
assert "user_input_form" in response
|
||||
assert "opening_statement" in response
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
@patch("controllers.service_api.app.app._get_agent_app_feature_dict_and_user_input_form")
|
||||
def test_get_parameters_for_agent_app(
|
||||
self,
|
||||
mock_get_agent_parameters,
|
||||
mock_db,
|
||||
mock_validate_token,
|
||||
mock_current_app,
|
||||
mock_user_logged_in,
|
||||
app: Flask,
|
||||
mock_app_model,
|
||||
):
|
||||
"""Test retrieving parameters for an Agent App from Agent Soul app variables."""
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_app_model.mode = AppMode.AGENT
|
||||
mock_app_model.app_model_config = None
|
||||
mock_app_model.workflow = None
|
||||
mock_get_agent_parameters.return_value = (
|
||||
{"opening_statement": "Hi from Agent"},
|
||||
[{"text-input": {"label": "topic", "variable": "topic", "required": True}}],
|
||||
)
|
||||
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_db.session.get.side_effect = [mock_app_model, mock_tenant]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppParameterApi()
|
||||
response = api.get()
|
||||
|
||||
assert response["opening_statement"] == "Hi from Agent"
|
||||
assert response["user_input_form"] == [
|
||||
{"text-input": {"label": "topic", "variable": "topic", "required": True}}
|
||||
]
|
||||
mock_get_agent_parameters.assert_called_once_with(mock_app_model)
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
|
||||
@ -29,7 +29,7 @@ from core.app.entities.task_entities import (
|
||||
WorkflowPauseStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from core.workflow.human_input_policy import FormDisposition, HumanInputSurface
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
@ -592,8 +592,10 @@ class TestHitlServiceApi:
|
||||
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
workflow_response_converter,
|
||||
"load_form_tokens_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {"form-1": "token"},
|
||||
"load_form_dispositions_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {
|
||||
"form-1": FormDisposition(form_token="token", approval_channels=[])
|
||||
},
|
||||
)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
@ -652,8 +654,10 @@ class TestHitlServiceApi:
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
|
||||
resumption_context = _build_resumption_context("task-ctx")
|
||||
monkeypatch.setattr(
|
||||
"services.workflow_event_snapshot_service.load_form_tokens_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {"form-1": "wtok"},
|
||||
"services.workflow_event_snapshot_service.load_form_dispositions_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {
|
||||
"form-1": FormDisposition(form_token="wtok", approval_channels=[])
|
||||
},
|
||||
)
|
||||
|
||||
class _SessionContext:
|
||||
|
||||
@ -25,6 +25,15 @@ MINIMAL_GRAPH = {
|
||||
}
|
||||
|
||||
|
||||
def _patch_create_session(mock_session: MagicMock):
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = mock_session
|
||||
session_context.__exit__.return_value = False
|
||||
mock_session.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.begin.return_value.__exit__.return_value = False
|
||||
return patch("core.app.apps.advanced_chat.app_runner.create_session", return_value=session_context)
|
||||
|
||||
|
||||
class TestAdvancedChatAppRunnerConversationVariables:
|
||||
"""Test that AdvancedChatAppRunner correctly handles conversation variables."""
|
||||
|
||||
@ -135,10 +144,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
_patch_create_session(mock_session),
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(
|
||||
runner,
|
||||
@ -151,12 +158,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
@ -281,10 +282,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
_patch_create_session(mock_session),
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(
|
||||
runner,
|
||||
@ -298,12 +297,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock ConversationVariable.from_variable to return mock objects
|
||||
mock_conv_vars = []
|
||||
for var in workflow_vars:
|
||||
@ -434,10 +427,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
_patch_create_session(mock_session),
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(
|
||||
runner,
|
||||
@ -450,12 +441,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.advanced_chat.app_runner as module
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueStopEvent
|
||||
@ -85,27 +86,24 @@ def build_runner():
|
||||
|
||||
def _patch_common_run_deps(runner: AdvancedChatAppRunner):
|
||||
"""Context manager that patches common heavy deps used by run()."""
|
||||
# create_session() returns a context manager whose body yields a session that
|
||||
# supports both scalar() (app record lookup) and begin()/scalars().all()
|
||||
# (conversation variable initialization).
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = MagicMock()
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = mock_session
|
||||
session_context.__exit__.return_value = False
|
||||
mock_session.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.begin.return_value.__exit__.return_value = False
|
||||
|
||||
return patch.multiple(
|
||||
"core.app.apps.advanced_chat.app_runner",
|
||||
Session=MagicMock(
|
||||
return_value=MagicMock(
|
||||
__enter__=lambda s: s,
|
||||
__exit__=lambda *a, **k: False,
|
||||
scalar=lambda *a, **k: MagicMock(),
|
||||
),
|
||||
),
|
||||
sessionmaker=MagicMock(
|
||||
return_value=MagicMock(
|
||||
begin=MagicMock(
|
||||
return_value=MagicMock(
|
||||
__enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))),
|
||||
__exit__=lambda *a, **k: False,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
create_session=MagicMock(return_value=session_context),
|
||||
select=MagicMock(),
|
||||
db=MagicMock(engine=MagicMock()),
|
||||
session_factory=MagicMock(get_session_maker=MagicMock(return_value=MagicMock())),
|
||||
RedisChannel=MagicMock(),
|
||||
redis_client=MagicMock(),
|
||||
WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}),
|
||||
@ -192,3 +190,42 @@ def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_
|
||||
# Ensure no further steps executed
|
||||
mock_anno.assert_not_called()
|
||||
mock_init_graph.assert_not_called()
|
||||
|
||||
|
||||
def test_run_closes_scoped_session_before_workflow_run(build_runner):
|
||||
runner = build_runner
|
||||
events = []
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = MagicMock()
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = mock_session
|
||||
session_context.__exit__.return_value = False
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
|
||||
def run_workflow():
|
||||
events.append("run")
|
||||
return iter([])
|
||||
|
||||
workflow_entry.run.side_effect = run_workflow
|
||||
|
||||
with (
|
||||
patch.object(module, "create_session", return_value=session_context),
|
||||
patch.object(module, "session_factory", MagicMock(get_session_maker=MagicMock(return_value=MagicMock()))),
|
||||
patch.object(module, "RedisChannel"),
|
||||
patch.object(module, "redis_client"),
|
||||
patch.object(module, "WorkflowEntry", return_value=workflow_entry),
|
||||
patch.object(module.db.session, "close", side_effect=lambda: events.append("close")),
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(False, runner.application_generate_entity.inputs, runner.application_generate_entity.query),
|
||||
),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch.object(runner, "_initialize_conversation_variables", return_value=[]),
|
||||
patch.object(runner, "_init_graph", return_value=MagicMock()),
|
||||
):
|
||||
runner.run()
|
||||
|
||||
assert events == ["close", "run"]
|
||||
|
||||
@ -175,6 +175,7 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
"actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
|
||||
"display_in_ui": True,
|
||||
"form_token": "token-1",
|
||||
"approval_channels": [],
|
||||
"resolved_default_values": {},
|
||||
"expiration_time": 123,
|
||||
}
|
||||
|
||||
@ -19,6 +19,10 @@ def _soul() -> AgentSoulConfig:
|
||||
"model_settings": {"temperature": 0.2},
|
||||
},
|
||||
"prompt": {"system_prompt": "You are Iris."},
|
||||
"app_variables": [
|
||||
{"name": "topic", "type": "string", "required": True},
|
||||
{"name": "count", "type": "number", "default": 3},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@ -32,7 +36,10 @@ def test_model_and_prompt_come_from_soul():
|
||||
"completion_params": {"temperature": 0.2},
|
||||
}
|
||||
assert d["pre_prompt"] == "You are Iris."
|
||||
assert d["user_input_form"] == []
|
||||
assert d["user_input_form"] == [
|
||||
{"text-input": {"label": "topic", "variable": "topic", "required": True}},
|
||||
{"number": {"label": "count", "variable": "count", "required": False, "default": 3}},
|
||||
]
|
||||
|
||||
|
||||
def test_feature_flags_come_from_app_model_config_when_present():
|
||||
|
||||
@ -13,12 +13,24 @@ def runner():
|
||||
return AgentChatAppRunner()
|
||||
|
||||
|
||||
def patch_create_session(mocker: MockerFixture, *, return_value=None, side_effect=None):
|
||||
session = mocker.MagicMock()
|
||||
if side_effect is not None:
|
||||
session.scalar.side_effect = side_effect
|
||||
else:
|
||||
session.scalar.return_value = return_value
|
||||
session_context = mocker.MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.create_session", return_value=session_context)
|
||||
return session
|
||||
|
||||
|
||||
class TestAgentChatAppRunnerRun:
|
||||
def test_run_app_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
|
||||
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
|
||||
patch_create_session(mocker, return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
@ -37,7 +49,7 @@ class TestAgentChatAppRunnerRun:
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
@ -62,7 +74,7 @@ class TestAgentChatAppRunnerRun:
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
annotation = mocker.MagicMock(id="anno", content="answer")
|
||||
@ -91,7 +103,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -121,7 +133,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -163,7 +175,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -179,10 +191,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, conversation, message])
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
|
||||
@ -219,7 +228,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -235,10 +244,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, conversation, message])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
@ -267,7 +273,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -283,10 +289,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, conversation, message])
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
|
||||
@ -323,10 +326,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, None],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, None])
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -357,10 +357,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, mocker.MagicMock(id="conv"), None])
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -391,7 +388,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -407,10 +404,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, conversation, message])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import ANY, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -29,6 +30,19 @@ class DummyQueueManager:
|
||||
self.published.append((event, pub_from))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patched_create_session(*, return_value=None, side_effect=None):
|
||||
session = MagicMock()
|
||||
if side_effect is not None:
|
||||
session.scalar.side_effect = side_effect
|
||||
else:
|
||||
session.scalar.return_value = return_value
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
with patch("core.app.apps.chat.app_runner.create_session", return_value=session_context):
|
||||
yield session
|
||||
|
||||
|
||||
class TestChatAppGenerator:
|
||||
def test_generate_requires_query(self):
|
||||
generator = ChatAppGenerator()
|
||||
@ -167,7 +181,7 @@ class TestChatAppRunner:
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None):
|
||||
with patched_create_session(return_value=None):
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
@ -195,10 +209,7 @@ class TestChatAppRunner:
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
|
||||
patch.object(ChatAppRunner, "direct_output") as mock_direct,
|
||||
@ -233,10 +244,7 @@ class TestChatAppRunner:
|
||||
annotation = SimpleNamespace(id="ann-1", content="answer")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation),
|
||||
@ -272,13 +280,73 @@ class TestChatAppRunner:
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
|
||||
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True),
|
||||
):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
def test_run_closes_scoped_session_before_stream_consumption(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model="model-1", parameters={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=True,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
events = []
|
||||
queue_manager = DummyQueueManager()
|
||||
model_instance = MagicMock()
|
||||
|
||||
def invoke_stream():
|
||||
events.append("first-chunk")
|
||||
yield "chunk"
|
||||
|
||||
def invoke_llm(**kwargs):
|
||||
events.append("invoke")
|
||||
return invoke_stream()
|
||||
|
||||
with (
|
||||
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
|
||||
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=False),
|
||||
patch.object(ChatAppRunner, "recalc_llm_max_tokens"),
|
||||
patch.object(
|
||||
ChatAppRunner,
|
||||
"_handle_invoke_result",
|
||||
side_effect=lambda invoke_result, **kwargs: list(invoke_result),
|
||||
) as mock_handle,
|
||||
patch("core.app.apps.chat.app_runner.ModelInstance", return_value=model_instance),
|
||||
patch("core.app.apps.chat.app_runner.db.session.close", side_effect=lambda: events.append("close")),
|
||||
):
|
||||
model_instance.invoke_llm.side_effect = invoke_llm
|
||||
runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
assert events == ["close", "invoke", "first-chunk"]
|
||||
mock_handle.assert_called_once_with(
|
||||
invoke_result=ANY,
|
||||
queue_manager=queue_manager,
|
||||
stream=True,
|
||||
message_id="m1",
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
@ -47,25 +48,28 @@ def _build_generate_entity(app_config, file_upload_config=None):
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patched_create_session(*, return_value=None):
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = return_value
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
with patch.object(module, "create_session", return_value=session_context):
|
||||
yield session
|
||||
|
||||
|
||||
class TestCompletionAppRunner:
|
||||
def test_run_app_not_found(self, runner, mocker: MockerFixture):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock())
|
||||
with patched_create_session(return_value=None):
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock())
|
||||
|
||||
def test_run_moderation_error_outputs_direct(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
@ -74,7 +78,8 @@ class TestCompletionAppRunner:
|
||||
runner.direct_output = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
@ -82,10 +87,6 @@ class TestCompletionAppRunner:
|
||||
def test_run_hosting_moderation_stops(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
@ -94,18 +95,14 @@ class TestCompletionAppRunner:
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
|
||||
def test_run_dataset_and_external_tools_flow(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
session.close = MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
retrieve_config = MagicMock(query_variable="qvar")
|
||||
dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config)
|
||||
additional_features = MagicMock(show_retrieve_source=True)
|
||||
@ -135,19 +132,56 @@ class TestCompletionAppRunner:
|
||||
model_instance.invoke_llm.return_value = "invoke_result"
|
||||
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
|
||||
|
||||
dataset_retrieval.retrieve.assert_called_once()
|
||||
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_closes_scoped_session_before_stream_consumption(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
queue_manager = MagicMock()
|
||||
|
||||
events = []
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=False)
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock(side_effect=lambda invoke_result, **kwargs: list(invoke_result))
|
||||
|
||||
model_instance = MagicMock()
|
||||
|
||||
def invoke_stream():
|
||||
events.append("first-chunk")
|
||||
yield "chunk"
|
||||
|
||||
def invoke_llm(**kwargs):
|
||||
events.append("invoke")
|
||||
return invoke_stream()
|
||||
|
||||
model_instance.invoke_llm.side_effect = invoke_llm
|
||||
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
|
||||
mocker.patch.object(module.db.session, "close", side_effect=lambda: events.append("close"))
|
||||
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, queue_manager, MagicMock(id="msg"))
|
||||
|
||||
assert events == ["close", "invoke", "first-chunk"]
|
||||
runner._handle_invoke_result.assert_called_once_with(
|
||||
invoke_result=ANY,
|
||||
queue_manager=queue_manager,
|
||||
stream=True,
|
||||
message_id="msg",
|
||||
user_id="user",
|
||||
tenant_id="tenant",
|
||||
)
|
||||
|
||||
def test_run_uses_low_image_detail_default(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config, file_upload_config=None)
|
||||
|
||||
@ -155,7 +189,8 @@ class TestCompletionAppRunner:
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
assert (
|
||||
runner.organize_prompt_messages.call_args.kwargs["image_detail_config"]
|
||||
|
||||
@ -53,6 +53,32 @@ def _build_app_generate_entity() -> SimpleNamespace:
|
||||
)
|
||||
|
||||
|
||||
def _patch_create_session(mocker: MockerFixture, session: MagicMock, *, events: list[str] | None = None):
|
||||
"""Patch create_session() to yield ``session`` inside its ``with`` body and ``begin()`` block.
|
||||
|
||||
The runner now obtains short-lived sessions via ``create_session()`` instead of the
|
||||
Flask scoped ``db.session``, so tests patch the module-level ``create_session`` and
|
||||
hand back a context manager that yields the mock session.
|
||||
"""
|
||||
session_context = MagicMock()
|
||||
|
||||
def enter_session():
|
||||
if events is not None:
|
||||
events.append("session_enter")
|
||||
return session
|
||||
|
||||
def exit_session(*args):
|
||||
if events is not None:
|
||||
events.append("session_exit")
|
||||
return False
|
||||
|
||||
session_context.__enter__.side_effect = enter_session
|
||||
session_context.__exit__.side_effect = exit_session
|
||||
session.begin.return_value.__enter__.return_value = session
|
||||
session.begin.return_value.__exit__.return_value = False
|
||||
return mocker.patch.object(module, "create_session", return_value=session_context)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
@ -77,13 +103,14 @@ def test_get_app_id(runner):
|
||||
assert runner._get_app_id() == "pipe"
|
||||
|
||||
|
||||
def test_get_workflow_returns_workflow(mocker, runner):
|
||||
def test_get_workflow_returns_workflow(runner):
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe")
|
||||
workflow = MagicMock(id="wf")
|
||||
|
||||
mocker.patch.object(module.db, "session", MagicMock(scalar=MagicMock(return_value=workflow)))
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = workflow
|
||||
|
||||
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
|
||||
result = runner.get_workflow(session=session, pipeline=pipeline, workflow_id="wf")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
@ -116,7 +143,7 @@ def test_update_document_status_on_failure(mocker, runner):
|
||||
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = document
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session)
|
||||
|
||||
event = GraphRunFailedEvent(error="boom")
|
||||
|
||||
@ -124,7 +151,10 @@ def test_update_document_status_on_failure(mocker, runner):
|
||||
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error == "boom"
|
||||
session.commit.assert_called_once()
|
||||
session.add.assert_called_once_with(document)
|
||||
session.begin.assert_called_once()
|
||||
session.begin.return_value.__enter__.assert_called_once()
|
||||
session.begin.return_value.__exit__.assert_called_once()
|
||||
|
||||
|
||||
def test_run_pipeline_not_found(mocker: MockerFixture):
|
||||
@ -135,7 +165,7 @@ def test_run_pipeline_not_found(mocker: MockerFixture):
|
||||
|
||||
session = MagicMock()
|
||||
session.get.side_effect = [None, None]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
@ -158,7 +188,7 @@ def test_run_workflow_not_initialized(mocker: MockerFixture):
|
||||
|
||||
session = MagicMock()
|
||||
session.get.side_effect = [None, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
@ -184,7 +214,7 @@ def test_run_single_iteration_path(mocker: MockerFixture):
|
||||
|
||||
session = MagicMock()
|
||||
session.get.side_effect = [end_user, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
@ -229,10 +259,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture):
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
end_user = MagicMock(session_id="sess")
|
||||
events = []
|
||||
|
||||
session = MagicMock()
|
||||
session.get.side_effect = [end_user, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session, events=events)
|
||||
|
||||
workflow = MagicMock(
|
||||
id="wf",
|
||||
@ -276,10 +307,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture):
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
workflow_entry.graph_engine = MagicMock()
|
||||
workflow_entry.run.return_value = []
|
||||
workflow_entry.run.side_effect = lambda: events.append("workflow_run") or []
|
||||
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
|
||||
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
|
||||
|
||||
runner.run()
|
||||
|
||||
assert events == ["session_enter", "session_exit", "workflow_run"]
|
||||
runner._init_rag_pipeline_graph.assert_called_once()
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _message_queue_manager(app_mode: str) -> MessageBasedAppQueueManager:
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
return MessageBasedAppQueueManager(
|
||||
task_id="task-1",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
conversation_id="conversation-1",
|
||||
app_mode=app_mode,
|
||||
message_id="message-1",
|
||||
)
|
||||
|
||||
|
||||
def _workflow_queue_manager(app_mode: str) -> WorkflowAppQueueManager:
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
return WorkflowAppQueueManager(
|
||||
task_id="task-1",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
|
||||
def test_message_queue_does_not_raise_legacy_stop_for_advanced_chat() -> None:
|
||||
manager = _message_queue_manager(AppMode.ADVANCED_CHAT.value)
|
||||
|
||||
with patch.object(manager, "_is_stopped", return_value=True):
|
||||
manager.publish(QueueTextChunkEvent(text="chunk"), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
|
||||
def test_workflow_queue_does_not_read_legacy_stop_flag() -> None:
|
||||
manager = _workflow_queue_manager(AppMode.WORKFLOW.value)
|
||||
|
||||
with patch.object(manager, "_is_stopped", return_value=True) as is_stopped:
|
||||
manager.publish(QueueTextChunkEvent(text="chunk"), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
is_stopped.assert_not_called()
|
||||
|
||||
|
||||
def test_message_queue_keeps_legacy_stop_for_non_graphengine_chat() -> None:
|
||||
manager = _message_queue_manager(AppMode.CHAT.value)
|
||||
|
||||
with patch.object(manager, "_is_stopped", return_value=True):
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
manager.publish(QueueTextChunkEvent(text="chunk"), PublishFrom.APPLICATION_MANAGER)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user