Merge branch 'main' into feat/refine-snippet-siderbar

This commit is contained in:
JzoNg 2026-06-23 12:34:17 +08:00
commit ce9d1c74af
414 changed files with 14243 additions and 5224 deletions

View File

@ -36,6 +36,7 @@ Use this as the decision guide for React/TypeScript component structure. Existin
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow.
- Do not replace prop drilling with one top-level hook that returns a large view model and then thread that object through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it, or use feature-scoped Jotai atoms for simple shared form/UI state when siblings need the same source of truth.
- When using feature-scoped Jotai state for a form, drawer, or other secondary surface, scope the store to that surface instance when stale cross-instance state is possible. Initialize stable config at the owning boundary, then let descendants read only the atoms or purpose-named hooks they actually need.
- For Jotai-backed surfaces, put shared query atoms, mutation atoms, derived state, and write actions in the feature state file when they coordinate multiple descendants. The lowest-owner rule still applies to independent visual surfaces that do not participate in shared state.
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action.
- Prefer uncontrolled DOM state and CSS variables before adding controlled props.
@ -45,8 +46,10 @@ Use this as the decision guide for React/TypeScript component structure. Existin
- Atom order in the state file follows the dependency graph: types/constants, editable primitives, query atoms, query-data derived atoms, readiness/business derived atoms, write actions, mutation atoms, submission orchestration, provider exports.
- Derived atom names read as business facts. Write atom names read as user or workflow commands.
- UI components read and write the exact atom they use with `useAtomValue` or `useSetAtom`. Repeated workflow semantics live in named derived atoms or write atoms.
- Non-query derived atoms return a narrow value with a clear domain name. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract.
- Write-only atoms own state transitions that update multiple primitives, reset dependent state, guard stale async work, or advance the workflow.
- Non-query derived atoms return a narrow value with a clear domain name; avoid pass-through aliases or bundling unrelated UI facts. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract.
- Write-only atoms own synchronous state transitions that update multiple primitives, reset dependent state, or advance the workflow. Async work with loading, error, caching, retry, or stale-result concerns should be modeled as query or mutation atoms, with write atoms only changing the inputs that drive them.
- Avoid feature hooks that aggregate form values, query results, derived state, and commands for sibling components. Prefer named derived atoms and write atoms so UI components read the exact shared fact or command they need.
- When a form library owns validation, keep submit orchestration in feature state when post-submit result or error state is shared by the surface. Avoid duplicating validation gates or request shaping in UI hooks.
- `jotai-tanstack-query` atoms use the same QueryClient as the React Query provider. Query atoms belong in feature state when atoms are the feature's local state surface.
- Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query atoms keep shared cache behavior through the shared QueryClient.
@ -108,7 +111,7 @@ Use this as the decision guide for React/TypeScript component structure. Existin
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow.
- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment.
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
- Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook and forwards every returned field to one child, move the hook into that child or make the wrapper own a real surface.
- Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, children-as-pass-through composition, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook, forwards props, or passes trigger/content through to one child, move the logic into that child or make the wrapper own a real surface.
## You Might Not Need An Effect

View File

@ -768,7 +768,6 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while use redis as event bus.
# It's highly recommended to enable this for large deployments.
EVENT_BUS_REDIS_USE_CLUSTERS=false
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
# Whether to Enable human input timeout check task
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true

View File

@ -2,7 +2,6 @@ from typing import Literal, Protocol, cast
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
from pydantic.types import NonNegativeInt
from pydantic_settings import BaseSettings
@ -71,24 +70,6 @@ class RedisPubSubConfig(BaseSettings):
default=600,
)
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
description=(
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
"an SSE client and the response stream actually closing.\n\n"
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
"return promptly while the daemon listener thread cleans itself up on the next poll "
"boundary - safe because the listener holds no critical state and exits within one poll "
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
),
default=2000,
)
def _build_default_pubsub_url(self) -> str:
defaults = _redis_defaults(self)
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:

View File

@ -0,0 +1,107 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING
from services.enterprise import rbac_service as enterprise_rbac_service
if TYPE_CHECKING:
from services.app_service import AppListBaseParams
from services.enterprise.rbac_service import MyPermissionsResponse
# Permission keys (dot-notation, from MyPermissionsResponse) that grant
# list/preview access to an app. Keep this the single source of truth for both
# the console and OpenAPI app-list endpoints.
APP_LIST_PERMISSION_KEYS: frozenset[str] = frozenset({"app.preview", "app.acl.preview", "app.full_access"})
# Workspace permission key that lets a caller see apps they maintain even when
# those apps are not in their preview whitelist.
_MANAGE_OWN_APPS_PERMISSION_KEY = "app.create_and_management"
def has_app_list_permission(permission_keys: Sequence[str]) -> bool:
"""Return True if any of ``permission_keys`` grants app list/preview access."""
return any(permission_key in APP_LIST_PERMISSION_KEYS for permission_key in permission_keys)
@dataclass(frozen=True)
class AppAccessFilter:
"""Resolved RBAC visibility for app list/read endpoints.
``accessible_app_ids`` of ``None`` means the caller can see every app in the
workspace (unrestricted). Otherwise it is the exact set of app ids the
caller may preview; combined with ``can_manage_own_apps`` it also covers
apps the caller maintains.
"""
accessible_app_ids: set[str] | None
can_manage_own_apps: bool
@classmethod
def unrestricted(cls) -> AppAccessFilter:
"""Filter that imposes no restriction (RBAC disabled / not applicable)."""
return cls(accessible_app_ids=None, can_manage_own_apps=False)
def is_app_accessible(self, app_id: str, maintainer: str | None, account_id: str) -> bool:
"""Whether a single app is visible to the caller under this filter.
Mirrors the service-layer query gate: an app is visible when the filter
is unrestricted, the app id is whitelisted, or the caller maintains it
and holds ``app.create_and_management``.
"""
if self.accessible_app_ids is None:
return True
if app_id in self.accessible_app_ids:
return True
return self.can_manage_own_apps and maintainer is not None and maintainer == account_id
def apply_to_params(self, params: AppListBaseParams) -> None:
if self.accessible_app_ids is None:
return
params.accessible_app_ids = sorted(self.accessible_app_ids)
params.include_own_apps = self.can_manage_own_apps
def resolve_app_access_filter(
tenant_id: str,
account_id: str,
*,
permissions: MyPermissionsResponse | None = None,
) -> AppAccessFilter:
"""Compute the RBAC app-access filter for ``account_id`` in ``tenant_id``.
Pass ``permissions`` when the caller has already fetched the snapshot (the
console controller reuses it for per-app permission keys) to avoid a second
inner-API round trip; otherwise it is fetched here.
"""
if permissions is None:
permissions = enterprise_rbac_service.RBACService.MyPermissions.get(tenant_id, account_id)
whitelist_scope = enterprise_rbac_service.RBACService.AppAccess.whitelist_resources(tenant_id, account_id)
can_manage_own_apps = _MANAGE_OWN_APPS_PERMISSION_KEY in permissions.workspace.permission_keys
has_default_preview = has_app_list_permission(permissions.app.default_permission_keys) or has_app_list_permission(
permissions.workspace.permission_keys
)
permission_app_ids: set[str] | None = None
if not has_default_preview:
# Collect apps the caller can preview via per-app permission overrides.
permission_app_ids = {
override.resource_id
for override in permissions.app.overrides
if has_app_list_permission(override.permission_keys)
}
accessible_app_ids: set[str] | None
if getattr(whitelist_scope, "unrestricted", False):
accessible_app_ids = permission_app_ids
else:
accessible_app_ids = set(whitelist_scope.resource_ids)
if permission_app_ids is not None:
accessible_app_ids |= permission_app_ids
elif has_default_preview:
# Default preview overrides the whitelist restriction.
accessible_app_ids = None
return AppAccessFilter(accessible_app_ids=accessible_app_ids, can_manage_own_apps=can_manage_own_apps)

View File

@ -1,23 +1,3 @@
"""Shared decorator utilities for Dify controller layers.
This module provides decorators that are not tied to any single API group (e.g.
console, inner, service). Currently it exposes the RBAC permission gate, which
can be applied to any blueprint.
Key exports
-----------
``rbac_permission_required`` decorator that enforces enterprise RBAC access
control. When ``RBAC_ENABLED`` is ``False`` it is a no-op.
``RBACPermission``, ``RBACResourceScope`` re-exported from ``core.rbac`` so
callers only need a single import site.
Private helpers
---------------
``_extract_resource_id``, ``_is_resource_owned_by_current_user`` kept module-
private but accessible via the module namespace for unit-test patching.
"""
from collections.abc import Callable
from functools import wraps
@ -32,7 +12,57 @@ from models.dataset import Dataset
from models.model import App
from services.enterprise.rbac_service import RBACService
__all__ = ["RBACPermission", "RBACResourceScope", "rbac_permission_required"]
__all__ = ["RBACPermission", "RBACResourceScope", "enforce_rbac_access", "rbac_permission_required"]
def enforce_rbac_access(
*,
tenant_id: str,
account_id: str,
resource_type: RBACResourceScope,
scene: RBACPermission,
resource_required: bool = True,
path_args: dict[str, object] | None = None,
) -> None:
"""Enforce enterprise RBAC for an explicit account/tenant pair.
This is the flask-login-independent core of the RBAC gate so it can run
inside request-handling layers that resolve the caller themselves (e.g. the
openapi auth pipeline, which has the account on ``AuthData`` before
flask-login is mounted).
No-op when ``RBAC_ENABLED`` is ``False``. For resource-scoped checks the
resource ID is taken from ``path_args`` merged with ``request.view_args``;
resource ownership short-circuits the check. Raises ``Forbidden`` when
access is denied. For workspace-level checks pass ``resource_required=False``
so the RBAC request omits ``resource_id``.
Args:
tenant_id: The tenant the access is evaluated against.
account_id: The account requesting access.
resource_type: The :class:`RBACResourceScope` member (app/dataset/workspace).
scene: The :class:`RBACPermission` permission point, e.g. ``RBACPermission.APP_DELETE``.
resource_required: Whether a concrete resource ID is required.
path_args: Extra path arguments to merge with ``request.view_args``.
"""
if not dify_config.RBAC_ENABLED:
return
check_resource_type = None if resource_type == RBACResourceScope.WORKSPACE else resource_type
resource_id = None
if resource_required and check_resource_type:
resource_id = _extract_resource_id(resource_type, path_args)
if _is_resource_owned_by_current_user(tenant_id, account_id, resource_type, resource_id):
return
allowed = RBACService.CheckAccess.check(
tenant_id,
account_id,
scene=scene,
resource_type=check_resource_type,
resource_id=resource_id,
)
if not allowed:
raise Forbidden()
def rbac_permission_required[**P, R](
@ -41,14 +71,12 @@ def rbac_permission_required[**P, R](
*,
resource_required: bool = True,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Check enterprise RBAC permissions for the current user.
"""Check enterprise RBAC permissions for the current flask-login user.
When ``RBAC_ENABLED`` is ``False`` the decorator is a no-op and the
request passes through unchanged. When enabled it extracts the resource ID
from ``request.view_args`` for resource-scoped checks, calls the RBAC
service ``check-access`` endpoint, and raises ``Forbidden`` if the access
is denied. For workspace-level checks, set ``resource_required=False`` so
the RBAC request omits ``resource_id``.
request passes through unchanged. When enabled it resolves the current
account/tenant and delegates to :func:`enforce_rbac_access`, raising
``Forbidden`` if access is denied.
Args:
resource_type: The :class:`RBACResourceScope` member (app/dataset/workspace).
@ -63,23 +91,14 @@ def rbac_permission_required[**P, R](
return view(*args, **kwargs)
current_user, current_tenant_id = current_account_with_tenant()
check_resource_type = None if resource_type == RBACResourceScope.WORKSPACE else resource_type
resource_id = None
if resource_required and check_resource_type:
resource_id = _extract_resource_id(resource_type, kwargs)
if _is_resource_owned_by_current_user(current_tenant_id, current_user.id, resource_type, resource_id):
return view(*args, **kwargs)
allowed = RBACService.CheckAccess.check(
current_tenant_id,
current_user.id,
enforce_rbac_access(
tenant_id=current_tenant_id,
account_id=current_user.id,
resource_type=resource_type,
scene=scene,
resource_type=check_resource_type,
resource_id=resource_id,
resource_required=resource_required,
path_args=kwargs,
)
if not allowed:
raise Forbidden()
return view(*args, **kwargs)
return decorated

View File

@ -3,10 +3,12 @@ from uuid import UUID
from flask import abort, request
from flask_restx import Resource
from pydantic import AliasChoices, BaseModel, Field, field_validator
from sqlalchemy import func, select
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.agent.app_helpers import resolve_agent_app_model
from controllers.console.apikey import ApiKeyItem, ApiKeyList, BaseApiKeyListResource, BaseApiKeyResource
from controllers.console.app.app import (
AppDetailWithSite as GenericAppDetailWithSite,
)
@ -25,9 +27,13 @@ from controllers.console.app.app import (
UpdateAppPayload as GenericUpdateAppPayload,
)
from controllers.console.wraps import (
RBACPermission,
RBACResourceScope,
account_initialization_required,
edit_permission_required,
enterprise_license_required,
is_admin_or_owner_required,
rbac_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
@ -36,6 +42,7 @@ from extensions.ext_database import db
from fields.agent_fields import (
AgentConfigSnapshotDetailResponse,
AgentConfigSnapshotListResponse,
AgentConfigSnapshotRestoreResponse,
AgentInviteOptionsResponse,
AgentLogListResponse,
AgentLogMessageListResponse,
@ -48,7 +55,8 @@ from libs.datetime_utils import parse_time_range
from libs.helper import dump_response
from libs.login import login_required
from models import Account
from models.model import IconType
from models.enums import ApiTokenType
from models.model import ApiToken, App, IconType
from services.agent.errors import AgentNotFoundError
from services.agent.observability_service import (
AgentLogQueryParams,
@ -102,6 +110,27 @@ class AgentAppUpdatePayload(GenericUpdateAppPayload):
return role
class AgentApiStatusPayload(BaseModel):
enable_api: bool = Field(..., description="Enable or disable Agent service API")
class AgentApiAccessResponse(BaseModel):
enabled: bool
service_api_base_url: str
streaming_only: bool = True
chat_endpoint: str
stop_endpoint: str
conversations_endpoint: str
messages_endpoint: str
files_upload_endpoint: str
parameters_endpoint: str
info_endpoint: str
meta_endpoint: str
api_rpm: int
api_rph: int
api_key_count: int
class AgentAppPublishedReferenceResponse(BaseModel):
app_id: str
app_name: str
@ -185,6 +214,7 @@ class AgentStatisticsQuery(BaseModel):
class AgentAppPartial(GenericAppPartial):
app_id: str | None = None
debug_conversation_id: str | None = None
role: str | None = None
active_config_is_published: bool = False
published_reference_count: int = 0
@ -193,6 +223,7 @@ class AgentAppPartial(GenericAppPartial):
class AgentAppDetailWithSite(GenericAppDetailWithSite):
app_id: str | None = None
debug_conversation_id: str | None = None
role: str | None = None
active_config_is_published: bool = False
@ -207,6 +238,7 @@ register_schema_models(
console_ns,
AgentAppCreatePayload,
AgentAppUpdatePayload,
AgentApiStatusPayload,
CopyAppPayload,
AgentInviteOptionsQuery,
AgentLogsQuery,
@ -218,11 +250,13 @@ register_schema_models(
register_response_schema_models(
console_ns,
AgentAppPagination,
AgentApiAccessResponse,
AgentAppPublishedReferenceResponse,
AgentAppDetailWithSite,
AgentAppPartial,
AgentConfigSnapshotDetailResponse,
AgentConfigSnapshotListResponse,
AgentConfigSnapshotRestoreResponse,
AgentInviteOptionsResponse,
AgentLogListResponse,
AgentLogMessageListResponse,
@ -237,7 +271,7 @@ def _agent_roster_service() -> AgentRosterService:
return AgentRosterService(db.session)
def _serialize_agent_app_detail(app_model) -> dict:
def _serialize_agent_app_detail(app_model, *, current_user: Account) -> dict:
"""Serialize an Agent App detail using roster-only DTOs.
`/agent` responses are roster-shaped rather than raw app-shaped: `id`
@ -260,6 +294,11 @@ def _serialize_agent_app_detail(app_model) -> dict:
payload.pop("bound_agent_id", None)
payload["app_id"] = str(app_model.id)
payload["id"] = agent.id
payload["debug_conversation_id"] = roster_service.get_or_create_agent_app_debug_conversation_id(
tenant_id=app_model.tenant_id,
agent_id=agent.id,
account_id=current_user.id,
)
payload["role"] = agent.role or ""
payload["active_config_is_published"] = roster_service.active_config_is_published(
tenant_id=app_model.tenant_id,
@ -268,7 +307,7 @@ def _serialize_agent_app_detail(app_model) -> dict:
return payload
def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str, current_user: Account) -> dict:
"""Serialize Agent App lists with roster-shaped items.
Each item starts from the shared App list shape, then drops
@ -291,6 +330,11 @@ def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
tenant_id=tenant_id,
agent_ids=[agent.id for agent in agents_by_app_id.values()],
)
debug_conversation_ids_by_agent_id = roster_service.load_or_create_agent_app_debug_conversation_ids_by_agent_id(
tenant_id=tenant_id,
agents=list(agents_by_app_id.values()),
account_id=current_user.id,
)
payload = AgentAppPagination.model_validate(app_pagination, from_attributes=True).model_dump(mode="json")
for item in payload["data"]:
app_id = item["id"]
@ -299,6 +343,7 @@ def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
if agent:
item["app_id"] = app_id
item["id"] = agent.id
item["debug_conversation_id"] = debug_conversation_ids_by_agent_id.get(agent.id)
item["role"] = agent.role or ""
item["active_config_is_published"] = active_config_is_published_by_agent_id.get(agent.id, False)
published_references = published_references_by_agent_id.get(agent.id, [])
@ -323,6 +368,38 @@ def _resolve_agent_app_model(*, tenant_id: str, agent_id: UUID):
return resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
def _agent_api_key_count(app_id: str) -> int:
return (
db.session.scalar(
select(func.count(ApiToken.id)).where(
ApiToken.type == ApiTokenType.APP,
ApiToken.app_id == app_id,
)
)
or 0
)
def _serialize_agent_api_access(app_model: App) -> dict:
base_url = app_model.api_base_url
response = AgentApiAccessResponse(
enabled=bool(app_model.enable_api),
service_api_base_url=base_url,
chat_endpoint=f"{base_url}/chat-messages",
stop_endpoint=f"{base_url}/chat-messages/{{task_id}}/stop",
conversations_endpoint=f"{base_url}/conversations",
messages_endpoint=f"{base_url}/messages",
files_upload_endpoint=f"{base_url}/files/upload",
parameters_endpoint=f"{base_url}/parameters",
info_endpoint=f"{base_url}/info",
meta_endpoint=f"{base_url}/meta",
api_rpm=app_model.api_rpm or 0,
api_rph=app_model.api_rph or 0,
api_key_count=_agent_api_key_count(str(app_model.id)),
)
return response.model_dump(mode="json")
def _agent_observability_service() -> AgentObservabilityService:
return AgentObservabilityService(db.session)
@ -374,7 +451,11 @@ class AgentAppListApi(Resource):
empty = AgentAppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json")
return _serialize_agent_app_pagination(app_pagination, tenant_id=current_tenant_id)
return _serialize_agent_app_pagination(
app_pagination,
tenant_id=current_tenant_id,
current_user=current_user,
)
@console_ns.expect(console_ns.models[AgentAppCreatePayload.__name__])
@console_ns.response(201, "Agent app created successfully", console_ns.models[AgentAppDetailWithSite.__name__])
@ -399,7 +480,7 @@ class AgentAppListApi(Resource):
)
app = AppService().create_app(current_tenant_id, params, current_user)
return _serialize_agent_app_detail(app), 201
return _serialize_agent_app_detail(app, current_user=current_user), 201
@console_ns.route("/agent/<uuid:agent_id>")
@ -409,10 +490,11 @@ class AgentAppApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
def get(self, tenant_id: str, current_user: Account, agent_id: UUID):
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
return _serialize_agent_app_detail(app_model)
return _serialize_agent_app_detail(app_model, current_user=current_user)
@console_ns.expect(console_ns.models[AgentAppUpdatePayload.__name__])
@console_ns.response(200, "Agent app updated successfully", console_ns.models[AgentAppDetailWithSite.__name__])
@ -422,8 +504,9 @@ class AgentAppApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@with_current_tenant_id
def put(self, tenant_id: str, agent_id: UUID):
def put(self, tenant_id: str, current_user: Account, agent_id: UUID):
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
args = AgentAppUpdatePayload.model_validate(console_ns.payload)
args_dict: AppService.ArgsDict = {
@ -437,7 +520,7 @@ class AgentAppApi(Resource):
"role": args.role,
}
updated = AppService().update_app(app_model, args_dict)
return _serialize_agent_app_detail(updated)
return _serialize_agent_app_detail(updated, current_user=current_user)
@console_ns.response(204, "Agent app deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@ -476,7 +559,76 @@ class AgentAppCopyApi(Resource):
icon=args.icon,
icon_background=args.icon_background,
)
return _serialize_agent_app_detail(copied_app), 201
return _serialize_agent_app_detail(copied_app, current_user=current_user), 201
@console_ns.route("/agent/<uuid:agent_id>/api-access")
class AgentApiAccessApi(Resource):
@console_ns.response(200, "Agent service API access", console_ns.models[AgentApiAccessResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
return _serialize_agent_api_access(app_model)
@console_ns.route("/agent/<uuid:agent_id>/api-enable")
class AgentApiStatusApi(Resource):
@console_ns.expect(console_ns.models[AgentApiStatusPayload.__name__])
@console_ns.response(200, "Agent service API status updated", console_ns.models[AgentApiAccessResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_RELEASE_AND_VERSION)
@with_current_tenant_id
def post(self, tenant_id: str, agent_id: UUID):
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
args = AgentApiStatusPayload.model_validate(console_ns.payload)
app_model = AppService().update_app_api_status(app_model, args.enable_api)
return _serialize_agent_api_access(app_model)
@console_ns.route("/agent/<uuid:agent_id>/api-keys")
class AgentApiKeyListApi(BaseApiKeyListResource):
resource_type = ApiTokenType.APP
resource_model = App
resource_id_field = "app_id"
token_prefix = "app-"
@console_ns.response(200, "Agent service API keys", console_ns.models[ApiKeyList.__name__])
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID) -> dict[str, object]:
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
return dump_response(ApiKeyList, self._get_api_key_list(str(app_model.id), tenant_id))
@console_ns.response(201, "Agent service API key created", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
@with_current_tenant_id
@edit_permission_required
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_RELEASE_AND_VERSION)
def post(self, tenant_id: str, agent_id: UUID) -> tuple[dict[str, object], int]:
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
return dump_response(ApiKeyItem, self._create_api_key(str(app_model.id), tenant_id)), 201
@console_ns.route("/agent/<uuid:agent_id>/api-keys/<uuid:api_key_id>")
class AgentApiKeyApi(BaseApiKeyResource):
resource_type = ApiTokenType.APP
resource_model = App
resource_id_field = "app_id"
@console_ns.response(204, "Agent service API key deleted")
@with_current_user
@with_current_tenant_id
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_RELEASE_AND_VERSION)
def delete(self, tenant_id: str, current_user: Account, agent_id: UUID, api_key_id: UUID) -> tuple[str, int]:
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
self._delete_api_key(str(app_model.id), str(api_key_id), tenant_id, current_user)
return "", 204
@console_ns.route("/agent/invite-options")
@ -649,3 +801,24 @@ class AgentRosterVersionDetailApi(Resource):
version_id=str(version_id),
),
)
@console_ns.route("/agent/<uuid:agent_id>/versions/<uuid:version_id>/restore")
class AgentRosterVersionRestoreApi(Resource):
@console_ns.response(200, "Agent version restored", console_ns.models[AgentConfigSnapshotRestoreResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, current_user: Account, agent_id: UUID, version_id: UUID):
return dump_response(
AgentConfigSnapshotRestoreResponse,
_agent_roster_service().restore_agent_version(
tenant_id=tenant_id,
agent_id=str(agent_id),
version_id=str(version_id),
account_id=current_user.id,
),
)

View File

@ -10,8 +10,12 @@ backend — drive data lives in the API's own DB/storage, served straight from
from __future__ import annotations
import json
from collections.abc import Mapping
from typing import Any
from uuid import UUID
from flask import Response
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -49,6 +53,10 @@ class AgentDriveFileByAgentQuery(BaseModel):
key: str = Field(min_length=1, description="Drive key, e.g. tender-analyzer/SKILL.md")
class AgentDriveSkillInspectQuery(BaseModel):
node_id: str | None = Field(default=None, description="Workflow node ID (workflow composer variant)")
class AgentDriveItemResponse(ResponseModel):
key: str
size: int | None = None
@ -56,12 +64,63 @@ class AgentDriveItemResponse(ResponseModel):
hash: str | None = None
file_kind: str
created_at: int | None = None
is_skill: bool | None = None
skill_metadata: str | None = None
class AgentDriveListResponse(ResponseModel):
items: list[AgentDriveItemResponse] = Field(default_factory=list)
class AgentDriveSkillItemResponse(ResponseModel):
path: str
skill_md_key: str
archive_key: str | None = None
name: str
description: str
size: int | None = None
mime_type: str | None = None
hash: str | None = None
created_at: int | None = None
class AgentDriveSkillListResponse(ResponseModel):
items: list[AgentDriveSkillItemResponse] = Field(default_factory=list)
class AgentDriveSkillFileResponse(ResponseModel):
path: str
name: str
type: str
drive_key: str | None = None
available_in_drive: bool
class AgentDriveSkillMarkdownResponse(ResponseModel):
key: str
size: int | None = None
truncated: bool
binary: bool
text: str | None = None
class AgentDriveSkillInspectResponse(ResponseModel):
path: str
skill_md_key: str
archive_key: str | None = None
name: str
description: str
size: int | None = None
mime_type: str | None = None
hash: str | None = None
created_at: int | None = None
source: str
files: list[AgentDriveSkillFileResponse] = Field(default_factory=list)
file_tree: list[dict[str, Any]] = Field(default_factory=list)
skill_md: AgentDriveSkillMarkdownResponse
warnings: list[str] = Field(default_factory=list)
class AgentDrivePreviewResponse(ResponseModel):
key: str
size: int | None = None
@ -75,7 +134,12 @@ class AgentDriveDownloadResponse(ResponseModel):
register_response_schema_models(
console_ns, AgentDriveListResponse, AgentDrivePreviewResponse, AgentDriveDownloadResponse
console_ns,
AgentDriveDownloadResponse,
AgentDriveListResponse,
AgentDrivePreviewResponse,
AgentDriveSkillInspectResponse,
AgentDriveSkillListResponse,
)
@ -96,6 +160,13 @@ def _handle(exc: AgentDriveError) -> tuple[dict[str, object], int]:
return {"code": exc.code, "message": exc.message}, exc.status_code
def _json_response(data: Mapping[str, Any]):
return Response(
response=json.dumps(data, ensure_ascii=False, separators=(",", ":")),
content_type="application/json; charset=utf-8",
)
_WORKFLOW_APP_MODES = [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]
@ -119,6 +190,49 @@ class AgentDriveListByAgentApi(Resource):
return {"items": [{k: v for k, v in item.items() if k != "file_id"} for item in items]}
@console_ns.route("/agent/<uuid:agent_id>/drive/skills")
class AgentDriveSkillListByAgentApi(Resource):
@console_ns.doc("list_agent_drive_skills_by_agent")
@console_ns.doc(description="List drive-backed skills for an Agent App")
@console_ns.doc(params={"agent_id": "Agent ID"})
@console_ns.response(200, "Drive skills", console_ns.models[AgentDriveSkillListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
try:
items = AgentDriveService().list_skills(tenant_id=tenant_id, agent_id=str(agent_id))
except AgentDriveError as exc:
return _handle(exc)
return {"items": items}
@console_ns.route("/agent/<uuid:agent_id>/drive/skills/<path:skill_path>/inspect")
class AgentDriveSkillInspectByAgentApi(Resource):
@console_ns.doc("inspect_agent_drive_skill_by_agent")
@console_ns.doc(description="Inspect one drive-backed skill for slash-menu hover/detail UI")
@console_ns.doc(params={"agent_id": "Agent ID", "skill_path": "Skill path/slug, e.g. tender-analyzer"})
@console_ns.response(200, "Drive skill inspect view", console_ns.models[AgentDriveSkillInspectResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID, skill_path: str):
resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
try:
return _json_response(
AgentDriveService().inspect_skill(
tenant_id=tenant_id,
agent_id=str(agent_id),
skill_path=skill_path,
)
)
except AgentDriveError as exc:
return _handle(exc)
@console_ns.route("/agent/<uuid:agent_id>/drive/files/preview")
class AgentDrivePreviewByAgentApi(Resource):
@console_ns.doc("preview_agent_drive_file_by_agent")
@ -182,6 +296,61 @@ class AgentDriveListApi(Resource):
return {"items": [{k: v for k, v in item.items() if k != "file_id"} for item in items]}
@console_ns.route("/apps/<uuid:app_id>/agent/drive/skills")
class AgentDriveSkillListApi(Resource):
@console_ns.doc("list_agent_drive_skills")
@console_ns.doc(description="List drive-backed skills for the bound agent")
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentDriveListQuery)})
@console_ns.response(200, "Drive skills", console_ns.models[AgentDriveSkillListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=_WORKFLOW_APP_MODES)
def get(self, app_model: App):
query = query_params_from_request(AgentDriveListQuery)
agent_id = _resolve_agent_id(app_model, query.node_id)
if not agent_id:
return _agent_not_bound()
try:
items = AgentDriveService().list_skills(tenant_id=app_model.tenant_id, agent_id=agent_id)
except AgentDriveError as exc:
return _handle(exc)
return {"items": items}
@console_ns.route("/apps/<uuid:app_id>/agent/drive/skills/<path:skill_path>/inspect")
class AgentDriveSkillInspectApi(Resource):
@console_ns.doc("inspect_agent_drive_skill")
@console_ns.doc(description="Inspect one drive-backed skill for slash-menu hover/detail UI")
@console_ns.doc(
params={
"app_id": "Application ID",
"skill_path": "Skill path/slug, e.g. tender-analyzer",
**query_params_from_model(AgentDriveSkillInspectQuery),
}
)
@console_ns.response(200, "Drive skill inspect view", console_ns.models[AgentDriveSkillInspectResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=_WORKFLOW_APP_MODES)
def get(self, app_model: App, skill_path: str):
query = query_params_from_request(AgentDriveSkillInspectQuery)
agent_id = _resolve_agent_id(app_model, query.node_id)
if not agent_id:
return _agent_not_bound()
try:
return _json_response(
AgentDriveService().inspect_skill(
tenant_id=app_model.tenant_id,
agent_id=agent_id,
skill_path=skill_path,
)
)
except AgentDriveError as exc:
return _handle(exc)
@console_ns.route("/apps/<uuid:app_id>/agent/drive/files/preview")
class AgentDrivePreviewApi(Resource):
@console_ns.doc("preview_agent_drive_file")
@ -232,4 +401,8 @@ __all__ = [
"AgentDriveListByAgentApi",
"AgentDrivePreviewApi",
"AgentDrivePreviewByAgentApi",
"AgentDriveSkillInspectApi",
"AgentDriveSkillInspectByAgentApi",
"AgentDriveSkillListApi",
"AgentDriveSkillListByAgentApi",
]

View File

@ -14,6 +14,7 @@ from werkzeug.datastructures import MultiDict
from werkzeug.exceptions import BadRequest, NotFound
from configs import dify_config
from controllers.common.app_access import resolve_app_access_filter
from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
from controllers.common.helpers import FileInfo
from controllers.common.schema import (
@ -78,7 +79,6 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
_CREATOR_IDS_BRACKET_PATTERN = re.compile(r"^creator_ids\[(\d+)\]$")
AppListMode = Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"]
DEFAULT_APP_LIST_MODE: AppListMode = "all"
APP_LIST_PERMISSION_KEYS = frozenset({"app.preview", "app.acl.preview", "app.full_access"})
class AppListBaseQuery(BaseModel):
@ -167,10 +167,6 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
return normalized
def _has_app_list_permission(permission_keys: Sequence[str]) -> bool:
return any(permission_key in APP_LIST_PERMISSION_KEYS for permission_key in permission_keys)
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@ -612,38 +608,12 @@ class AppListApi(Resource):
current_user_id,
)
if dify_config.RBAC_ENABLED:
whitelist_scope = enterprise_rbac_service.RBACService.AppAccess.whitelist_resources(
access_filter = resolve_app_access_filter(
str(current_tenant_id),
current_user_id,
permissions=permissions,
)
can_manage_own_apps = "app.create_and_management" in permissions.workspace.permission_keys
has_default_preview = _has_app_list_permission(
permissions.app.default_permission_keys
) or _has_app_list_permission(permissions.workspace.permission_keys)
permission_app_ids: set[str] | None = None
if not has_default_preview:
permission_app_ids = {
override.resource_id
for override in permissions.app.overrides
if _has_app_list_permission(override.permission_keys)
}
if getattr(whitelist_scope, "unrestricted", False):
accessible_app_ids = permission_app_ids
else:
accessible_app_ids = set(whitelist_scope.resource_ids)
if permission_app_ids is not None:
accessible_app_ids |= permission_app_ids
elif has_default_preview:
accessible_app_ids = None
if accessible_app_ids:
params.accessible_app_ids = sorted(accessible_app_ids)
params.include_own_apps = can_manage_own_apps
elif accessible_app_ids is not None and can_manage_own_apps:
params.is_created_by_me = True
elif accessible_app_ids is not None:
params.accessible_app_ids = []
access_filter.apply_to_params(params)
# get app list
app_service = AppService()

View File

@ -40,12 +40,15 @@ from core.errors.error import (
QuotaExceededError,
)
from core.helper.trace_id_helper import get_external_trace_id
from extensions.ext_database import db
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from libs.login import login_required
from models import Account
from models.model import App, AppMode
from services.agent.errors import AgentNotFoundError
from services.agent.roster_service import AgentRosterService
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
@ -191,10 +194,11 @@ class ChatMessageApi(Resource):
@account_initialization_required
@edit_permission_required
@with_current_user
@with_current_tenant_id
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
def post(self, current_user: Account, app_model: App):
return _create_chat_message(current_user=current_user, app_model=app_model)
def post(self, current_tenant_id: str, current_user: Account, app_model: App):
return _create_chat_message(current_tenant_id=current_tenant_id, current_user=current_user, app_model=app_model)
@console_ns.route("/agent/<uuid:agent_id>/chat-messages")
@ -215,7 +219,12 @@ class AgentChatMessageApi(Resource):
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, agent_id: UUID):
app_model = resolve_agent_app_model(tenant_id=current_tenant_id, agent_id=agent_id)
return _create_chat_message(current_user=current_user, app_model=app_model)
return _create_chat_message(
current_tenant_id=current_tenant_id,
current_user=current_user,
app_model=app_model,
agent_id=str(agent_id),
)
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
@ -249,11 +258,45 @@ class AgentChatMessageStopApi(Resource):
return _stop_chat_message(current_user_id=current_user_id, app_model=app_model, task_id=task_id)
def _create_chat_message(*, current_user: Account, app_model: App):
def _resolve_current_user_agent_debug_conversation_id(
*, current_tenant_id: str, current_user: Account, app_model: App, agent_id: str | None
) -> str:
roster_service = AgentRosterService(db.session)
if agent_id:
return roster_service.get_or_create_agent_app_debug_conversation_id(
tenant_id=current_tenant_id,
agent_id=agent_id,
account_id=current_user.id,
)
agent = roster_service.get_app_backing_agent(tenant_id=current_tenant_id, app_id=str(app_model.id))
if agent is None:
raise AgentNotFoundError()
return roster_service.get_or_create_agent_app_debug_conversation_id(
tenant_id=current_tenant_id,
agent_id=agent.id,
account_id=current_user.id,
)
def _create_chat_message(
*, current_user: Account, app_model: App, current_tenant_id: str | None = None, agent_id: str | None = None
):
raw_payload = console_ns.payload or {}
args_model = ChatMessagePayload.model_validate(raw_payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
debug_conversation_id = _resolve_current_user_agent_debug_conversation_id(
current_tenant_id=current_tenant_id or app_model.tenant_id,
current_user=current_user,
app_model=app_model,
agent_id=agent_id,
)
if args_model.conversation_id and args_model.conversation_id != debug_conversation_id:
raise NotFound("Conversation Not Exists.")
args["conversation_id"] = debug_conversation_id
streaming = _resolve_debugger_chat_streaming(
app_mode=AppMode.value_of(app_model.mode),
response_mode=args_model.response_mode,

View File

@ -53,6 +53,7 @@ from libs.login import login_required
from models.account import Account
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService, attach_message_extra_contents
@ -186,10 +187,11 @@ class ChatMessageListApi(Resource):
@account_initialization_required
@setup_required
@edit_permission_required
@with_current_user
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_VIEW_LAYOUT)
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
def get(self, app_model: App):
return _list_chat_messages(app_model=app_model)
def get(self, current_user: Account, app_model: App):
return _list_chat_messages(app_model=app_model, current_user=current_user)
@console_ns.route("/agent/<uuid:agent_id>/chat-messages")
@ -205,10 +207,11 @@ class AgentChatMessageListApi(Resource):
@setup_required
@edit_permission_required
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_VIEW_LAYOUT)
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, agent_id: UUID):
def get(self, current_tenant_id: str, current_user: Account, agent_id: UUID):
app_model = resolve_agent_app_model(tenant_id=current_tenant_id, agent_id=agent_id)
return _list_chat_messages(app_model=app_model)
return _list_chat_messages(app_model=app_model, current_user=current_user)
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
@ -390,14 +393,24 @@ class AgentMessageApi(Resource):
return _get_message_detail(app_model=app_model, message_id=message_id)
def _list_chat_messages(*, app_model: App):
def _list_chat_messages(*, app_model: App, current_user: Account | None = None):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = db.session.scalar(
select(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.limit(1)
)
if AppMode.value_of(app_model.mode) == AppMode.AGENT and current_user is not None:
try:
conversation = ConversationService.get_conversation(
app_model=app_model,
conversation_id=args.conversation_id,
user=current_user,
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
else:
conversation = db.session.scalar(
select(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.limit(1)
)
if not conversation:
raise NotFound("Conversation Not Exists.")

View File

@ -83,7 +83,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
@login_required
@account_initialization_required
@is_admin_or_owner_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
@with_current_tenant_id
def post(self, current_tenant_id: str):

View File

@ -222,7 +222,7 @@ class DatasourceAuth(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})

View File

@ -5,6 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from configs import dify_config
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission
@ -17,6 +18,9 @@ def plugin_permission_required(
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if dify_config.RBAC_ENABLED:
return view(*args, **kwargs)
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
tenant_id = current_tenant_id

View File

@ -169,7 +169,7 @@ class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
@ -244,7 +244,7 @@ class ModelProviderCredentialSwitchApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
@ -326,7 +326,7 @@ class PreferredProviderTypeUpdateApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
@account_initialization_required
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):

View File

@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
@account_initialization_required
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
@ -481,7 +481,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):

View File

@ -469,6 +469,7 @@ class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_DEBUG, resource_required=False)
@plugin_permission_required(debug_required=True)
@with_current_tenant_id
def get(self, tenant_id: str):
@ -614,6 +615,7 @@ class PluginUploadFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -634,6 +636,7 @@ class PluginUploadFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -653,6 +656,7 @@ class PluginUploadFromBundleApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -673,6 +677,7 @@ class PluginInstallFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -693,6 +698,7 @@ class PluginInstallFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -719,6 +725,7 @@ class PluginInstallFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -739,6 +746,7 @@ class PluginFetchMarketplacePkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def get(self, tenant_id: str):
@ -764,6 +772,7 @@ class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def get(self, tenant_id: str):
@ -784,6 +793,7 @@ class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def get(self, tenant_id: str):
@ -801,6 +811,7 @@ class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def get(self, tenant_id: str, task_id: str):
@ -816,6 +827,7 @@ class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str, task_id: str):
@ -831,6 +843,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -846,6 +859,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str, task_id: str, identifier: str):
@ -862,6 +876,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -884,6 +899,7 @@ class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -911,6 +927,7 @@ class PluginUninstallApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
@plugin_permission_required(install_required=True)
@with_current_tenant_id
def post(self, tenant_id: str):
@ -1041,10 +1058,11 @@ class PluginChangeAutoUpgradeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_PREFERENCES, resource_required=False)
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
if not user.is_admin_or_owner:
if not dify_config.RBAC_ENABLED and not user.is_admin_or_owner:
raise Forbidden()
args = ParserAutoUpgradeChange.model_validate(console_ns.payload)
@ -1097,6 +1115,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_PREFERENCES, resource_required=False)
@with_current_tenant_id
def post(self, tenant_id: str):
# exclude one single plugin

View File

@ -211,7 +211,7 @@ def _legacy_workspace_roles(
name=role_name,
description="",
is_builtin=True,
permission_keys=list(_LEGACY_ROLE_PERMISSION_KEYS[role_name]),
permission_keys=list(dict.fromkeys(_LEGACY_ROLE_PERMISSION_KEYS[role_name])),
role_tag="owner" if role_name == "owner" else "",
)
for role_name in ("owner", "admin", "editor", "normal", "dataset_operator")
@ -244,11 +244,6 @@ def _legacy_workspace_roles(
)
# ---------------------------------------------------------------------------
# Permission catalogs.
# ---------------------------------------------------------------------------
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog")
class RBACWorkspaceCatalogApi(Resource):
@login_required
@ -375,30 +370,6 @@ class RBACRoleCopyApi(Resource):
return _dump(role), 201
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>/members")
class RBACRoleMembersApi(Resource):
@login_required
@rbac_permission_required(
RBACResourceScope.WORKSPACE, RBACPermission.WORKSPACE_ROLE_MANAGE, resource_required=False
)
@console_ns.response(200, "Success", console_ns.models[_RBACRoleAccountList.__name__])
def get(self, role_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.Roles.members(
tenant_id,
account_id,
str(role_id),
options=_pagination_options(),
)
)
# ---------------------------------------------------------------------------
# Access policies (tenant-level permission sets).
# ---------------------------------------------------------------------------
class _AccessPolicyCreateRequest(BaseModel):
name: str
resource_type: svc.RBACResourceType
@ -788,11 +759,6 @@ class RBACDatasetMemberBindingsApi(Resource):
return {"result": "success"}
# ---------------------------------------------------------------------------
# Workspace-level access (Settings > Access Rules).
# ---------------------------------------------------------------------------
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy")
class RBACWorkspaceAppMatrixApi(Resource):
@login_required

View File

@ -971,7 +971,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
@ -1070,6 +1070,7 @@ class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
@ -1125,6 +1126,7 @@ class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
@with_current_tenant_id
def put(self, current_tenant_id: str):
payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {})
@ -1178,6 +1180,7 @@ class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
@with_current_tenant_id
def delete(self, current_tenant_id: str):
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
@ -1196,6 +1199,7 @@ class ToolMCPAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
@with_current_tenant_id
def post(self, tenant_id: str):
payload = MCPAuthPayload.model_validate(console_ns.payload or {})
@ -1300,6 +1304,7 @@ class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
@with_current_tenant_id
def get(self, tenant_id: str, provider_id: str):
with sessionmaker(db.engine).begin() as session:

View File

@ -31,7 +31,7 @@ from controllers.openapi._models import (
AppDslExportQuery,
AppDslExportResponse,
AppDslImportPayload,
AppInfoResponse,
AppInfo,
AppListQuery,
AppListResponse,
AppListRow,
@ -62,7 +62,6 @@ from controllers.openapi._models import (
SessionListQuery,
SessionListResponse,
SessionRow,
TagItem,
TaskStopResponse,
UsageInfo,
WorkflowRunData,
@ -96,12 +95,11 @@ register_response_schema_models(
openapi_ns,
ErrorBody,
EventStreamResponse,
TagItem,
UsageInfo,
MessageMetadata,
AppListRow,
AppListResponse,
AppInfoResponse,
AppInfo,
AppDescribeInfo,
AppDescribeResponse,
AppDslExportResponse,

View File

@ -63,6 +63,8 @@ class OpenApiErrorCode(StrEnum):
FILE_EXTENSION_BLOCKED = "file_extension_blocked"
MEMBER_LIMIT_EXCEEDED = "member_limit_exceeded"
MEMBER_LICENSE_EXCEEDED = "member_license_exceeded"
HUMAN_INPUT_FORM_NOT_FOUND = "form_not_found"
RECIPIENT_SURFACE_MISMATCH = "recipient_surface_mismatch"
class ErrorDetail(BaseModel):
@ -239,3 +241,16 @@ class MemberLicenseExceeded(OpenApiError): # noqa: N818
error_code = OpenApiErrorCode.MEMBER_LICENSE_EXCEEDED
description = "Workspace member license capacity reached."
hint = "Contact your workspace administrator to expand the license seat count."
class HumanInputFormNotFound(OpenApiError): # noqa: N818
code = 404
error_code = OpenApiErrorCode.HUMAN_INPUT_FORM_NOT_FOUND
description = "No human-input form matches this token. It may be wrong, expired, or already submitted."
class RecipientSurfaceMismatch(OpenApiError): # noqa: N818
code = 403
error_code = OpenApiErrorCode.RECIPIENT_SURFACE_MISMATCH
description = "This form's recipient can't be submitted via the OpenAPI surface."
hint = "Action it through its channel (web app or console)."

View File

@ -38,18 +38,12 @@ class PaginationEnvelope[T](BaseModel):
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
class TagItem(BaseModel):
name: str
class AppListRow(BaseModel):
id: str
name: str
description: str | None = None
mode: AppMode
tags: list[TagItem] = []
updated_at: str | None = None
created_by_name: str | None = None
workspace_id: str | None = None
workspace_name: str | None = None
@ -70,16 +64,14 @@ class PermittedExternalAppsListResponse(BaseModel):
data: list[AppListRow]
class AppInfoResponse(BaseModel):
class AppInfo(BaseModel):
id: str
name: str
description: str | None = None
mode: str
author: str | None = None
tags: list[TagItem] = []
class AppDescribeInfo(AppInfoResponse):
class AppDescribeInfo(AppInfo):
updated_at: str | None = None
service_api_enabled: bool
is_agent: bool = False
@ -294,7 +286,6 @@ class AppListQuery(BaseModel):
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)
tag: str | None = Field(None, max_length=100)
class AppRunRequest(BaseModel):

View File

@ -5,11 +5,12 @@ from typing import cast
from flask_restx import Resource
from sqlalchemy.orm import Session
from controllers.common.wraps import RBACPermission, RBACResourceScope
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
from controllers.openapi._models import AppDslExportQuery, AppDslExportResponse, AppDslImportPayload
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.data import AuthData, RBACRequirement
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from models import Account, App
@ -37,6 +38,11 @@ class AppDslImportApi(Resource):
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
rbac=RBACRequirement(
resource_type=RBACResourceScope.APP,
scene=RBACPermission.APP_IMPORT_EXPORT_DSL,
resource_required=False,
),
)
@returns(200, Import, "Import completed")
@returns(202, Import, "Import pending confirmation")
@ -89,6 +95,11 @@ class AppDslImportConfirmApi(Resource):
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
rbac=RBACRequirement(
resource_type=RBACResourceScope.APP,
scene=RBACPermission.APP_IMPORT_EXPORT_DSL,
resource_required=False,
),
)
@returns(200, Import, "Import confirmed")
@returns(400, Import, "Import failed")
@ -125,6 +136,7 @@ class AppDslExportApi(Resource):
scope=Scope.APPS_READ,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_IMPORT_EXPORT_DSL),
)
@accepts(query=AppDslExportQuery)
@returns(200, AppDslExportResponse, "Export successful")
@ -155,6 +167,7 @@ class AppDslCheckDependenciesApi(Resource):
scope=Scope.APPS_READ,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_IMPORT_EXPORT_DSL),
)
@returns(200, CheckDependenciesResult, "Dependencies checked")
def get(self, app_id: str, *, auth_data: AuthData):

View File

@ -19,12 +19,13 @@ from werkzeug.exceptions import (
import services
from controllers.common.fields import EventStreamResponse
from controllers.common.wraps import RBACPermission, RBACResourceScope
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._contract import accepts, returns
from controllers.openapi._models import AppRunRequest, TaskStopResponse
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.data import AuthData, RBACRequirement
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@ -136,7 +137,10 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
@openapi_ns.route("/apps/<string:app_id>/run")
class AppRunApi(Resource):
@auth_router.guard(scope=Scope.APPS_RUN)
@auth_router.guard(
scope=Scope.APPS_RUN,
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
)
@openapi_ns.response(200, "Run result (SSE stream)", openapi_ns.models[EventStreamResponse.__name__])
@accepts(body=AppRunRequest)
def post(self, app_id: str, *, auth_data: AuthData, body: AppRunRequest):
@ -167,7 +171,10 @@ class AppRunApi(Resource):
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
class AppRunTaskStopApi(Resource):
@auth_router.guard(scope=Scope.APPS_RUN)
@auth_router.guard(
scope=Scope.APPS_RUN,
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
)
@returns(200, TaskStopResponse, description="Task stopped")
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
app_model, caller, caller_kind = auth_data.require_app_context()

View File

@ -8,7 +8,10 @@ from typing import Any, cast
from flask_restx import Resource
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
from configs import dify_config
from controllers.common.app_access import AppAccessFilter, resolve_app_access_filter
from controllers.common.fields import Parameters
from controllers.common.wraps import RBACPermission, RBACResourceScope
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
@ -19,18 +22,17 @@ from controllers.openapi._models import (
AppListQuery,
AppListResponse,
AppListRow,
TagItem,
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.data import AuthData, RBACRequirement
from controllers.service_api.app.error import AppUnavailableError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from models import App
from models.model import AppMode
from services.account_service import TenantService
from services.app_service import AppListParams, AppService
from services.tag_service import TagService
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
@ -84,54 +86,55 @@ def parameters_payload(app: App) -> dict:
return Parameters.model_validate(parameters).model_dump(mode="json")
def build_app_describe_response(app: App, fields: set[str] | None) -> AppDescribeResponse:
"""Public projection of an app (name / params / input schema) — never internal config."""
want_info = fields is None or "info" in fields
want_params = fields is None or "parameters" in fields
want_schema = fields is None or "input_schema" in fields
info = (
AppDescribeInfo(
id=str(app.id),
name=app.name,
mode=app.mode,
description=app.description,
updated_at=app.updated_at.isoformat() if app.updated_at else None,
service_api_enabled=bool(app.enable_api),
is_agent=app.mode in (AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT),
)
if want_info
else None
)
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
if want_params:
try:
parameters = parameters_payload(app)
except AppUnavailableError:
parameters = dict(_EMPTY_PARAMETERS)
if want_schema:
try:
input_schema = build_input_schema(app)
except AppUnavailableError:
input_schema = dict(EMPTY_INPUT_SCHEMA)
return AppDescribeResponse(info=info, parameters=parameters, input_schema=input_schema)
@openapi_ns.route("/apps/<string:app_id>/describe")
class AppDescribeApi(AppReadResource):
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(
scope=Scope.APPS_READ,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_VIEW_LAYOUT),
)
@returns(200, AppDescribeResponse, description="App description")
@accepts(query=AppDescribeQuery)
def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery):
# describe is UUID-only (workspace_id query param dropped in #37212).
app = self._load(app_id)
requested = query.fields
want_info = requested is None or "info" in requested
want_params = requested is None or "parameters" in requested
want_schema = requested is None or "input_schema" in requested
info = (
AppDescribeInfo(
id=str(app.id),
name=app.name,
mode=app.mode,
description=app.description,
tags=[TagItem(name=t.name) for t in app.tags],
author=app.author_name,
updated_at=app.updated_at.isoformat() if app.updated_at else None,
service_api_enabled=bool(app.enable_api),
is_agent=app.mode in ("agent-chat", "advanced-chat"),
)
if want_info
else None
)
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
if want_params:
try:
parameters = parameters_payload(app)
except AppUnavailableError:
parameters = dict(_EMPTY_PARAMETERS)
if want_schema:
try:
input_schema = build_input_schema(app)
except AppUnavailableError:
input_schema = dict(EMPTY_INPUT_SCHEMA)
return AppDescribeResponse(
info=info,
parameters=parameters,
input_schema=input_schema,
)
return build_app_describe_response(app, query.fields)
@openapi_ns.route("/apps")
@ -152,45 +155,55 @@ class AppListApi(Resource):
else:
parsed_uuid = None
# Compute RBAC-accessible app IDs when RBAC is enabled and the caller is an account.
# ``None`` means unrestricted (caller can see all apps in the workspace);
# an empty set or list means the caller has no accessible apps.
# End-users bypass RBAC here — their access is controlled by scope upstream.
apply_rbac_filter = (
dify_config.RBAC_ENABLED and auth_data.caller_kind != "end_user" and auth_data.account_id is not None
)
access_filter = AppAccessFilter.unrestricted()
if apply_rbac_filter:
access_filter = resolve_app_access_filter(workspace_id, str(auth_data.account_id))
tenant_name: str | None = None
if parsed_uuid is not None:
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
if app is None or str(app.tenant_id) != workspace_id:
return empty
# Apply RBAC visibility to the UUID fast-path the same way the service
# layer does for paginated queries (id in accessible set OR own app).
if apply_rbac_filter and not access_filter.is_app_accessible(
str(app.id), str(app.maintainer) if app.maintainer else None, str(auth_data.account_id)
):
return empty
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
item = AppListRow(
id=str(app.id),
name=app.name,
description=app.description,
mode=app.mode,
tags=[TagItem(name=t.name) for t in app.tags],
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=getattr(app, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
return env
tag_ids: list[str] | None = None
if query.tag:
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag, db.session)
if not tags:
return empty
tag_ids = [tag.id for tag in tags]
params = AppListParams(
page=query.page,
limit=query.limit,
mode=query.mode.value if query.mode else "all", # type:ignore
name=query.name,
tag_ids=tag_ids,
status="normal",
# Visibility gate pushed into the query — pagination.total stays
# consistent across pages because invisible rows never count.
openapi_visible=True,
)
if apply_rbac_filter:
access_filter.apply_to_params(params)
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params, db.session)
if pagination is None:
return empty
@ -205,9 +218,7 @@ class AppListApi(Resource):
name=r.name,
description=r.description,
mode=r.mode,
tags=[TagItem(name=t.name) for t in r.tags],
updated_at=r.updated_at.isoformat() if r.updated_at else None,
created_by_name=getattr(r, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)

View File

@ -8,14 +8,18 @@ EE blueprint chain so this module is unreachable there.
from __future__ import annotations
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
from controllers.openapi._models import (
AppDescribeQuery,
AppDescribeResponse,
AppListRow,
PermittedExternalAppsListQuery,
PermittedExternalAppsListResponse,
)
from controllers.openapi.apps import build_app_describe_response
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData, Edition
from extensions.ext_database import db
@ -67,9 +71,7 @@ class PermittedExternalAppsListApi(Resource):
name=app.name,
description=app.description,
mode=app.mode,
tags=[], # tenant-scoped; not surfaced cross-tenant
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=None, # cross-tenant author leak prevention
workspace_id=str(app.tenant_id),
workspace_name=tenant.name if tenant else None,
)
@ -82,3 +84,20 @@ class PermittedExternalAppsListApi(Resource):
data=items,
)
return env
@openapi_ns.route("/permitted-external-apps/<string:app_id>/describe")
class PermittedExternalAppDescribeApi(Resource):
@auth_router.guard(
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
edition=frozenset({Edition.EE}),
)
@returns(200, AppDescribeResponse, description="Permitted external app description")
@accepts(query=AppDescribeQuery)
def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery):
# App already loaded and ACL-checked by the external_sso pipeline; project it.
app = auth_data.app
if app is None:
raise NotFound("app not found")
return build_app_describe_response(app, query.fields)

View File

@ -3,9 +3,11 @@ from __future__ import annotations
from controllers.openapi.auth.conditions import (
EDITION_EE,
HAS_ALLOWED_ROLES,
HAS_RBAC,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
WEBAPP_AUTH_ENABLED,
WEBAPP_RUN_SCOPED,
WORKSPACE_MEMBERSHIP_REQUIRED,
WORKSPACE_SCOPED,
)
@ -25,6 +27,7 @@ from controllers.openapi.auth.verify import (
check_acl,
check_app_api_enabled,
check_private_app_permission,
check_rbac_permission,
check_scope,
check_workspace_member,
check_workspace_mismatch,
@ -47,8 +50,9 @@ account_pipeline = AuthPipeline(
When(WORKSPACE_SCOPED, then=check_workspace_member),
When(PATH_HAS_APP_ID, then=check_workspace_mismatch),
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
When(HAS_RBAC, then=check_rbac_permission),
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED & WEBAPP_RUN_SCOPED, then=check_acl),
When(EDITION_EE & LOADED_APP_IS_PRIVATE & WEBAPP_RUN_SCOPED, then=check_private_app_permission),
],
)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
from libs.oauth_bearer import TokenType
from libs.oauth_bearer import Scope, TokenType
from services.enterprise.enterprise_service import WebAppAccessMode
from services.feature_service import FeatureService
@ -50,8 +50,11 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
WEBAPP_RUN_SCOPED = request_cond(lambda ctx: ctx.scope == Scope.APPS_RUN)
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
HAS_RBAC = request_cond(lambda ctx: ctx.rbac is not None)
# Caller must belong to the resolved tenant: either an app-scoped path (tenant
# from the app) or an explicit workspace-membership path (tenant from request).

View File

@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import InternalServerError
from configs import dify_config
from core.rbac import RBACPermission, RBACResourceScope
from libs.oauth_bearer import Scope, TokenType
from models.account import Account, Tenant, TenantAccountRole
from models.model import App, EndUser
@ -35,6 +36,14 @@ class ExternalIdentity(BaseModel):
issuer: str | None = None
class RBACRequirement(BaseModel):
model_config = ConfigDict(frozen=True)
resource_type: RBACResourceScope
scene: RBACPermission
resource_required: bool = True
class RequestContext(BaseModel):
model_config = ConfigDict(frozen=True)
@ -43,6 +52,7 @@ class RequestContext(BaseModel):
path_params: dict[str, str]
workspace_membership: bool = False
allowed_roles: frozenset[TenantAccountRole] | None = None
rbac: RBACRequirement | None = None
class AuthData(BaseModel):
@ -59,6 +69,7 @@ class AuthData(BaseModel):
path_params: dict[str, str] = Field(default_factory=dict)
allowed_roles: frozenset[TenantAccountRole] | None = None
rbac: RBACRequirement | None = None
app: App | None = None
tenant: Tenant | None = None

View File

@ -21,6 +21,7 @@ from controllers.openapi.auth.data import (
AuthData,
Edition,
ExternalIdentity,
RBACRequirement,
RequestContext,
current_edition,
)
@ -59,6 +60,7 @@ class AuthPipeline:
scope: Scope | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
rbac: RBACRequirement | None = None,
) -> Any:
req_ctx = RequestContext(
token_type=identity.token_type,
@ -66,6 +68,7 @@ class AuthPipeline:
path_params=dict(request.view_args or {}),
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
rbac=rbac,
)
data = AuthData(
@ -77,6 +80,7 @@ class AuthPipeline:
tenants=dict(identity.verified_tenants),
required_scope=scope,
allowed_roles=allowed_roles,
rbac=rbac,
path_params=dict(req_ctx.path_params),
external_identity=(
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
@ -129,6 +133,7 @@ class PipelineRouter:
edition: frozenset[Edition] | None = None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
rbac: RBACRequirement | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
@ -136,6 +141,7 @@ class PipelineRouter:
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
rbac=rbac,
)
def guard_workspace(
@ -145,6 +151,7 @@ class PipelineRouter:
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
allowed_roles: frozenset[TenantAccountRole] | None = None,
rbac: RBACRequirement | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
@ -152,6 +159,7 @@ class PipelineRouter:
edition=edition,
workspace_membership=True,
allowed_roles=allowed_roles,
rbac=rbac,
)
def _make_decorator(
@ -162,6 +170,7 @@ class PipelineRouter:
edition: frozenset[Edition] | None,
workspace_membership: bool,
allowed_roles: frozenset[TenantAccountRole] | None,
rbac: RBACRequirement | None,
) -> Callable:
def decorator(view: Callable) -> Callable:
@wraps(view)
@ -175,6 +184,7 @@ class PipelineRouter:
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
rbac=rbac,
)
return decorated
@ -192,6 +202,7 @@ class PipelineRouter:
edition: frozenset[Edition] | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
rbac: RBACRequirement | None = None,
) -> Any:
# 404 not 403 — this edition doesn't expose the feature at all
if edition is not None and current_edition() not in edition:
@ -235,6 +246,7 @@ class PipelineRouter:
scope=scope,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
rbac=rbac,
)

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from flask import request
from werkzeug.exceptions import Forbidden, NotFound, UnprocessableEntity
from configs import dify_config
from controllers.common.wraps import enforce_rbac_access
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
@ -38,6 +40,9 @@ def check_workspace_mismatch(data: AuthData) -> None:
def check_workspace_role(data: AuthData) -> None:
if dify_config.RBAC_ENABLED and data.rbac is not None:
# fine-grained permission check is performed by RBAC
return
if data.allowed_roles is None:
return
if data.tenant_role is None:
@ -46,6 +51,27 @@ def check_workspace_role(data: AuthData) -> None:
raise Forbidden("insufficient workspace role")
def check_rbac_permission(data: AuthData) -> None:
req = data.rbac
if req is None:
return
if not dify_config.RBAC_ENABLED:
return
# Only account callers are subject to RBAC; end_user access is scope-controlled.
if data.caller_kind != "account":
return
if data.account_id is None or data.tenant is None:
raise Forbidden("rbac context missing")
enforce_rbac_access(
tenant_id=str(data.tenant.id),
account_id=str(data.account_id),
resource_type=req.resource_type,
scene=req.scene,
resource_required=req.resource_required,
path_args=dict(data.path_params),
)
def check_app_api_enabled(data: AuthData) -> None:
if data.app is None:
return

View File

@ -12,16 +12,21 @@ import logging
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from werkzeug.exceptions import BadRequest
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_schema_models
from controllers.common.wraps import RBACPermission, RBACResourceScope
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
from controllers.openapi._errors import HumanInputFormNotFound, RecipientSurfaceMismatch
from controllers.openapi._models import FormSubmitResponse, HumanInputFormDefinitionResponse
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from controllers.openapi.auth.data import AuthData, RBACRequirement
from core.workflow.human_input_policy import (
HumanInputSurface,
is_recipient_type_allowed_for_surface,
)
from extensions.ext_database import db
from libs.helper import to_timestamp
from libs.oauth_bearer import Scope
@ -47,31 +52,37 @@ def _jsonify_form_definition(form) -> Response:
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
raise HumanInputFormNotFound()
def _ensure_form_is_allowed_for_openapi(form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
raise NotFound("Form not found")
raise RecipientSurfaceMismatch()
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
class OpenApiWorkflowHumanInputFormApi(Resource):
@openapi_ns.response(200, "Form definition", openapi_ns.models[HumanInputFormDefinitionResponse.__name__])
@auth_router.guard(scope=Scope.APPS_RUN)
@auth_router.guard(
scope=Scope.APPS_RUN,
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
)
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
app_model, caller, caller_kind = auth_data.require_app_context()
app_model, _caller, _caller_kind = auth_data.require_app_context()
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
raise HumanInputFormNotFound()
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_openapi(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@auth_router.guard(scope=Scope.APPS_RUN)
@auth_router.guard(
scope=Scope.APPS_RUN,
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
)
@returns(200, FormSubmitResponse, description="Form submitted")
@accepts(body=HumanInputFormSubmitPayload)
def post(self, app_id: str, form_token: str, *, auth_data: AuthData, body: HumanInputFormSubmitPayload):
@ -80,7 +91,7 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
raise HumanInputFormNotFound()
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_openapi(form)
@ -106,6 +117,6 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
submission_end_user_id=submission_end_user_id,
)
except FormNotFoundError:
raise NotFound("Form not found")
raise HumanInputFormNotFound()
return FormSubmitResponse()

View File

@ -19,9 +19,10 @@ from werkzeug.exceptions import NotFound, UnprocessableEntity
from controllers.common.fields import EventStreamResponse
from controllers.common.schema import query_params_from_model
from controllers.common.wraps import RBACPermission, RBACResourceScope
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.data import AuthData, RBACRequirement
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
@ -46,7 +47,10 @@ class WorkflowEventsQuery(BaseModel):
class OpenApiWorkflowEventsApi(Resource):
@openapi_ns.doc(params=query_params_from_model(WorkflowEventsQuery))
@openapi_ns.response(200, "SSE event stream", openapi_ns.models[EventStreamResponse.__name__])
@auth_router.guard(scope=Scope.APPS_RUN)
@auth_router.guard(
scope=Scope.APPS_RUN,
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
)
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
app_model, caller, caller_kind = auth_data.require_app_context()
app_mode = AppMode.value_of(app_model.mode)

View File

@ -2,6 +2,7 @@ from typing import Any, cast
from flask_restx import Resource
from pydantic import Field
from sqlalchemy import select
from controllers.common.fields import Parameters
from controllers.common.schema import register_response_schema_models
@ -9,7 +10,11 @@ from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.agent_app.app_variable_projection import agent_app_variables_to_user_input_form
from extensions.ext_database import db
from fields.base import ResponseModel
from models.agent import Agent, AgentConfigSnapshot, AgentScope, AgentSource, AgentStatus
from models.agent_config_entities import AgentSoulConfig
from models.model import App, AppMode
from services.app_service import AppService
@ -29,6 +34,40 @@ class AppMetaResponse(ResponseModel):
register_response_schema_models(service_api_ns, Parameters, AppMetaResponse, AppInfoResponse)
def _get_agent_app_feature_dict_and_user_input_form(app_model: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
app_model_config = app_model.app_model_config
features_dict = cast(dict[str, Any], app_model_config.to_dict()) if app_model_config is not None else {}
agent = db.session.scalar(
select(Agent)
.where(
Agent.tenant_id == app_model.tenant_id,
Agent.app_id == app_model.id,
Agent.scope == AgentScope.ROSTER,
Agent.source == AgentSource.AGENT_APP,
Agent.status == AgentStatus.ACTIVE,
)
.limit(1)
)
if agent is None or not agent.active_config_snapshot_id:
raise AppUnavailableError()
snapshot = db.session.scalar(
select(AgentConfigSnapshot)
.where(
AgentConfigSnapshot.tenant_id == app_model.tenant_id,
AgentConfigSnapshot.agent_id == agent.id,
AgentConfigSnapshot.id == agent.active_config_snapshot_id,
)
.limit(1)
)
if snapshot is None:
raise AppUnavailableError()
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
return features_dict, agent_app_variables_to_user_input_form(agent_soul.app_variables)
@service_api_ns.route("/parameters")
class AppParameterApi(Resource):
"""Resource for app variables."""
@ -61,12 +100,16 @@ class AppParameterApi(Resource):
Returns the input form parameters and configuration for the application.
"""
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
features_dict: dict[str, Any]
user_input_form: list[dict[str, Any]]
if app_model.mode == AppMode.AGENT:
features_dict, user_input_form = _get_agent_app_feature_dict_and_user_input_form(app_model)
elif app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict: dict[str, Any] = workflow.features_dict
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -22,7 +22,7 @@ from core.app.entities.queue_entities import (
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.db.session_factory import session_factory
from core.db.session_factory import create_session, session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
@ -107,7 +107,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
with Session(db.engine, expire_on_commit=False) as session:
with create_session() as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record:
@ -204,6 +204,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
)
# Release the Flask scoped session before workflow execution so a checked-out DB connection
# is not held for the lifetime of the graph run.
db.session.close()
# RUN WORKFLOW
@ -368,7 +370,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
:return: List of conversation variables ready for use
"""
with sessionmaker(bind=db.engine).begin() as session:
with create_session() as session, session.begin():
existing_variables = self._load_existing_conversation_variables(session)
if not existing_variables:

View File

@ -21,6 +21,7 @@ from core.app.app_config.entities import (
EasyUIBasedAppModelConfigFrom,
PromptTemplateEntity,
)
from core.app.apps.agent_app.app_variable_projection import agent_app_variables_to_user_input_form
from models.agent_config_entities import AgentSoulConfig
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
@ -98,8 +99,7 @@ class AgentAppConfigManager(BaseAppConfigManager):
# pipeline's bookkeeping (token counting, persistence).
base["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
base["pre_prompt"] = agent_soul.prompt.system_prompt or ""
# Agent App takes the user message directly; no completion-style inputs form.
base.setdefault("user_input_form", [])
base["user_input_form"] = agent_app_variables_to_user_input_form(agent_soul.app_variables)
return base

View File

@ -0,0 +1,37 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from models.agent_config_entities import AppVariableConfig
def agent_app_variables_to_user_input_form(app_variables: Sequence[AppVariableConfig]) -> list[dict[str, Any]]:
"""Project Agent Soul app variables into the legacy service-API parameter form."""
user_input_form: list[dict[str, Any]] = []
for variable in app_variables:
form_type = _form_type_for_agent_variable(variable.type)
form_item: dict[str, Any] = {
"label": variable.name,
"variable": variable.name,
"required": variable.required,
}
if variable.default is not None:
form_item["default"] = variable.default
user_input_form.append({form_type: form_item})
return user_input_form
def _form_type_for_agent_variable(variable_type: str) -> str:
normalized = variable_type.strip().lower()
if normalized in {"number", "integer", "float"}:
return "number"
if normalized in {"boolean", "bool"}:
return "checkbox"
if normalized in {"paragraph", "long_text", "multiline"}:
return "paragraph"
return "text-input"
__all__ = ["agent_app_variables_to_user_input_form"]

View File

@ -12,10 +12,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.db.session_factory import create_session
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
@ -47,7 +47,10 @@ class AgentChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
with create_session() as session:
app_record = session.scalar(app_stmt)
if app_record:
session.expunge(app_record)
if not app_record:
raise ValueError("App not found")
@ -185,14 +188,18 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
with create_session() as session:
conversation_result = session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = session.scalar(msg_stmt)
if message_result is not None:
session.expunge(message_result)
session.expunge(conversation_result)
if message_result is None:
raise ValueError("Message not found")
db.session.close()
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner

View File

@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.db.session_factory import create_session
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationError
@ -46,7 +47,10 @@ class ChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
with create_session() as session:
app_record = session.scalar(stmt)
if app_record:
session.expunge(app_record)
if not app_record:
raise ValueError("App not found")
@ -216,6 +220,8 @@ class ChatAppRunner(AppRunner):
model=application_generate_entity.model_conf.model,
)
# Release the Flask scoped session before LLM streaming so a checked-out DB connection
# is not held for the lifetime of the provider response.
db.session.close()
invoke_result = model_instance.invoke_llm(

View File

@ -51,8 +51,11 @@ from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_forms import (
load_form_dispositions_by_form_id,
)
from core.workflow.human_input_policy import (
FormDisposition,
HumanInputSurface,
enrich_human_input_pause_reasons,
resolve_human_input_pause_reason_inputs,
@ -340,13 +343,14 @@ class WorkflowResponseConverter:
human_input_form_ids = [reason.form_id for reason in resolved_reasons if isinstance(reason, HumanInputRequired)]
expiration_times_by_form_id: dict[str, datetime] = {}
display_in_ui_by_form_id: dict[str, bool] = {}
form_token_by_form_id: dict[str, str] = {}
dispositions_by_form_id: dict[str, FormDisposition] = {}
if human_input_form_ids:
stmt = select(
HumanInputForm.id,
HumanInputForm.expiration_time,
HumanInputForm.form_definition,
).where(HumanInputForm.id.in_(human_input_form_ids))
hitl_surface = _INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from)
with Session(bind=db.engine) as session:
for form_id, expiration_time, form_definition in session.execute(stmt):
expiration_times_by_form_id[str(form_id)] = expiration_time
@ -355,17 +359,17 @@ class WorkflowResponseConverter:
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(
dispositions_by_form_id = load_form_dispositions_by_form_id(
human_input_form_ids,
session=session,
surface=_INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from),
surface=hitl_surface,
)
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
pause_reasons = enrich_human_input_pause_reasons(
pause_reasons,
form_tokens_by_form_id=form_token_by_form_id,
dispositions_by_form_id=dispositions_by_form_id,
expiration_times_by_form_id={
form_id: int(expiration_time.timestamp())
for form_id, expiration_time in expiration_times_by_form_id.items()
@ -379,6 +383,7 @@ class WorkflowResponseConverter:
expiration_time = expiration_times_by_form_id.get(reason.form_id)
if expiration_time is None:
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
disposition = dispositions_by_form_id.get(reason.form_id)
responses.append(
HumanInputRequiredResponse(
task_id=task_id,
@ -391,7 +396,8 @@ class WorkflowResponseConverter:
inputs=reason.inputs,
actions=reason.actions,
display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False),
form_token=form_token_by_form_id.get(reason.form_id),
form_token=disposition.form_token if disposition else None,
approval_channels=list(disposition.approval_channels) if disposition else [],
resolved_default_values=reason.resolved_default_values,
expiration_time=int(expiration_time.timestamp()),
),

View File

@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
CompletionAppGenerateEntity,
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.db.session_factory import create_session
from core.model_manager import ModelInstance
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -39,7 +40,10 @@ class CompletionAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
with create_session() as session:
app_record = session.scalar(stmt)
if app_record:
session.expunge(app_record)
if not app_record:
raise ValueError("App not found")
@ -174,6 +178,8 @@ class CompletionAppRunner(AppRunner):
model=application_generate_entity.model_conf.model,
)
# Release the Flask scoped session before LLM streaming so a checked-out DB connection
# is not held for the lifetime of the provider response.
db.session.close()
invoke_result = model_instance.invoke_llm(

View File

@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
QueueMessageEndEvent,
QueueStopEvent,
)
from models.model import AppMode
class MessageBasedAppQueueManager(AppQueueManager):
@ -47,4 +48,6 @@ class MessageBasedAppQueueManager(AppQueueManager):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
if self._app_mode == AppMode.ADVANCED_CHAT.value:
return
raise GenerateTaskStoppedError()

View File

@ -3,6 +3,7 @@ import time
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
@ -14,12 +15,12 @@ from core.app.entities.app_invoke_entities import (
build_dify_run_context,
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.db.session_factory import create_session
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from graphon.enums import WorkflowType
from graphon.graph import Graph
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
@ -83,22 +84,24 @@ class PipelineRunner(WorkflowBasedAppRunner):
user_from = self._resolve_user_from(invoke_from)
user_id = None
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.get(EndUser, self.application_generate_entity.user_id)
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
with create_session() as session:
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = session.get(EndUser, self.application_generate_entity.user_id)
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.get(Pipeline, app_config.app_id)
if not pipeline:
raise ValueError("Pipeline not found")
pipeline = session.get(Pipeline, app_config.app_id)
if not pipeline:
raise ValueError("Pipeline not found")
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow = self.get_workflow(session=session, pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
session.expunge(pipeline)
session.expunge(workflow)
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
@ -208,12 +211,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
)
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
def get_workflow(self, session: Session, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.scalar(
workflow = session.scalar(
select(Workflow)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.limit(1)
@ -298,11 +301,11 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = db.session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if document:
document.indexing_status = "error"
document.error = event.error or "Unknown error"
db.session.add(document)
db.session.commit()
with create_session() as session, session.begin():
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if document:
document.indexing_status = "error"
document.error = event.error or "Unknown error"
session.add(document)

View File

@ -1,7 +1,6 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
@ -43,6 +42,3 @@ class WorkflowAppQueueManager(AppQueueManager):
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@ -288,6 +288,7 @@ class HumanInputRequiredResponse(StreamResponse):
actions: Sequence[UserActionConfig] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
approval_channels: list[str] = Field(default_factory=list)
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int = Field(..., description="Unix timestamp in seconds")
@ -311,6 +312,7 @@ class HumanInputRequiredPauseReasonPayload(BaseModel):
actions: Sequence[UserActionConfig] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
approval_channels: list[str] = Field(default_factory=list)
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@ -325,6 +327,7 @@ class HumanInputRequiredPauseReasonPayload(BaseModel):
actions=data.actions,
display_in_ui=data.display_in_ui,
form_token=data.form_token,
approval_channels=data.approval_channels,
resolved_default_values=data.resolved_default_values,
expiration_time=data.expiration_time,
)

View File

@ -3,7 +3,6 @@ from collections.abc import Generator, Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -13,10 +12,19 @@ from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db.session_factory import create_session
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from extensions.ext_database import db
from models import Account
from models.model import App, AppMode, EndUser
from models import Account, TenantAccountJoin
from models.model import (
App,
AppMode,
AppModelConfig,
AppModelConfigDict,
EndUser,
load_annotation_reply_config,
)
from models.workflow import Workflow
from services.end_user_service import EndUserService
@ -30,18 +38,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app.workflow
workflow = cls._get_workflow(app)
if workflow is None:
raise ValueError("unexpected app type")
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
app_model_config_dict = cls._get_app_model_config_dict(app)
if app_model_config_dict is None:
raise ValueError("unexpected app type")
features_dict = cast(dict[str, Any], app_model_config.to_dict())
features_dict = cast(dict[str, Any], app_model_config_dict)
user_input_form = features_dict.get("user_input_form", [])
@ -68,7 +76,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if not user_id:
user = EndUserService.get_or_create_end_user(app)
else:
user = cls._get_user(user_id)
user = cls._get_user(user_id, app)
conversation_id = conversation_id or ""
@ -79,7 +87,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
case AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
workflow = cls._get_workflow(app)
if not workflow:
raise ValueError("unexpected app type")
return cls.invoke_workflow_app(app, workflow, user, stream, inputs, files)
case AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)
case _:
@ -101,7 +112,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
match app.mode:
case AppMode.ADVANCED_CHAT:
workflow = app.workflow
workflow = cls._get_workflow(app)
if not workflow:
raise ValueError("unexpected app type")
@ -158,6 +169,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
def invoke_workflow_app(
cls,
app: App,
workflow: Workflow,
user: EndUser | Account,
stream: bool,
inputs: Mapping,
@ -166,10 +178,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke workflow app
"""
workflow = app.workflow
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
@ -207,16 +215,26 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
)
@classmethod
def _get_user(cls, user_id: str) -> EndUser | Account:
def _get_user(cls, user_id: str, app: App) -> EndUser | Account:
"""
get the user by user id
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(EndUser).where(EndUser.id == user_id)
with create_session() as session:
stmt = select(EndUser).where(
EndUser.id == user_id,
EndUser.tenant_id == app.tenant_id,
EndUser.app_id == app.id,
)
user = session.scalar(stmt)
if not user:
stmt = select(Account).where(Account.id == user_id)
stmt = select(Account).where(
Account.id == user_id,
Account.id == TenantAccountJoin.account_id,
TenantAccountJoin.tenant_id == app.tenant_id,
)
user = session.scalar(stmt)
if user:
session.expunge(user)
if not user:
raise ValueError("user not found")
@ -229,7 +247,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app
"""
try:
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
with create_session() as session:
app = session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
if app:
session.expunge(app)
except Exception:
raise ValueError("app not found")
@ -237,3 +258,41 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
raise ValueError("app not found")
return app
@classmethod
def _get_workflow(cls, app: App) -> Workflow | None:
"""
get workflow without relying on App.workflow's request-scoped session property
"""
if not app.workflow_id:
return None
with create_session() as session:
workflow = session.scalar(
select(Workflow)
.where(Workflow.id == app.workflow_id, Workflow.tenant_id == app.tenant_id, Workflow.app_id == app.id)
.limit(1)
)
if workflow:
session.expunge(workflow)
return workflow
@classmethod
def _get_app_model_config_dict(cls, app: App) -> AppModelConfigDict | None:
"""
get app model config features without relying on request-scoped session-backed model properties
"""
if not app.app_model_config_id:
return None
with create_session() as session:
app_model_config = session.scalar(
select(AppModelConfig)
.where(AppModelConfig.id == app.app_model_config_id, AppModelConfig.app_id == app.id)
.limit(1)
)
if app_model_config is None:
return None
annotation_reply = load_annotation_reply_config(session, app_model_config.app_id)
return app_model_config.to_dict(annotation_reply=annotation_reply)

View File

@ -22,23 +22,35 @@ class RBACPermission(StrEnum):
APP_VIEW_LAYOUT = "app_view_layout"
APP_TEST_AND_RUN = "app_test_and_run"
APP_PREVIEW = "app_preview"
APP_CREATE_AND_MANAGEMENT = "app_create_and_management"
APP_RELEASE_AND_VERSION = "app_release_and_version"
APP_IMPORT_EXPORT_DSL = "app_import_export_dsl"
APP_EDIT = "app_edit"
APP_MONITOR = "app_monitor"
APP_DELETE = "app_delete"
APP_ACCESS_CONFIG = "app_access_config"
DATASET_PREVIEW = "dataset_preview"
DATASET_READONLY = "dataset_readonly"
DATASET_EDIT = "dataset_edit"
DATASET_CREATE_AND_MANAGEMENT = "dataset_create_and_management"
DATASET_PIPELINE_TEST = "dataset_pipeline_test"
DATASET_DOCUMENT_DOWNLOAD = "dataset_document_download"
DATASET_RETRIEVAL_RECALL = "dataset_retrieval_recall"
DATASET_USE = "dataset_use"
DATASET_DELETE_FILE = "dataset_delete_file"
DATASET_PIPELINE_RELEASE = "dataset_pipeline_release"
DATASET_DELETE = "dataset_delete"
DATASET_ACCESS_CONFIG = "dataset_access_config"
DATASET_API_KEY_MANAGE = "dataset_api_key_manage"
DATASET_EXTERNAL_CONNECT = "dataset_external_connect"
DATASET_IMPORT_EXPORT_DSL = "dataset_import_export_dsl"
WORKSPACE_MEMBER_MANAGE = "workspace_member_manage"
WORKSPACE_ROLE_MANAGE = "workspace_role_manage"
API_EXTENSION_MANAGE = "api_extension_manage"
CUSTOMIZATION_MANAGE = "customization_manage"
SNIPPETS_CREATE_AND_MODIFY = "snippets_create_and_modify"
SNIPPETS_MANAGE = "snippets_management"
@ -49,6 +61,7 @@ class RBACPermission(StrEnum):
PLUGIN_DEBUG = "plugin_debug"
CREDENTIAL_USE = "credential_use"
CREDENTIAL_CREATE = "credential_create"
CREDENTIAL_MANAGE = "credential_manage"
TOOL_MANAGE = "tool_manage"

View File

@ -12,60 +12,61 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
from core.workflow.human_input_policy import (
FormDisposition,
HumanInputSurface,
disposition_for_surface,
)
from extensions.ext_database import db
from models.human_input import HumanInputFormRecipient, RecipientType
def load_form_dispositions_by_form_id(
form_ids: Sequence[str],
*,
session: Session | None = None,
surface: HumanInputSurface | None = None,
) -> dict[str, FormDisposition]:
"""Resolve each paused form's resume token and approval channels for `surface`."""
unique_form_ids = list(dict.fromkeys(form_ids))
if not unique_form_ids:
return {}
if session is not None:
return _load_form_dispositions_by_form_id(session, unique_form_ids, surface=surface)
with Session(bind=db.engine, expire_on_commit=False) as new_session:
return _load_form_dispositions_by_form_id(new_session, unique_form_ids, surface=surface)
def _load_form_dispositions_by_form_id(
session: Session,
form_ids: Sequence[str],
*,
surface: HumanInputSurface | None,
) -> dict[str, FormDisposition]:
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(stmt):
recipients_by_form_id.setdefault(recipient.form_id, []).append(
(recipient.recipient_type, recipient.access_token or "")
)
return {
form_id: disposition_for_surface(recipients, surface=surface)
for form_id, recipients in recipients_by_form_id.items()
}
def load_form_tokens_by_form_id(
form_ids: Sequence[str],
*,
session: Session | None = None,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
"""Load the preferred access token for each human input form."""
unique_form_ids = list(dict.fromkeys(form_ids))
if not unique_form_ids:
return {}
if session is not None:
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
with Session(bind=db.engine, expire_on_commit=False) as new_session:
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
def _load_form_tokens_by_form_id(
session: Session,
form_ids: Sequence[str],
*,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(stmt):
if not recipient.access_token:
continue
recipients_by_form_id.setdefault(recipient.form_id, []).append(
(recipient.recipient_type, recipient.access_token)
)
tokens_by_form_id: dict[str, str] = {}
for form_id, recipients in recipients_by_form_id.items():
token = _get_surface_form_token(recipients, surface=surface)
if token is not None:
tokens_by_form_id[form_id] = token
return tokens_by_form_id
def _get_surface_form_token(
recipients: Sequence[tuple[RecipientType, str]],
*,
surface: HumanInputSurface | None,
) -> str | None:
if surface in {HumanInputSurface.SERVICE_API, HumanInputSurface.OPENAPI}:
for recipient_type, token in recipients:
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
return token
return get_preferred_form_token(recipients)
"""Resume tokens only, for callers that don't surface approval channels."""
dispositions = load_form_dispositions_by_form_id(form_ids, session=session, surface=surface)
return {
form_id: disposition.form_token
for form_id, disposition in dispositions.items()
if disposition.form_token is not None
}

View File

@ -2,14 +2,14 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from typing import Any, NamedTuple
from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType
from graphon.nodes.human_input.entities import FormInputConfig, SelectInputConfig
from graphon.nodes.human_input.enums import ValueSourceType
from graphon.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
from graphon.variables import ArrayStringSegment
from models.human_input import RecipientType
from models.human_input import ApprovalChannel, RecipientType
class HumanInputSurface(StrEnum):
@ -20,7 +20,7 @@ class HumanInputSurface(StrEnum):
# SERVICE_API and OPENAPI are intentionally narrower than CONSOLE: token callers
# should only be able to act on end-user web forms, not internal console flows.
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
HumanInputSurface.OPENAPI: frozenset({RecipientType.STANDALONE_WEB_APP}),
@ -41,7 +41,7 @@ def is_recipient_type_allowed_for_surface(
) -> bool:
if recipient_type is None:
return False
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
return recipient_type in ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
def get_preferred_form_token(
@ -59,10 +59,39 @@ def get_preferred_form_token(
return chosen_token
class FormDisposition(NamedTuple):
"""How a paused form resolves for one API surface.
A form's recipients split into those the surface may act on (yielding a resume
`form_token`) and those it may not (their channels named in `approval_channels`
so the caller is told where approval actually happens instead).
"""
form_token: str | None
approval_channels: list[ApprovalChannel]
def disposition_for_surface(
recipients: Sequence[tuple[RecipientType, str]],
*,
surface: HumanInputSurface | None,
) -> FormDisposition:
if surface is None:
return FormDisposition(form_token=get_preferred_form_token(recipients), approval_channels=[])
allowed = ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
actionable = [(recipient_type, token) for recipient_type, token in recipients if recipient_type in allowed]
return FormDisposition(
form_token=get_preferred_form_token(actionable),
approval_channels=sorted(
{recipient_type.approval_channel for recipient_type, _ in recipients if recipient_type not in allowed}
),
)
def enrich_human_input_pause_reasons(
reasons: Sequence[Mapping[str, Any]],
*,
form_tokens_by_form_id: Mapping[str, str],
dispositions_by_form_id: Mapping[str, FormDisposition],
expiration_times_by_form_id: Mapping[str, int],
) -> list[dict[str, Any]]:
enriched: list[dict[str, Any]] = []
@ -71,7 +100,9 @@ def enrich_human_input_pause_reasons(
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_id = updated.get("form_id")
if isinstance(form_id, str):
updated["form_token"] = form_tokens_by_form_id.get(form_id)
disposition = dispositions_by_form_id.get(form_id)
updated["form_token"] = disposition.form_token if disposition else None
updated["approval_channels"] = list(disposition.approval_channels) if disposition else []
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is not None:
updated["expiration_time"] = expiration_time

View File

@ -25,7 +25,7 @@ from extensions.redis_names import (
serialize_redis_name_args,
)
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.pubsub_channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
@ -457,16 +457,14 @@ def init_app(app: DifyApp):
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
join_timeout_ms = dify_config.PUBSUB_LISTENER_JOIN_TIMEOUT_MS
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
return ShardedRedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
return StreamsBroadcastChannel(
_pubsub_redis_client,
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
join_timeout_ms=join_timeout_ms,
)
return RedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
return RedisBroadcastChannel(_pubsub_redis_client)
def redis_fallback[T](default_return: T | None = None): # type: ignore

View File

@ -291,6 +291,11 @@ class AgentConfigSnapshotListResponse(ResponseModel):
data: list[AgentConfigSnapshotSummaryResponse]
class AgentConfigSnapshotRestoreResponse(ResponseModel):
result: Literal["success"]
active_config_snapshot_id: str
class AgentComposerAgentResponse(ResponseModel):
id: str
name: str

View File

@ -1,4 +1,4 @@
from .channel import BroadcastChannel
from .pubsub_channel import BroadcastChannel
from .sharded_channel import ShardedRedisBroadcastChannel
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]

View File

@ -7,6 +7,7 @@ from typing import Any, Self, override
from libs.broadcast_channel.channel import Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from libs.broadcast_channel.signals import SIG_CLOSE
from redis import Redis, RedisCluster
from redis.client import PubSub
@ -26,8 +27,6 @@ class RedisSubscriptionBase(Subscription):
client: Redis | RedisCluster,
pubsub: PubSub,
topic: str,
*,
join_timeout_ms: int = 2000,
):
# The _pubsub is None only if the subscription is closed.
self._client = client
@ -39,11 +38,6 @@ class RedisSubscriptionBase(Subscription):
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
# Max time close() will wait for the listener thread to finish before
# returning. Bounds SSE close tail latency. The listener is a daemon
# and exits on its own within one poll window (~1s), so a low value
# here just means close() returns sooner without breaking anything.
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def _start_if_needed(self) -> None:
"""Start the subscription if not already started."""
@ -90,6 +84,11 @@ class RedisSubscriptionBase(Subscription):
if raw_message is None:
continue
# If close() sent a control event to unblock us, exit immediately
# without processing any message — the subscription is shutting down.
if self._closed.is_set():
break
if raw_message.get("type") != self._get_message_type():
continue
@ -119,6 +118,8 @@ class RedisSubscriptionBase(Subscription):
continue
self._enqueue_message(payload_bytes)
if payload_bytes == SIG_CLOSE:
break
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
try:
@ -212,13 +213,16 @@ class RedisSubscriptionBase(Subscription):
return
self._closed.set()
# Send a control event on the same Redis channel to unblock the
self._publish_close_event()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
# message retrieval method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=self._join_timeout_ms / 1000.0)
listener.join(timeout=2)
self._listener_thread = None
# Abstract methods to be implemented by subclasses
@ -226,6 +230,15 @@ class RedisSubscriptionBase(Subscription):
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
raise NotImplementedError
def _publish_close_event(self) -> None:
"""Publish a control event on the Redis channel to unblock the listener.
This is called by close() after setting _closed. The subclass should
publish an empty message on the same topic so that a blocking
get_message() call in the listener thread returns promptly.
"""
raise NotImplementedError
def _subscribe(self) -> None:
"""Subscribe to the Redis topic using the appropriate command."""
raise NotImplementedError

View File

@ -1,13 +1,17 @@
from __future__ import annotations
import logging
from typing import Any, override
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.signals import SIG_CLOSE
from redis import Redis, RedisCluster
from ._subscription import RedisSubscriptionBase
logger = logging.getLogger(__name__)
class BroadcastChannel:
"""
@ -22,16 +26,11 @@ class BroadcastChannel:
def __init__(
self,
redis_client: Redis | RedisCluster,
*,
join_timeout_ms: int = 2000,
):
self._client = redis_client
# See `RedisSubscriptionBase._join_timeout_ms`: how long close()
# waits for the listener thread before returning.
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def topic(self, topic: str) -> Topic:
return Topic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
return Topic(self._client, topic)
class Topic:
@ -39,13 +38,10 @@ class Topic:
self,
redis_client: Redis | RedisCluster,
topic: str,
*,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def as_producer(self) -> Producer:
return self
@ -61,7 +57,6 @@ class Topic:
client=self._client,
pubsub=self._client.pubsub(),
topic=self._redis_topic,
join_timeout_ms=self._join_timeout_ms,
)
@ -72,6 +67,13 @@ class _RedisSubscription(RedisSubscriptionBase):
def _get_subscription_type(self) -> str:
return "regular"
@override
def _publish_close_event(self) -> None:
try:
self._client.publish(self._topic, SIG_CLOSE)
except Exception:
logger.exception("failed to publish close event")
@override
def _subscribe(self) -> None:
assert self._pubsub is not None

View File

@ -1,13 +1,17 @@
from __future__ import annotations
import logging
from typing import Any, override
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.signals import SIG_CLOSE
from redis import Redis, RedisCluster
from ._subscription import RedisSubscriptionBase
logger = logging.getLogger(__name__)
class ShardedRedisBroadcastChannel:
"""
@ -20,14 +24,11 @@ class ShardedRedisBroadcastChannel:
def __init__(
self,
redis_client: Redis | RedisCluster,
*,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def topic(self, topic: str) -> ShardedTopic:
return ShardedTopic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
return ShardedTopic(self._client, topic)
class ShardedTopic:
@ -35,13 +36,10 @@ class ShardedTopic:
self,
redis_client: Redis | RedisCluster,
topic: str,
*,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def as_producer(self) -> Producer:
return self
@ -57,7 +55,6 @@ class ShardedTopic:
client=self._client,
pubsub=self._client.pubsub(),
topic=self._redis_topic,
join_timeout_ms=self._join_timeout_ms,
)
@ -68,6 +65,13 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
def _get_subscription_type(self) -> str:
return "sharded"
@override
def _publish_close_event(self) -> None:
try:
self._client.spublish(self._topic, SIG_CLOSE) # type: ignore[attr-defined,union-attr]
except Exception:
logger.exception("failed to publish close event")
@override
def _subscribe(self) -> None:
assert self._pubsub is not None

View File

@ -9,6 +9,7 @@ from typing import Self, override
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from libs.broadcast_channel.signals import SIG_CLOSE
from redis import Redis, RedisCluster
logger = logging.getLogger(__name__)
@ -29,20 +30,15 @@ class StreamsBroadcastChannel:
redis_client: Redis | RedisCluster,
*,
retention_seconds: int = 600,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._retention_seconds = max(int(retention_seconds or 0), 0)
# Max time close() will wait for the listener thread to finish.
# See `_StreamsSubscription._join_timeout_ms` for the rationale.
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def topic(self, topic: str) -> StreamsTopic:
return StreamsTopic(
self._client,
topic,
retention_seconds=self._retention_seconds,
join_timeout_ms=self._join_timeout_ms,
)
@ -53,13 +49,11 @@ class StreamsTopic:
topic: str,
*,
retention_seconds: int = 600,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._topic = topic
self._key = serialize_redis_name(f"stream:{topic}")
self._retention_seconds = retention_seconds
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
self.max_length = 5000
def as_producer(self) -> Producer:
@ -77,23 +71,15 @@ class StreamsTopic:
return self
def subscribe(self) -> Subscription:
return _StreamsSubscription(self._client, self._key, join_timeout_ms=self._join_timeout_ms)
return _StreamsSubscription(self._client, self._key)
class _StreamsSubscription(Subscription):
_SENTINEL = object()
def __init__(self, client: Redis | RedisCluster, key: str, *, join_timeout_ms: int = 2000):
def __init__(self, client: Redis | RedisCluster, key: str):
self._client = client
self._key = key
# Max time close() will wait for the listener thread to finish before
# returning. Bounds SSE close tail latency: the listener blocks on
# XREAD with BLOCK=1000ms, so close() naturally waits up to ~1s for
# the thread to notice _closed. Setting this lower lets close()
# return promptly while the daemon listener exits on its own within
# one BLOCK window - safe because the listener holds no critical
# state. ``0`` means close() does not wait at all.
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
self._queue: queue.Queue[object] = queue.Queue()
@ -106,7 +92,6 @@ class _StreamsSubscription(Subscription):
# reading and writing the _listener / `_closed` attribute.
self._lock = threading.Lock()
self._closed: bool = False
# self._closed = threading.Event()
self._listener: threading.Thread | None = None
def _listen(self) -> None:
@ -144,6 +129,8 @@ class _StreamsSubscription(Subscription):
case bytes() | bytearray():
data_bytes = bytes(data)
if data_bytes is not None:
if data_bytes == SIG_CLOSE:
break
self._queue.put_nowait(data_bytes)
last_id = entry_id
finally:
@ -203,6 +190,13 @@ class _StreamsSubscription(Subscription):
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
return bytes(item)
def _publish_close_event(self) -> None:
"""Publish an empty message to the stream to unblock the listener's xread."""
try:
self._client.xadd(self._key, {b"data": SIG_CLOSE})
except Exception:
logger.exception("failed to publish close event")
@override
def close(self) -> None:
with self._lock:
@ -212,16 +206,17 @@ class _StreamsSubscription(Subscription):
listener = self._listener
if listener is not None:
self._listener = None
# We close the listener outside of the with block to avoid holding the
# lock for a long time.
if listener is not None:
self._publish_close_event()
if listener is not None and listener.is_alive():
listener.join(timeout=self._join_timeout_ms / 1000.0)
listener.join(timeout=2)
if listener.is_alive():
logger.debug(
"Streams subscription listener for key %s did not stop within %dms; "
"Streams subscription listener for key %s did not stop after join; "
"daemon thread will exit on its own within one poll window.",
self._key,
self._join_timeout_ms,
)
# Context manager helpers

View File

@ -0,0 +1 @@
SIG_CLOSE = b"__closed__"

View File

@ -0,0 +1,39 @@
"""agent drive skill metadata refactor
Revision ID: b2515f9d4c2a
Revises: 4f7b2c8d9a10
Create Date: 2026-06-18 23:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = "b2515f9d4c2a"
down_revision = "4f7b2c8d9a10"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"agent_drive_files",
sa.Column("is_skill", sa.Boolean(), nullable=False, server_default=sa.text("false")),
)
op.add_column(
"agent_drive_files",
sa.Column("skill_metadata", sa.Text().with_variant(mysql.LONGTEXT(), "mysql"), nullable=True),
)
op.create_index(
"agent_drive_files_tenant_agent_is_skill_key_idx",
"agent_drive_files",
["tenant_id", "agent_id", "is_skill", "key"],
)
def downgrade() -> None:
op.drop_index("agent_drive_files_tenant_agent_is_skill_key_idx", table_name="agent_drive_files")
op.drop_column("agent_drive_files", "skill_metadata")
op.drop_column("agent_drive_files", "is_skill")

View File

@ -0,0 +1,66 @@
"""add agent debug conversations
Revision ID: c8f4a6b2d3e1
Revises: b2515f9d4c2a
Create Date: 2026-06-22 10:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
import models
# revision identifiers, used by Alembic.
revision = "c8f4a6b2d3e1"
down_revision = "b2515f9d4c2a"
branch_labels = None
depends_on = None
def _is_pg(conn) -> bool:
return conn.dialect.name == "postgresql"
def _uuid_column(name: str, *, nullable: bool = False, primary_key: bool = False) -> sa.Column:
kwargs = {"nullable": nullable, "primary_key": primary_key}
if primary_key and _is_pg(op.get_bind()):
kwargs["server_default"] = sa.text("uuidv7()")
return sa.Column(name, models.types.StringUUID(), **kwargs)
def upgrade():
op.create_table(
"agent_debug_conversations",
_uuid_column("id", primary_key=True),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("agent_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("account_id", models.types.StringUUID(), nullable=False),
sa.Column("conversation_id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("agent_debug_conversation_pkey")),
sa.UniqueConstraint(
"tenant_id",
"agent_id",
"account_id",
name=op.f("agent_debug_conversation_agent_account_unique"),
),
)
op.create_index(
"agent_debug_conversation_conversation_idx",
"agent_debug_conversations",
["conversation_id"],
)
op.create_index(
"agent_debug_conversation_account_idx",
"agent_debug_conversations",
["tenant_id", "account_id"],
)
def downgrade():
op.drop_index("agent_debug_conversation_account_idx", table_name="agent_debug_conversations")
op.drop_index("agent_debug_conversation_conversation_idx", table_name="agent_debug_conversations")
op.drop_table("agent_debug_conversations")

View File

@ -13,6 +13,7 @@ from .agent import (
AgentConfigRevision,
AgentConfigRevisionOperation,
AgentConfigSnapshot,
AgentDebugConversation,
AgentDriveFile,
AgentDriveFileKind,
AgentIconType,
@ -156,6 +157,7 @@ __all__ = [
"AgentConfigRevision",
"AgentConfigRevisionOperation",
"AgentConfigSnapshot",
"AgentDebugConversation",
"AgentDriveFile",
"AgentDriveFileKind",
"AgentIconType",

View File

@ -83,6 +83,8 @@ class AgentConfigRevisionOperation(StrEnum):
SAVE_NEW_AGENT = "save_new_agent"
# Promotes a workflow-only Agent into the reusable Agent Roster.
SAVE_TO_ROSTER = "save_to_roster"
# Switches the Agent's current published config back to an existing version.
RESTORE_VERSION = "restore_version"
class WorkflowAgentBindingType(StrEnum):
@ -180,6 +182,34 @@ class Agent(DefaultFieldsMixin, Base):
archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
class AgentDebugConversation(DefaultFieldsMixin, Base):
"""Per-account console debug conversation for an Agent App.
Agent App preview state must be isolated by editor account. The Agent row is
shared by everyone in the workspace, so this table owns the user-specific
conversation pointer used by console debug chat.
"""
__tablename__ = "agent_debug_conversations"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="agent_debug_conversation_pkey"),
UniqueConstraint(
"tenant_id",
"agent_id",
"account_id",
name="agent_debug_conversation_agent_account_unique",
),
Index("agent_debug_conversation_conversation_idx", "conversation_id"),
Index("agent_debug_conversation_account_idx", "tenant_id", "account_id"),
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
agent_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
class AgentConfigSnapshot(DefaultFieldsMixin, Base):
"""Immutable Agent Soul snapshot.
@ -430,14 +460,17 @@ class AgentDriveFile(DefaultFieldsMixin, Base):
synced. ``value_owned_by_drive`` gates physical cleanup: only drive-owned values
(created by the agent runtime or Skill standardization, not shared with other
business records) have their storage object + record deleted when the KV entry is
overwritten or removed; otherwise only the KV row is dropped. Lifecycle never relies
on ``UploadFile.used/used_by`` (not a reliable refcount).
overwritten or removed; otherwise only the KV row is dropped. Skills are represented
by the canonical ``<path>/SKILL.md`` row with ``is_skill=True`` and a serialized
``skill_metadata`` string. Lifecycle never relies on ``UploadFile.used/used_by``
(not a reliable refcount).
"""
__tablename__ = "agent_drive_files"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="agent_drive_file_pkey"),
UniqueConstraint("tenant_id", "agent_id", "key", name="agent_drive_file_scope_key_unique"),
Index("agent_drive_files_tenant_agent_is_skill_key_idx", "tenant_id", "agent_id", "is_skill", "key"),
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -453,6 +486,8 @@ class AgentDriveFile(DefaultFieldsMixin, Base):
value_owned_by_drive: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, default=False, server_default=sa.text("false")
)
is_skill: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False, server_default=sa.text("false"))
skill_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True)
size: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True)
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
mime_type: Mapped[str | None] = mapped_column(String(255), nullable=True)

View File

@ -134,20 +134,40 @@ class HumanInputDelivery(DefaultFieldsMixin, Base):
)
class ApprovalChannel(StrEnum):
"""Where a paused human input form can be approved, surfaced to API callers."""
EMAIL = "email"
WEB_APP = "web_app"
CONSOLE = "console"
class RecipientType(StrEnum):
# EMAIL_MEMBER member means that the
EMAIL_MEMBER = "email_member"
EMAIL_EXTERNAL = "email_external"
# Second value = the approval channel this recipient maps to (surfaced in `approval_channels`).
EMAIL_MEMBER = "email_member", ApprovalChannel.EMAIL
EMAIL_EXTERNAL = "email_external", ApprovalChannel.EMAIL
# STANDALONE_WEB_APP is used by the standalone web app.
#
# It's not used while running workflows / chatflows containing HumanInput
# node inside console.
STANDALONE_WEB_APP = "standalone_web_app"
STANDALONE_WEB_APP = "standalone_web_app", ApprovalChannel.WEB_APP
# CONSOLE is used while running workflows / chatflows containing HumanInput
# node inside console. (E.G. running installed apps or debugging workflows / chatflows)
CONSOLE = "console"
CONSOLE = "console", ApprovalChannel.CONSOLE
# BACKSTAGE is used for backstage input inside console.
BACKSTAGE = "backstage"
BACKSTAGE = "backstage", ApprovalChannel.CONSOLE
_approval_channel: ApprovalChannel
def __new__(cls, value: str, approval_channel: ApprovalChannel) -> "RecipientType":
member = str.__new__(cls, value)
member._value_ = value
member._approval_channel = approval_channel
return member
@property
def approval_channel(self) -> ApprovalChannel:
return self._approval_channel
@final

View File

@ -774,26 +774,7 @@ class AppModelConfig(TypeBase):
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
if not collection_binding_detail:
raise ValueError("Collection binding detail not found")
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {"enabled": False}
return load_annotation_reply_config(db.session(), self.app_id)
@property
def more_like_this_dict(self) -> EnabledConfig:
@ -864,7 +845,7 @@ class AppModelConfig(TypeBase):
},
)
def to_dict(self) -> AppModelConfigDict:
def to_dict(self, *, annotation_reply: AnnotationReplyConfig | None = None) -> AppModelConfigDict:
return {
"opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list,
@ -872,7 +853,7 @@ class AppModelConfig(TypeBase):
"speech_to_text": self.speech_to_text_dict,
"text_to_speech": self.text_to_speech_dict,
"retriever_resource": self.retriever_resource_dict,
"annotation_reply": self.annotation_reply_dict,
"annotation_reply": annotation_reply if annotation_reply is not None else self.annotation_reply_dict,
"more_like_this": self.more_like_this_dict,
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"external_data_tools": self.external_data_tools_list,
@ -2038,6 +2019,30 @@ class AppAnnotationSetting(TypeBase):
)
def load_annotation_reply_config(session: Session, app_id: str) -> AnnotationReplyConfig:
annotation_setting = session.scalar(select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id))
if annotation_setting is None:
return {"enabled": False}
from .dataset import DatasetCollectionBinding
collection_binding_detail = session.scalar(
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == annotation_setting.collection_binding_id)
)
if collection_binding_detail is None:
raise ValueError("Collection binding detail not found")
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
class OperationLog(TypeBase):
__tablename__ = "operation_logs"
__table_args__ = (

View File

@ -391,6 +391,80 @@ Check if activation token is valid
| 400 | Invalid request parameters | |
| 403 | Insufficient permissions | |
### [GET] /agent/{agent_id}/api-access
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| agent_id | path | | Yes | string (uuid) |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Agent service API access | **application/json**: [AgentApiAccessResponse](#agentapiaccessresponse)<br> |
### [POST] /agent/{agent_id}/api-enable
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| agent_id | path | | Yes | string (uuid) |
#### Request Body
| Required | Schema |
| -------- | ------ |
| Yes | **application/json**: [AgentApiStatusPayload](#agentapistatuspayload)<br> |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Agent service API status updated | **application/json**: [AgentApiAccessResponse](#agentapiaccessresponse)<br> |
| 403 | Insufficient permissions | |
### [GET] /agent/{agent_id}/api-keys
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| agent_id | path | | Yes | string (uuid) |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Agent service API keys | **application/json**: [ApiKeyList](#apikeylist)<br> |
### [POST] /agent/{agent_id}/api-keys
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| agent_id | path | | Yes | string (uuid) |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 201 | Agent service API key created | **application/json**: [ApiKeyItem](#apikeyitem)<br> |
| 400 | Maximum keys exceeded | |
### [DELETE] /agent/{agent_id}/api-keys/{api_key_id}
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| agent_id | path | | Yes | string (uuid) |
| api_key_id | path | | Yes | string (uuid) |
#### Responses
| Code | Description |
| ---- | ----------- |
| 204 | Agent service API key deleted |
### [GET] /agent/{agent_id}/chat-messages
Get Agent App chat messages for a conversation with pagination
@ -576,6 +650,37 @@ Truncated text preview of one Agent App drive value
| ---- | ----------- | ------ |
| 200 | Preview | **application/json**: [AgentDrivePreviewResponse](#agentdrivepreviewresponse)<br> |
### [GET] /agent/{agent_id}/drive/skills
List drive-backed skills for an Agent App
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| agent_id | path | Agent ID | Yes | string (uuid) |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Drive skills | **application/json**: [AgentDriveSkillListResponse](#agentdriveskilllistresponse)<br> |
### [GET] /agent/{agent_id}/drive/skills/{skill_path}/inspect
Inspect one drive-backed skill for slash-menu hover/detail UI
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| agent_id | path | Agent ID | Yes | string (uuid) |
| skill_path | path | Skill path/slug, e.g. tender-analyzer | Yes | string |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Drive skill inspect view | **application/json**: [AgentDriveSkillInspectResponse](#agentdriveskillinspectresponse)<br> |
### [POST] /agent/{agent_id}/features
Update an Agent App's presentation features (opener, follow-up, citations, ...)
@ -905,6 +1010,20 @@ Infer CLI tool + ENV suggestions from a standardized Agent App skill
| ---- | ----------- | ------ |
| 200 | Agent version detail | **application/json**: [AgentConfigSnapshotDetailResponse](#agentconfigsnapshotdetailresponse)<br> |
### [POST] /agent/{agent_id}/versions/{version_id}/restore
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| agent_id | path | | Yes | string (uuid) |
| version_id | path | | Yes | string (uuid) |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Agent version restored | **application/json**: [AgentConfigSnapshotRestoreResponse](#agentconfigsnapshotrestoreresponse)<br> |
### [GET] /all-workspaces
#### Parameters
@ -1454,6 +1573,40 @@ Truncated text preview of one drive value (binary-safe; SKILL.md is the main cas
| ---- | ----------- | ------ |
| 200 | Preview | **application/json**: [AgentDrivePreviewResponse](#agentdrivepreviewresponse)<br> |
### [GET] /apps/{app_id}/agent/drive/skills
List drive-backed skills for the bound agent
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| app_id | path | Application ID | Yes | string (uuid) |
| node_id | query | Workflow node ID (workflow composer variant) | No | string |
| prefix | query | Key prefix filter: '<slug>/' for one skill, 'files/' for files | No | string |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Drive skills | **application/json**: [AgentDriveSkillListResponse](#agentdriveskilllistresponse)<br> |
### [GET] /apps/{app_id}/agent/drive/skills/{skill_path}/inspect
Inspect one drive-backed skill for slash-menu hover/detail UI
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| app_id | path | Application ID | Yes | string (uuid) |
| skill_path | path | Skill path/slug, e.g. tender-analyzer | Yes | string |
| node_id | query | Workflow node ID (workflow composer variant) | No | string |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Drive skill inspect view | **application/json**: [AgentDriveSkillInspectResponse](#agentdriveskillinspectresponse)<br> |
### [DELETE] /apps/{app_id}/agent/files
Delete one drive file by key; soul ref first, then the KV row (ENG-625 D5)
@ -11954,6 +12107,31 @@ Default namespace
| chat_prompt_config | object | | No |
| completion_prompt_config | object | | No |
#### AgentApiAccessResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| api_key_count | integer | | Yes |
| api_rph | integer | | Yes |
| api_rpm | integer | | Yes |
| chat_endpoint | string | | Yes |
| conversations_endpoint | string | | Yes |
| enabled | boolean | | Yes |
| files_upload_endpoint | string | | Yes |
| info_endpoint | string | | Yes |
| messages_endpoint | string | | Yes |
| meta_endpoint | string | | Yes |
| parameters_endpoint | string | | Yes |
| service_api_base_url | string | | Yes |
| stop_endpoint | string | | Yes |
| streaming_only | boolean, <br>**Default:** true | | No |
#### AgentApiStatusPayload
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| enable_api | boolean | Enable or disable Agent service API | Yes |
#### AgentAppComposerResponse
| Name | Type | Description | Required |
@ -11987,6 +12165,7 @@ Default namespace
| bound_agent_id | string | | No |
| created_at | integer | | No |
| created_by | string | | No |
| debug_conversation_id | string | | No |
| deleted_tools | [ [DeletedTool](#deletedtool) ] | | No |
| description | string | | No |
| enable_api | boolean | | Yes |
@ -12050,6 +12229,7 @@ default (the config form sends the full desired feature state on save).
| create_user_name | string | | No |
| created_at | integer | | No |
| created_by | string | | No |
| debug_conversation_id | string | | No |
| description | string | | No |
| has_draft_trigger | boolean | | No |
| icon | string | | No |
@ -12340,6 +12520,13 @@ Audit operation recorded for Agent Soul version/revision changes.
| ---- | ---- | ----------- | -------- |
| data | [ [AgentConfigSnapshotSummaryResponse](#agentconfigsnapshotsummaryresponse) ] | | Yes |
#### AgentConfigSnapshotRestoreResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| active_config_snapshot_id | string | | Yes |
| result | string | | Yes |
#### AgentConfigSnapshotSummaryResponse
| Name | Type | Description | Required |
@ -12425,9 +12612,11 @@ Audit operation recorded for Agent Soul version/revision changes.
| created_at | integer | | No |
| file_kind | string | | Yes |
| hash | string | | No |
| is_skill | boolean | | No |
| key | string | | Yes |
| mime_type | string | | No |
| size | integer | | No |
| skill_metadata | string | | No |
#### AgentDriveListResponse
@ -12445,6 +12634,65 @@ Audit operation recorded for Agent Soul version/revision changes.
| text | string | | No |
| truncated | boolean | | Yes |
#### AgentDriveSkillFileResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| available_in_drive | boolean | | Yes |
| drive_key | string | | No |
| name | string | | Yes |
| path | string | | Yes |
| type | string | | Yes |
#### AgentDriveSkillInspectResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| archive_key | string | | No |
| created_at | integer | | No |
| description | string | | Yes |
| file_tree | [ object ] | | No |
| files | [ [AgentDriveSkillFileResponse](#agentdriveskillfileresponse) ] | | No |
| hash | string | | No |
| mime_type | string | | No |
| name | string | | Yes |
| path | string | | Yes |
| size | integer | | No |
| skill_md | [AgentDriveSkillMarkdownResponse](#agentdriveskillmarkdownresponse) | | Yes |
| skill_md_key | string | | Yes |
| source | string | | Yes |
| warnings | [ string ] | | No |
#### AgentDriveSkillItemResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| archive_key | string | | No |
| created_at | integer | | No |
| description | string | | Yes |
| hash | string | | No |
| mime_type | string | | No |
| name | string | | Yes |
| path | string | | Yes |
| size | integer | | No |
| skill_md_key | string | | Yes |
#### AgentDriveSkillListResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| items | [ [AgentDriveSkillItemResponse](#agentdriveskillitemresponse) ] | | No |
#### AgentDriveSkillMarkdownResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| binary | boolean | | Yes |
| key | string | | Yes |
| size | integer | | No |
| text | string | | No |
| truncated | boolean | | Yes |
#### AgentEnvVariableConfig
| Name | Type | Description | Required |

View File

@ -83,7 +83,6 @@ User-scoped operations
| mode | query | | No | string, <br>**Available values:** "advanced-chat", "agent", "agent-chat", "channel", "chat", "completion", "rag-pipeline", "workflow" |
| name | query | | No | string |
| page | query | | No | integer, <br>**Default:** 1 |
| tag | query | | No | string |
| workspace_id | query | | Yes | string |
#### Responses
@ -331,6 +330,22 @@ Upload a file to use as an input variable when running the app
| 422 | Validation error | **application/json**: [ErrorBody](#errorbody)<br> |
| default | Error | **application/json**: [ErrorBody](#errorbody)<br> |
### [GET] /permitted-external-apps/{app_id}/describe
#### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| fields | query | | No | string |
| app_id | path | | Yes | string |
#### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Permitted external app description | **application/json**: [AppDescribeResponse](#appdescriberesponse)<br> |
| 422 | Validation error | **application/json**: [ErrorBody](#errorbody)<br> |
| default | Error | **application/json**: [ErrorBody](#errorbody)<br> |
### [GET] /workspaces
#### Responses
@ -507,14 +522,12 @@ Upload a file to use as an input variable when running the app
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| author | string | | No |
| description | string | | No |
| id | string | | Yes |
| is_agent | boolean | | No |
| mode | string | | Yes |
| name | string | | Yes |
| service_api_enabled | boolean | | Yes |
| tags | [ [TagItem](#tagitem) ], <br>**Default:** | | No |
| updated_at | string | | No |
#### AppDescribeQuery
@ -568,16 +581,14 @@ Request body for POST /workspaces/<workspace_id>/apps/imports.
| yaml_content | string | Inline YAML DSL string (required when mode is yaml-content) | No |
| yaml_url | string | Remote URL to fetch YAML from (required when mode is yaml-url) | No |
#### AppInfoResponse
#### AppInfo
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| author | string | | No |
| description | string | | No |
| id | string | | Yes |
| mode | string | | Yes |
| name | string | | Yes |
| tags | [ [TagItem](#tagitem) ], <br>**Default:** | | No |
#### AppListQuery
@ -589,7 +600,6 @@ mode is a closed enum.
| mode | [AppMode](#appmode) | | No |
| name | string | | No |
| page | integer, <br>**Default:** 1 | | No |
| tag | string | | No |
| workspace_id | string | | Yes |
#### AppListResponse
@ -606,12 +616,10 @@ mode is a closed enum.
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| created_by_name | string | | No |
| description | string | | No |
| id | string | | Yes |
| mode | [AppMode](#appmode) | | Yes |
| name | string | | Yes |
| tags | [ [TagItem](#tagitem) ], <br>**Default:** | | No |
| updated_at | string | | No |
| workspace_id | string | | No |
| workspace_name | string | | No |
@ -982,12 +990,6 @@ Pagination for GET /account/sessions. Strict (extra='forbid').
| last_used_at | string | | No |
| prefix | string | | Yes |
#### TagItem
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| name | string | | Yes |
#### TaskStopResponse
200 body for POST /apps/<id>/tasks/<task_id>/stop. The handler always returns

View File

@ -1,3 +1,4 @@
import logging
import time
import uuid
from datetime import datetime
@ -142,9 +143,8 @@ class TestTraceClient:
mock_notify.assert_called_once()
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.logger")
def test_add_span_queue_full(
self, mock_logger: MagicMock, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]
self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient], caplog: pytest.LogCaptureFixture
):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
@ -164,12 +164,15 @@ class TestTraceClient:
client.add_span(span_data)
assert len(client.queue) == 1
client.add_span(span_data)
assert len(client.queue) == 1
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
with caplog.at_level(logging.WARNING):
client.add_span(span_data)
assert len(client.queue) == 1
assert "Queue is full, likely spans will be dropped." in caplog.text
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_export_batch_error(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
def test_export_batch_error(
self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient], caplog: pytest.LogCaptureFixture
):
mock_exporter = mock_exporter_class.return_value
mock_exporter.export.side_effect = Exception("Export failed")
@ -177,9 +180,9 @@ class TestTraceClient:
mock_span = MagicMock(spec=ReadableSpan)
client.queue.append(mock_span)
with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger:
with caplog.at_level(logging.WARNING):
client._export_batch()
mock_logger.warning.assert_called()
assert "Error exporting spans" in caplog.text
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_worker_loop(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):

View File

@ -307,13 +307,12 @@ class TestGetProjectUrl:
monkeypatch.setattr(trace_instance, "entity", None)
monkeypatch.setattr(trace_instance, "project_name", None)
# Force an error by making string formatting fail
with patch("dify_trace_weave.weave_trace.logger") as mock_logger:
# Simulate exception via property
original_entity = trace_instance.entity
trace_instance.entity = None
trace_instance.project_name = None
url = trace_instance.get_project_url()
assert "https://wandb.ai/" in url
# Simulate exception via property
original_entity = trace_instance.entity
trace_instance.entity = None
trace_instance.project_name = None
url = trace_instance.get_project_url()
assert "https://wandb.ai/" in url
# ── TestTraceDispatcher ─────────────────────────────────────────────────────

View File

@ -830,6 +830,16 @@ class AgentComposerService:
) -> WorkflowAgentNodeBinding:
node_job = payload.node_job or WorkflowNodeJobConfig()
if binding:
if cls._is_start_from_scratch_request(binding=binding, payload=payload):
return cls._switch_roster_binding_to_inline_agent(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
node_id=node_id,
account_id=account_id,
binding=binding,
payload=payload,
)
binding.node_job_config = node_job
if payload.agent_soul is not None and binding.binding_type == WorkflowAgentBindingType.INLINE_AGENT:
current_snapshot = cls._require_version(
@ -880,6 +890,46 @@ class AgentComposerService:
db.session.flush()
return binding
@classmethod
def _is_start_from_scratch_request(cls, *, binding: WorkflowAgentNodeBinding, payload: ComposerSavePayload) -> bool:
return (
binding.binding_type == WorkflowAgentBindingType.ROSTER_AGENT
and payload.binding is not None
and payload.binding.binding_type == WorkflowAgentBindingType.INLINE_AGENT.value
)
@classmethod
def _switch_roster_binding_to_inline_agent(
cls,
*,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
account_id: str,
binding: WorkflowAgentNodeBinding,
payload: ComposerSavePayload,
) -> WorkflowAgentNodeBinding:
if payload.binding and (payload.binding.agent_id or payload.binding.current_snapshot_id):
raise ValueError("Start from Scratch must not provide an existing inline agent binding.")
agent_soul = payload.agent_soul or AgentSoulConfig()
agent = cls._create_workflow_only_agent(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
node_id=node_id,
account_id=account_id,
agent_soul=agent_soul,
)
binding.binding_type = WorkflowAgentBindingType.INLINE_AGENT
binding.agent_id = agent.id
binding.current_snapshot_id = agent.active_config_snapshot_id
binding.node_job_config = payload.node_job or binding.node_job_config
binding.updated_by = account_id
db.session.flush()
return binding
@classmethod
def _save_to_current_version(
cls,

View File

@ -3,6 +3,7 @@ from typing import Any, TypedDict
from sqlalchemy import and_, func, or_, select
from sqlalchemy.exc import IntegrityError
from core.app.entities.app_invoke_entities import InvokeFrom
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from models.agent import (
@ -10,6 +11,7 @@ from models.agent import (
AgentConfigRevision,
AgentConfigRevisionOperation,
AgentConfigSnapshot,
AgentDebugConversation,
AgentKind,
AgentScope,
AgentSource,
@ -18,8 +20,8 @@ from models.agent import (
WorkflowAgentNodeBinding,
)
from models.agent_config_entities import AgentSoulConfig
from models.enums import AppStatus
from models.model import App, AppMode, IconType
from models.enums import AppStatus, ConversationFromSource, ConversationStatus
from models.model import App, AppMode, Conversation, IconType
from models.workflow import Workflow
from services.agent.agent_soul_state import agent_soul_has_model
from services.agent.composer_validator import ComposerConfigValidator
@ -96,6 +98,7 @@ class AgentRosterService:
"scope": agent.scope.value,
"source": agent.source.value,
"app_id": agent.app_id,
"debug_conversation_id": None,
"workflow_id": agent.workflow_id,
"workflow_node_id": agent.workflow_node_id,
"active_config_snapshot_id": agent.active_config_snapshot_id,
@ -392,8 +395,126 @@ class AgentRosterService:
agent.active_config_snapshot_id = version.id
agent.active_config_has_model = agent_soul_has_model(AgentSoulConfig())
self._session.flush()
self._get_or_create_agent_app_debug_conversation(agent=agent, account_id=account_id)
return agent
def _create_agent_app_debug_conversation(self, *, app_id: str, account_id: str) -> str:
"""Create one console debug conversation for an Agent App editor."""
conversation = Conversation(
app_id=app_id,
app_model_config_id=None,
model_provider=None,
model_id="",
override_model_configs=None,
mode=AppMode.AGENT,
name="Agent Debugging Conversation",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status=ConversationStatus.NORMAL,
invoke_from=InvokeFrom.DEBUGGER,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account_id,
)
self._session.add(conversation)
self._session.flush()
return conversation.id
def _get_or_create_agent_app_debug_conversation(self, *, agent: Agent, account_id: str) -> str:
if not agent.app_id:
raise AgentNotFoundError()
mapping = self._session.scalar(
select(AgentDebugConversation).where(
AgentDebugConversation.tenant_id == agent.tenant_id,
AgentDebugConversation.agent_id == agent.id,
AgentDebugConversation.account_id == account_id,
)
)
if mapping is not None:
conversation_id = self._session.scalar(
select(Conversation.id).where(
Conversation.id == mapping.conversation_id,
Conversation.app_id == agent.app_id,
Conversation.from_source == ConversationFromSource.CONSOLE,
Conversation.from_account_id == account_id,
Conversation.is_deleted.is_(False),
)
)
if conversation_id:
return conversation_id
mapping.conversation_id = self._create_agent_app_debug_conversation(
app_id=agent.app_id,
account_id=account_id,
)
self._session.flush()
return mapping.conversation_id
conversation_id = self._create_agent_app_debug_conversation(
app_id=agent.app_id,
account_id=account_id,
)
self._session.add(
AgentDebugConversation(
tenant_id=agent.tenant_id,
agent_id=agent.id,
app_id=agent.app_id,
account_id=account_id,
conversation_id=conversation_id,
)
)
self._session.flush()
return conversation_id
def get_or_create_agent_app_debug_conversation_id(
self, *, tenant_id: str, agent_id: str, account_id: str, commit: bool = True
) -> str:
"""Return the current editor's debug conversation for an Agent App."""
agent = self._session.scalar(
select(Agent).where(
Agent.tenant_id == tenant_id,
Agent.id == agent_id,
Agent.scope == AgentScope.ROSTER,
Agent.source == AgentSource.AGENT_APP,
Agent.status == AgentStatus.ACTIVE,
)
)
if agent is None:
raise AgentNotFoundError()
conversation_id = self._get_or_create_agent_app_debug_conversation(agent=agent, account_id=account_id)
if commit:
self._session.commit()
return conversation_id
def load_or_create_agent_app_debug_conversation_ids_by_agent_id(
self, *, tenant_id: str, agents: list[Agent], account_id: str
) -> dict[str, str]:
"""Return per-account debug conversations for a page of Agent Apps."""
conversation_ids_by_agent_id: dict[str, str] = {}
changed = False
for agent in agents:
if (
agent.tenant_id != tenant_id
or agent.scope != AgentScope.ROSTER
or agent.source != AgentSource.AGENT_APP
):
continue
conversation_ids_by_agent_id[agent.id] = self._get_or_create_agent_app_debug_conversation(
agent=agent,
account_id=account_id,
)
changed = True
if changed:
self._session.commit()
return conversation_ids_by_agent_id
def load_app_backing_agents_by_app_id(self, *, tenant_id: str, app_ids: list[str]) -> dict[str, Agent]:
"""Return active app-backed Agents keyed by Agent App id."""
if not app_ids:
@ -666,12 +787,16 @@ class AgentRosterService:
@staticmethod
def _visible_version_operations(agent: Agent) -> set[AgentConfigRevisionOperation]:
if agent.source == AgentSource.AGENT_APP:
return {AgentConfigRevisionOperation.SAVE_NEW_VERSION}
return {
AgentConfigRevisionOperation.SAVE_NEW_VERSION,
AgentConfigRevisionOperation.RESTORE_VERSION,
}
return {
AgentConfigRevisionOperation.CREATE_VERSION,
AgentConfigRevisionOperation.SAVE_NEW_VERSION,
AgentConfigRevisionOperation.SAVE_NEW_AGENT,
AgentConfigRevisionOperation.SAVE_TO_ROSTER,
AgentConfigRevisionOperation.RESTORE_VERSION,
}
def active_config_is_published(self, *, tenant_id: str, agent: Agent) -> bool:
@ -764,6 +889,46 @@ class AgentRosterService:
]
return result
def restore_agent_version(
self, *, tenant_id: str, agent_id: str, version_id: str, account_id: str
) -> dict[str, Any]:
agent = self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
visible_version_ids = self._visible_version_ids_stmt(tenant_id=tenant_id, agent_id=agent_id, agent=agent)
visible_version_id = self._session.scalar(
select(AgentConfigSnapshot.id)
.where(
AgentConfigSnapshot.tenant_id == tenant_id,
AgentConfigSnapshot.agent_id == agent_id,
AgentConfigSnapshot.id == version_id,
AgentConfigSnapshot.id.in_(select(visible_version_ids.c.current_snapshot_id)),
)
.limit(1)
)
if not visible_version_id:
raise AgentVersionNotFoundError()
version = self._get_version(tenant_id=tenant_id, agent_id=agent_id, version_id=version_id)
if agent.active_config_snapshot_id == version.id:
return {"result": "success", "active_config_snapshot_id": version.id}
previous_snapshot_id = agent.active_config_snapshot_id
agent.active_config_snapshot_id = version.id
agent.active_config_has_model = agent_soul_has_model(version.config_snapshot)
agent.updated_by = account_id
self._session.add(
AgentConfigRevision(
tenant_id=tenant_id,
agent_id=agent_id,
previous_snapshot_id=previous_snapshot_id,
current_snapshot_id=version.id,
revision=self._next_revision(tenant_id=tenant_id, agent_id=agent_id),
operation=AgentConfigRevisionOperation.RESTORE_VERSION,
created_by=account_id,
)
)
self._session.commit()
return {"result": "success", "active_config_snapshot_id": version.id}
def _get_agent(self, *, tenant_id: str, agent_id: str, roster_only: bool = False) -> Agent:
stmt = select(Agent).where(Agent.tenant_id == tenant_id, Agent.id == agent_id)
if roster_only:
@ -789,6 +954,17 @@ class AgentRosterService:
raise AgentVersionNotFoundError()
return version
def _next_revision(self, *, tenant_id: str, agent_id: str) -> int:
return (
self._session.scalar(
select(func.max(AgentConfigRevision.revision)).where(
AgentConfigRevision.tenant_id == tenant_id,
AgentConfigRevision.agent_id == agent_id,
)
)
or 0
) + 1
def _load_published_active_snapshot_agent_ids(self, *, tenant_id: str, agents: list[Agent]) -> set[str]:
predicates = [
and_(

View File

@ -23,7 +23,7 @@ from typing import Any
from core.tools.tool_file_manager import ToolFileManager
from models.agent_config_entities import AgentSkillRefConfig
from services.agent.skill_package_service import SkillPackageService
from services.agent_drive_service import AgentDriveService, DriveCommitItem, DriveFileRef
from services.agent_drive_service import AgentDriveService, DriveCommitItem, DriveFileRef, DriveSkillMetadata
_FULL_ARCHIVE_NAME = ".DIFY-SKILL-FULL.zip"
_SKILL_MD_NAME = "SKILL.md"
@ -91,6 +91,12 @@ class SkillStandardizeService:
key=skill_md_key,
file_ref=DriveFileRef(kind="tool_file", id=md_tool_file.id),
value_owned_by_drive=True,
is_skill=True,
skill_metadata=DriveSkillMetadata(
name=manifest.name,
description=manifest.description,
manifest_files=manifest.files,
),
),
DriveCommitItem(
key=archive_key,

View File

@ -17,12 +17,14 @@ ToolFile records (see ``AgentDriveFile``). This service is the control plane:
from __future__ import annotations
import json
import logging
import re
import urllib.parse
from typing import Any, Literal
from typing import Any, Literal, TypedDict
from urllib.parse import unquote
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, field_validator
from sqlalchemy import func, select
from sqlalchemy.exc import DataError, SQLAlchemyError
from sqlalchemy.orm import Session
@ -41,6 +43,8 @@ logger = logging.getLogger(__name__)
_MAX_KEY_LENGTH = 512
_DRIVE_REF_PREFIX = "agent-"
_SKILL_MD_SUFFIX = "/SKILL.md"
_SKILL_ARCHIVE_NAME = ".DIFY-SKILL-FULL.zip"
class AgentDriveError(Exception):
@ -58,16 +62,86 @@ class AgentDriveError(Exception):
class DriveFileRef(BaseModel):
model_config = ConfigDict(extra="forbid")
kind: Literal["upload_file", "tool_file"]
id: str
class DriveSkillMetadata(BaseModel):
"""Validated skill catalog metadata stored as a JSON string on the drive row."""
model_config = ConfigDict(extra="forbid")
name: str
description: str = ""
# Safe archive member paths captured during skill standardization. The drive
# stores only canonical SKILL.md + full archive, so the UI uses this manifest
# to show the original uploaded package contents.
manifest_files: list[str] | None = None
@field_validator("name")
@classmethod
def _validate_name(cls, value: str) -> str:
normalized = value.strip()
if not normalized:
raise ValueError("skill metadata name must not be blank")
return normalized
class DriveCommitItem(BaseModel):
model_config = ConfigDict(extra="forbid")
key: str
file_ref: DriveFileRef
# Drive-owned values may be physically cleaned on overwrite/removal; refs to
# files shared with other business records should set this False.
value_owned_by_drive: bool = True
is_skill: bool = False
skill_metadata: DriveSkillMetadata | None = None
class AgentDriveSkillInfo(TypedDict):
path: str
skill_md_key: str
archive_key: str | None
name: str
description: str
size: int | None
mime_type: str | None
hash: str | None
created_at: int | None
class AgentDriveSkillFileInfo(TypedDict):
path: str
name: str
type: str
drive_key: str | None
available_in_drive: bool
class AgentDriveSkillInspectInfo(TypedDict):
path: str
skill_md_key: str
archive_key: str | None
name: str
description: str
size: int | None
mime_type: str | None
hash: str | None
created_at: int | None
source: str
files: list[AgentDriveSkillFileInfo]
file_tree: list[dict[str, Any]]
skill_md: dict[str, Any]
warnings: list[str]
def decode_drive_mention_ref(ref_id: str) -> str:
"""Decode the prompt token's URL-encoded drive-key field."""
return unquote(ref_id or "")
def parse_agent_drive_ref(drive_ref: str) -> str:
@ -132,6 +206,8 @@ class AgentDriveService:
"mime_type": row.mime_type,
"file_kind": row.file_kind.value,
"file_id": row.file_id,
"is_skill": row.is_skill,
"skill_metadata": row.skill_metadata,
"created_at": int(row.created_at.timestamp()) if row.created_at else None,
}
if include_download_url:
@ -217,6 +293,87 @@ class AgentDriveService:
self._delete_storage(storage_key)
return removed_keys
def list_skills(self, *, tenant_id: str, agent_id: str) -> list[AgentDriveSkillInfo]:
"""Return the drive-backed skill catalog derived from canonical ``SKILL.md`` rows."""
with session_factory.create_session() as session:
self._assert_agent_belongs_to_tenant(session, tenant_id=tenant_id, agent_id=agent_id)
skill_rows = list(
session.scalars(
select(AgentDriveFile)
.where(
AgentDriveFile.tenant_id == tenant_id,
AgentDriveFile.agent_id == agent_id,
AgentDriveFile.is_skill.is_(True),
)
.order_by(AgentDriveFile.key)
)
)
archive_keys = set(
session.scalars(
select(AgentDriveFile.key).where(
AgentDriveFile.tenant_id == tenant_id,
AgentDriveFile.agent_id == agent_id,
AgentDriveFile.key.in_([self._skill_archive_key(row.key) for row in skill_rows]),
)
)
)
skills: list[AgentDriveSkillInfo] = []
for row in skill_rows:
metadata = self._parse_skill_metadata(row.key, row.skill_metadata)
archive_key = self._skill_archive_key(row.key)
skills.append(
{
"path": self._skill_path_from_key(row.key),
"skill_md_key": row.key,
"archive_key": archive_key if archive_key in archive_keys else None,
"name": metadata.name,
"description": metadata.description,
"size": row.size,
"mime_type": row.mime_type,
"hash": row.hash,
"created_at": int(row.created_at.timestamp()) if row.created_at else None,
}
)
return skills
def inspect_skill(self, *, tenant_id: str, agent_id: str, skill_path: str) -> AgentDriveSkillInspectInfo:
"""Return the UI-facing skill inspect view for slash-menu hover/detail."""
skill_path = normalize_drive_key(skill_path)
skill_md_key = skill_path if skill_path.endswith(_SKILL_MD_SUFFIX) else f"{skill_path}{_SKILL_MD_SUFFIX}"
skill_path = self._skill_path_from_key(skill_md_key)
catalog = next(
(item for item in self.list_skills(tenant_id=tenant_id, agent_id=agent_id) if item["path"] == skill_path),
None,
)
if catalog is None:
raise AgentDriveError("skill_not_found", "no drive-backed skill for this path", status_code=404)
manifest_files = self._manifest_files_from_skill_metadata(
tenant_id=tenant_id,
agent_id=agent_id,
skill_md_key=skill_md_key,
)
drive_items = self.manifest(tenant_id=tenant_id, agent_id=agent_id, prefix=f"{skill_path}/")
drive_keys = {item["key"] for item in drive_items}
preview = self.preview(tenant_id=tenant_id, agent_id=agent_id, key=skill_md_key)
files, warnings = self._skill_file_entries(
skill_path=skill_path,
skill_md_key=skill_md_key,
manifest_files=manifest_files,
drive_keys=drive_keys,
)
return {
**catalog,
"source": "skill_md",
"files": files,
"file_tree": self._build_file_tree(files),
"skill_md": preview,
"warnings": warnings,
}
def _commit_one(
self,
session: Session,
@ -228,9 +385,10 @@ class AgentDriveService:
pending_storage_deletes: list[str],
) -> dict[str, Any]:
key = normalize_drive_key(item.key)
skill_metadata = self._validate_skill_commit_fields(key=key, item=item)
file_kind = AgentDriveFileKind(item.file_ref.kind)
file_id = item.file_ref.id
size, mime_type = self._validate_source(
size, mime_type, file_hash = self._validate_source(
session, tenant_id=tenant_id, user_id=user_id, file_kind=file_kind, file_id=file_id
)
@ -245,6 +403,11 @@ class AgentDriveService:
# Idempotent re-commit of the same value: leave it (do not clean).
if existing.file_kind == file_kind and existing.file_id == file_id:
existing.value_owned_by_drive = item.value_owned_by_drive
existing.is_skill = item.is_skill
existing.skill_metadata = skill_metadata
existing.size = size
existing.mime_type = mime_type
existing.hash = file_hash
return self._row_dict(existing)
# Overwrite: clean the previous drive-owned value if no longer referenced.
if existing.value_owned_by_drive:
@ -259,7 +422,10 @@ class AgentDriveService:
existing.file_kind = file_kind
existing.file_id = file_id
existing.value_owned_by_drive = item.value_owned_by_drive
existing.is_skill = item.is_skill
existing.skill_metadata = skill_metadata
existing.size = size
existing.hash = file_hash
existing.mime_type = mime_type
return self._row_dict(existing)
@ -271,7 +437,10 @@ class AgentDriveService:
file_kind=file_kind,
file_id=file_id,
value_owned_by_drive=item.value_owned_by_drive,
is_skill=item.is_skill,
skill_metadata=skill_metadata,
size=size,
hash=file_hash,
mime_type=mime_type,
created_by=user_id,
)
@ -287,8 +456,187 @@ class AgentDriveService:
"size": row.size,
"mime_type": row.mime_type,
"value_owned_by_drive": row.value_owned_by_drive,
"is_skill": row.is_skill,
"skill_metadata": row.skill_metadata,
}
@staticmethod
def _skill_path_from_key(key: str) -> str:
if not key.endswith(_SKILL_MD_SUFFIX):
raise AgentDriveError(
"invalid_skill_key",
"skill rows must use the canonical '<path>/SKILL.md' key",
status_code=500,
)
path = key[: -len(_SKILL_MD_SUFFIX)]
if not path:
raise AgentDriveError(
"invalid_skill_key",
"skill rows must use the canonical '<path>/SKILL.md' key",
status_code=500,
)
return path
@classmethod
def _skill_archive_key(cls, key: str) -> str:
return f"{cls._skill_path_from_key(key)}/{_SKILL_ARCHIVE_NAME}"
@classmethod
def _validate_skill_commit_fields(cls, *, key: str, item: DriveCommitItem) -> str | None:
if not item.is_skill:
if item.skill_metadata is not None:
raise AgentDriveError(
"invalid_skill_metadata",
"skill metadata is only allowed for canonical skill rows",
status_code=400,
)
return None
cls._skill_path_from_key(key)
if item.skill_metadata is None:
raise AgentDriveError(
"invalid_skill_metadata",
"skill metadata is required for canonical skill rows",
status_code=400,
)
return json.dumps(
item.skill_metadata.model_dump(mode="json", exclude_none=True),
separators=(",", ":"),
sort_keys=True,
)
@staticmethod
def _parse_skill_metadata(key: str, raw_metadata: str | None) -> DriveSkillMetadata:
if raw_metadata is None:
raise AgentDriveError(
"invalid_skill_metadata",
f"skill row '{key}' is missing required metadata",
status_code=500,
)
try:
return DriveSkillMetadata.model_validate(json.loads(raw_metadata))
except (ValueError, TypeError) as exc:
raise AgentDriveError(
"invalid_skill_metadata",
f"skill row '{key}' has invalid stored metadata",
status_code=500,
) from exc
@staticmethod
def _manifest_files_from_skill_metadata(*, tenant_id: str, agent_id: str, skill_md_key: str) -> list[str] | None:
with session_factory.create_session() as session:
row = session.scalar(
select(AgentDriveFile).where(
AgentDriveFile.tenant_id == tenant_id,
AgentDriveFile.agent_id == agent_id,
AgentDriveFile.key == skill_md_key,
AgentDriveFile.is_skill.is_(True),
)
)
if row is None:
return None
try:
metadata = AgentDriveService._parse_skill_metadata(row.key, row.skill_metadata)
except Exception:
logger.warning("drive skill inspect: malformed skill metadata for %s", skill_md_key, exc_info=True)
return None
return [str(item) for item in (metadata.manifest_files or []) if str(item).strip()] or None
@classmethod
def _skill_file_entries(
cls,
*,
skill_path: str,
skill_md_key: str,
manifest_files: list[str] | None,
drive_keys: set[str],
) -> tuple[list[AgentDriveSkillFileInfo], list[str]]:
warnings: list[str] = []
if manifest_files:
paths = sorted({normalize_drive_key(path) for path in manifest_files})
else:
paths = sorted(
{
key.removeprefix(f"{skill_path}/")
for key in drive_keys
if not key.endswith(f"/{_SKILL_ARCHIVE_NAME}")
}
)
warnings.append("manifest_files_unavailable")
files: list[AgentDriveSkillFileInfo] = []
for path in paths:
if path == _SKILL_ARCHIVE_NAME:
continue
drive_key = f"{skill_path}/{path}"
files.append(
{
"path": path,
"name": path.rsplit("/", 1)[-1],
"type": "file",
"drive_key": drive_key if drive_key in drive_keys else None,
"available_in_drive": drive_key in drive_keys,
}
)
if "SKILL.md" not in {file["path"] for file in files}:
files.insert(
0,
{
"path": "SKILL.md",
"name": "SKILL.md",
"type": "file",
"drive_key": skill_md_key,
"available_in_drive": skill_md_key in drive_keys,
},
)
return files, warnings
@staticmethod
def _build_file_tree(files: list[AgentDriveSkillFileInfo]) -> list[dict[str, Any]]:
root: dict[str, Any] = {}
for file in files:
cursor = root
parts = [part for part in file["path"].split("/") if part]
path_parts: list[str] = []
for part in parts[:-1]:
path_parts.append(part)
directory = cursor.setdefault(
part,
{
"name": part,
"path": "/".join(path_parts),
"type": "directory",
"children": {},
},
)
cursor = directory["children"]
leaf_name = parts[-1] if parts else file["name"]
cursor[leaf_name] = {
"name": leaf_name,
"path": file["path"],
"type": file["type"],
"drive_key": file["drive_key"],
"available_in_drive": file["available_in_drive"],
}
def serialize(node: dict[str, Any]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = []
for item in sorted(node.values(), key=lambda value: (value["type"] != "directory", value["name"])):
if item["type"] == "directory":
children = serialize(item["children"])
result.append(
{
"name": item["name"],
"path": item["path"],
"type": "directory",
"children": children,
}
)
else:
result.append(item)
return result
return serialize(root)
@staticmethod
def _assert_agent_belongs_to_tenant(session: Session, *, tenant_id: str, agent_id: str) -> None:
try:
@ -309,7 +657,7 @@ class AgentDriveService:
user_id: str,
file_kind: AgentDriveFileKind,
file_id: str,
) -> tuple[int | None, str | None]:
) -> tuple[int | None, str | None, str | None]:
"""Verify the source file exists for the tenant (and user, for ToolFile).
Malformed ids (e.g. a non-UUID hitting a UUID column) are treated as a
@ -328,7 +676,7 @@ class AgentDriveService:
raise AgentDriveError(
"source_not_found", "source ToolFile not found for this tenant/user", status_code=404
)
return tool_file.size, tool_file.mimetype
return tool_file.size, tool_file.mimetype, None
upload_file = session.scalar(
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id)
)
@ -337,7 +685,7 @@ class AgentDriveService:
raise AgentDriveError("source_not_found", "source file ref is invalid", status_code=404) from exc
if upload_file is None:
raise AgentDriveError("source_not_found", "source UploadFile not found for this tenant", status_code=404)
return upload_file.size, upload_file.mime_type
return upload_file.size, upload_file.mime_type, upload_file.hash
def _cleanup_value(
self,
@ -509,6 +857,8 @@ __all__ = [
"AgentDriveService",
"DriveCommitItem",
"DriveFileRef",
"DriveSkillMetadata",
"decode_drive_mention_ref",
"normalize_drive_key",
"parse_agent_drive_ref",
]

View File

@ -1,4 +1,12 @@
"""Tenant credit pool accounting.
Credit deductions are guarded by a tenant-level Redis lock before the database
row lock is acquired. This keeps concurrent usage accounting for one tenant
from piling up database transactions while preserving cross-tenant concurrency.
"""
import logging
from collections.abc import Callable
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -7,13 +15,44 @@ from configs import dify_config
from core.db.session_factory import session_factory
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import TenantCreditPool
from models.enums import ProviderQuotaType
logger = logging.getLogger(__name__)
CREDIT_POOL_TENANT_LOCK_TIMEOUT_SECONDS = 10
CREDIT_POOL_TENANT_LOCK_BLOCKING_TIMEOUT_SECONDS = 5
class CreditPoolService:
@staticmethod
def _get_tenant_lock_key(tenant_id: str) -> str:
return f"credit_pool:tenant:{tenant_id}:deduct_lock"
@classmethod
def _deduct_with_tenant_lock(cls, tenant_id: str, deduct: Callable[[], int]) -> int:
lock_key = cls._get_tenant_lock_key(tenant_id)
lock = redis_client.lock(
lock_key,
timeout=CREDIT_POOL_TENANT_LOCK_TIMEOUT_SECONDS,
blocking_timeout=CREDIT_POOL_TENANT_LOCK_BLOCKING_TIMEOUT_SECONDS,
)
lock_acquired = False
try:
lock_acquired = lock.acquire(blocking=True)
if not lock_acquired:
raise QuotaExceededError("Failed to acquire credit pool lock")
return deduct()
finally:
if lock_acquired:
try:
lock.release()
except Exception:
logger.warning("Failed to release credit pool lock, tenant_id=%s", tenant_id, exc_info=True)
@staticmethod
def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None:
return session.scalar(
@ -76,7 +115,7 @@ class CreditPoolService:
if credits_required <= 0:
return 0
try:
def deduct() -> int:
with session_factory.get_session_maker().begin() as session:
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
if not pool:
@ -89,14 +128,16 @@ class CreditPoolService:
raise QuotaExceededError("Insufficient credits remaining")
pool.quota_used += credits_required
return credits_required
try:
return cls._deduct_with_tenant_lock(tenant_id, deduct)
except QuotaExceededError:
raise
except Exception:
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits")
return credits_required
@classmethod
def deduct_credits_capped(
cls,
@ -108,7 +149,7 @@ class CreditPoolService:
if credits_required <= 0:
return 0
try:
def deduct() -> int:
with session_factory.get_session_maker().begin() as session:
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
if not pool:
@ -121,6 +162,9 @@ class CreditPoolService:
pool.quota_used += deducted_credits
return deducted_credits
try:
return cls._deduct_with_tenant_lock(tenant_id, deduct)
except QuotaExceededError:
raise
except Exception:

View File

@ -182,6 +182,10 @@ class EnterpriseRequest(BaseRequest):
inner_headers: dict[str, str] = {INNER_TENANT_ID_HEADER: tenant_id}
if account_id:
inner_headers[INNER_ACCOUNT_ID_HEADER] = account_id
if not cls.base_url.startswith("http") or not cls.base_url.startswith("https") or not cls.base_url:
raise ValueError("ENTERPRISE_RBAC_API_URL is required when RBAC_ENABLED=true")
url = f"{cls.rbac_base_url}{endpoint}"
mounts = cls._build_mounts()

View File

@ -312,15 +312,26 @@ _LEGACY_WORKSPACE_OWNER_KEYS: list[str] = [
"plugin.manage",
"plugin.debug",
"credential.use",
"credential.create",
"credential.manage",
"billing.view",
"billing.subscription.manage",
"billing.manage",
"app.acl.preview",
"app_library.access",
"app.create_and_management",
"app.tag.manage",
"dataset.acl.preview",
"dataset.create_and_management",
"dataset.tag.manage",
"dataset.external.connect",
"dataset.api_key.manage",
"snippets.create_and_modify",
"snippets.management",
"tool.manage",
"mcp.manage",
"snippets.create_and_modify",
"snippets.management",
]
_LEGACY_WORKSPACE_ADMIN_KEYS: list[str] = [
@ -334,15 +345,24 @@ _LEGACY_WORKSPACE_ADMIN_KEYS: list[str] = [
"plugin.manage",
"plugin.debug",
"credential.use",
"credential.create",
"credential.manage",
"billing.view",
"billing.subscription.manage",
"billing.manage",
"app_library.access",
"app.create_and_management",
"app.tag.manage",
"dataset.create_and_management",
"dataset.tag.manage",
"dataset.external.connect",
"dataset.api_key.manage",
"snippets.create_and_modify",
"snippets.management",
"tool.manage",
"mcp.manage",
"snippets.create_and_modify",
"snippets.management",
]
_LEGACY_WORKSPACE_EDITOR_KEYS: list[str] = [
@ -356,7 +376,9 @@ _LEGACY_WORKSPACE_EDITOR_KEYS: list[str] = [
"dataset.create_and_management",
"dataset.tag.manage",
"dataset.external.connect",
"snippets.create_and_modify",
"tool.manage",
"snippets.create_and_modify",
]
_LEGACY_WORKSPACE_NORMAL_KEYS: list[str] = [
@ -373,6 +395,7 @@ _LEGACY_WORKSPACE_DATASET_OPERATOR_KEYS: list[str] = [
]
_LEGACY_APP_OWNER_KEYS: list[str] = [
"app.acl.preview",
"app.acl.view_layout",
"app.acl.test_and_run",
"app.acl.edit",
@ -384,6 +407,7 @@ _LEGACY_APP_OWNER_KEYS: list[str] = [
]
_LEGACY_APP_ADMIN_KEYS: list[str] = [
"app.acl.preview",
"app.acl.view_layout",
"app.acl.test_and_run",
"app.acl.edit",
@ -395,6 +419,7 @@ _LEGACY_APP_ADMIN_KEYS: list[str] = [
]
_LEGACY_APP_EDITOR_KEYS: list[str] = [
"app.acl.preview",
"app.acl.view_layout",
"app.acl.test_and_run",
"app.acl.edit",
@ -406,12 +431,14 @@ _LEGACY_APP_EDITOR_KEYS: list[str] = [
]
_LEGACY_APP_NORMAL_KEYS: list[str] = [
"app.acl.preview",
"app.acl.view_layout",
"app.acl.test_and_run",
"app.acl.monitor",
]
_LEGACY_DATASET_OWNER_KEYS: list[str] = [
"dataset.acl.preview",
"dataset.acl.readonly",
"dataset.acl.edit",
"dataset.acl.import_export_dsl",
@ -427,6 +454,7 @@ _LEGACY_DATASET_OWNER_KEYS: list[str] = [
]
_LEGACY_DATASET_ADMIN_KEYS: list[str] = [
"dataset.acl.preview",
"dataset.acl.readonly",
"dataset.acl.edit",
"dataset.acl.import_export_dsl",
@ -442,6 +470,7 @@ _LEGACY_DATASET_ADMIN_KEYS: list[str] = [
]
_LEGACY_DATASET_EDITOR_KEYS: list[str] = [
"dataset.acl.preview",
"dataset.acl.readonly",
"dataset.acl.edit",
"dataset.acl.import_export_dsl",
@ -492,6 +521,19 @@ _LEGACY_MY_PERMISSIONS: dict[TenantAccountRole, dict[str, list[str]]] = {
}
def _legacy_role_permission_keys(role: TenantAccountRole) -> list[str]:
permissions = _LEGACY_MY_PERMISSIONS.get(role, {})
return list(
dict.fromkeys(
[
*permissions.get("workspace", []),
*permissions.get("app", []),
*permissions.get("dataset", []),
]
)
)
def _legacy_my_permissions(tenant_id: str, account_id: str | None) -> MyPermissionsResponse:
if not account_id:
return MyPermissionsResponse()
@ -1518,21 +1560,44 @@ class RBACService:
)
return AccessMatrixItem.model_validate(data or {})
# ------------------------------------------------------------------
# Member ↔ role bindings (screenshot 3: Settings > Members > Assign roles).
# ------------------------------------------------------------------
class MemberRoles:
@staticmethod
def get(tenant_id: str, account_id: str | None, member_account_id: str) -> MemberRolesResponse:
data = _inner_call(
"GET",
f"{_INNER_PREFIX}/members/rbac-roles",
tenant_id=tenant_id,
account_id=account_id,
params={"account_id": member_account_id},
)
rst = MemberRolesResponse.model_validate(data or {})
return rst
if dify_config.RBAC_ENABLED:
data = _inner_call(
"GET",
f"{_INNER_PREFIX}/members/rbac-roles",
tenant_id=tenant_id,
account_id=account_id,
params={"account_id": member_account_id},
)
rst = MemberRolesResponse.model_validate(data or {})
return rst
else:
with session_factory.create_session() as session:
role = session.scalar(
select(TenantAccountJoin.role).where(
TenantAccountJoin.tenant_id == tenant_id,
TenantAccountJoin.account_id == member_account_id,
)
)
return MemberRolesResponse(
account_id=member_account_id,
roles=[
RBACRole(
id="",
name=role,
description="",
is_builtin=True,
type="",
permission_keys=_legacy_role_permission_keys(role),
role_tag="owner" if role == "owner" else role,
tenant_id=tenant_id,
)
]
if role
else [],
)
@staticmethod
def batch_get(

View File

@ -3,6 +3,8 @@ from typing import TypedDict
import httpx
OPERATION_REQUEST_TIMEOUT = httpx.Timeout(10.0, connect=3.0)
class UtmInfo(TypedDict, total=False):
"""Expected shape of the utm_info dict passed to record_utm.
@ -26,7 +28,9 @@ class OperationService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)
response = httpx.request(
method, url, json=json, params=params, headers=headers, timeout=OPERATION_REQUEST_TIMEOUT
)
return response.json()

View File

@ -23,8 +23,11 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_forms import (
load_form_dispositions_by_form_id,
)
from core.workflow.human_input_policy import (
FormDisposition,
HumanInputSurface,
enrich_human_input_pause_reasons,
resolve_human_input_pause_reason_inputs,
@ -359,7 +362,7 @@ def _build_human_input_required_events(
expiration_times_by_form_id: dict[str, int] = {}
display_in_ui_by_form_id: dict[str, bool] = {}
form_tokens_by_form_id: dict[str, str] = {}
dispositions_by_form_id: dict[str, FormDisposition] = {}
if human_input_form_ids and session_maker is not None:
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where(
HumanInputForm.id.in_(human_input_form_ids)
@ -372,7 +375,7 @@ def _build_human_input_required_events(
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_tokens_by_form_id = load_form_tokens_by_form_id(
dispositions_by_form_id = load_form_dispositions_by_form_id(
human_input_form_ids,
session=session,
surface=human_input_surface,
@ -393,6 +396,7 @@ def _build_human_input_required_events(
reason.inputs,
variable_pool=variable_pool,
)
disposition = dispositions_by_form_id.get(form_id)
response = HumanInputRequiredResponse(
task_id=task_id,
@ -405,7 +409,8 @@ def _build_human_input_required_events(
inputs=resolved_inputs,
actions=reason.actions,
display_in_ui=display_in_ui_by_form_id.get(form_id, False),
form_token=form_tokens_by_form_id.get(form_id),
form_token=disposition.form_token if disposition else None,
approval_channels=list(disposition.approval_channels) if disposition else [],
resolved_default_values=reason.resolved_default_values,
expiration_time=expiration_time,
),
@ -493,11 +498,11 @@ def _build_pause_event(
for form_id in [reason.get("form_id")]
if isinstance(form_id, str)
]
form_tokens_by_form_id: dict[str, str] = {}
dispositions_by_form_id: dict[str, FormDisposition] = {}
expiration_times_by_form_id: dict[str, int] = {}
if human_input_form_ids and session_maker is not None:
with session_maker() as session:
form_tokens_by_form_id = load_form_tokens_by_form_id(
dispositions_by_form_id = load_form_dispositions_by_form_id(
human_input_form_ids,
session=session,
surface=human_input_surface,
@ -512,7 +517,7 @@ def _build_pause_event(
# otherwise clients see schema drift after resume.
reasons = enrich_human_input_pause_reasons(
reasons,
form_tokens_by_form_id=form_tokens_by_form_id,
dispositions_by_form_id=dispositions_by_form_id,
expiration_times_by_form_id=expiration_times_by_form_id,
)

View File

@ -20,7 +20,7 @@ from testcontainers.redis import RedisContainer
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
from libs.broadcast_channel.exc import SubscriptionClosedError
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.pubsub_channel import BroadcastChannel as RedisBroadcastChannel
class TestRedisBroadcastChannelIntegration:

View File

@ -0,0 +1,161 @@
"""Unit tests for controllers.common.app_access RBAC app-id access filtering."""
from __future__ import annotations
import pytest
from controllers.common.app_access import (
APP_LIST_PERMISSION_KEYS,
AppAccessFilter,
has_app_list_permission,
resolve_app_access_filter,
)
from services.app_service import AppListParams
from services.enterprise.rbac_service import (
MyPermissionsResponse,
ResourcePermissionKeys,
ResourcePermissionSnapshot,
ResourceWhitelistResources,
WorkspacePermissionSnapshot,
)
_RBAC_MODULE = "controllers.common.app_access.enterprise_rbac_service"
def _permissions(
*,
workspace_keys: list[str] | None = None,
app_default_keys: list[str] | None = None,
app_overrides: list[ResourcePermissionKeys] | None = None,
) -> MyPermissionsResponse:
return MyPermissionsResponse(
workspace=WorkspacePermissionSnapshot(permission_keys=workspace_keys or []),
app=ResourcePermissionSnapshot(
default_permission_keys=app_default_keys or [],
overrides=app_overrides or [],
),
)
class TestHasAppListPermission:
def test_matches_known_preview_keys(self):
for key in APP_LIST_PERMISSION_KEYS:
assert has_app_list_permission([key])
def test_rejects_unknown_keys(self):
assert not has_app_list_permission(["app.export", "app.delete"])
assert not has_app_list_permission([])
class TestAppAccessFilterIsAppAccessible:
def test_unrestricted_sees_everything(self):
flt = AppAccessFilter.unrestricted()
assert flt.is_app_accessible("app-1", maintainer="someone", account_id="acc-1")
def test_whitelisted_app_is_visible(self):
flt = AppAccessFilter(accessible_app_ids={"app-1"}, can_manage_own_apps=False)
assert flt.is_app_accessible("app-1", maintainer=None, account_id="acc-1")
assert not flt.is_app_accessible("app-2", maintainer=None, account_id="acc-1")
def test_own_app_visible_only_with_manage_permission(self):
own = AppAccessFilter(accessible_app_ids=set(), can_manage_own_apps=True)
assert own.is_app_accessible("app-1", maintainer="acc-1", account_id="acc-1")
assert not own.is_app_accessible("app-1", maintainer="acc-2", account_id="acc-1")
no_manage = AppAccessFilter(accessible_app_ids=set(), can_manage_own_apps=False)
assert not no_manage.is_app_accessible("app-1", maintainer="acc-1", account_id="acc-1")
class TestAppAccessFilterApplyToParams:
def test_unrestricted_leaves_params_untouched(self):
params = AppListParams()
AppAccessFilter.unrestricted().apply_to_params(params)
assert params.accessible_app_ids is None
assert params.include_own_apps is False
assert params.is_created_by_me is None
def test_whitelisted_ids_are_sorted_with_own_apps_flag(self):
params = AppListParams()
AppAccessFilter(accessible_app_ids={"b", "a"}, can_manage_own_apps=True).apply_to_params(params)
assert params.accessible_app_ids == ["a", "b"]
assert params.include_own_apps is True
def test_empty_set_with_manage_falls_back_to_maintained_apps(self):
# Own-app fallback must use maintainer (include_own_apps), consistent
# with is_app_accessible — not created_by (is_created_by_me).
params = AppListParams()
AppAccessFilter(accessible_app_ids=set(), can_manage_own_apps=True).apply_to_params(params)
assert params.accessible_app_ids == []
assert params.include_own_apps is True
assert params.is_created_by_me is None
def test_empty_set_without_manage_sees_nothing(self):
params = AppListParams()
AppAccessFilter(accessible_app_ids=set(), can_manage_own_apps=False).apply_to_params(params)
assert params.accessible_app_ids == []
assert params.include_own_apps is False
assert params.is_created_by_me is None
class TestResolveAppAccessFilter:
def _patch_whitelist(self, monkeypatch: pytest.MonkeyPatch, whitelist: ResourceWhitelistResources) -> None:
monkeypatch.setattr(
f"{_RBAC_MODULE}.RBACService.AppAccess.whitelist_resources",
lambda tenant_id, account_id: whitelist,
)
def test_default_preview_is_unrestricted(self, monkeypatch: pytest.MonkeyPatch):
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=True))
permissions = _permissions(app_default_keys=["app.preview"])
flt = resolve_app_access_filter("tenant-1", "acc-1", permissions=permissions)
assert flt.accessible_app_ids is None
assert flt.can_manage_own_apps is False
def test_default_preview_overrides_whitelist_restriction(self, monkeypatch: pytest.MonkeyPatch):
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=False, resource_ids=["app-9"]))
permissions = _permissions(
workspace_keys=["app.full_access", "app.create_and_management"],
)
flt = resolve_app_access_filter("tenant-1", "acc-1", permissions=permissions)
# Workspace-level preview grant defeats the whitelist restriction.
assert flt.accessible_app_ids is None
assert flt.can_manage_own_apps is True
def test_override_apps_collected_without_default_preview(self, monkeypatch: pytest.MonkeyPatch):
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=True))
permissions = _permissions(
app_overrides=[
ResourcePermissionKeys(resource_id="app-1", permission_keys=["app.preview"]),
ResourcePermissionKeys(resource_id="app-2", permission_keys=["app.export"]),
],
)
flt = resolve_app_access_filter("tenant-1", "acc-1", permissions=permissions)
assert flt.accessible_app_ids == {"app-1"}
def test_whitelist_union_with_override_apps(self, monkeypatch: pytest.MonkeyPatch):
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=False, resource_ids=["app-5"]))
permissions = _permissions(
app_overrides=[ResourcePermissionKeys(resource_id="app-1", permission_keys=["app.acl.preview"])],
)
flt = resolve_app_access_filter("tenant-1", "acc-1", permissions=permissions)
assert flt.accessible_app_ids == {"app-1", "app-5"}
def test_fetches_permissions_when_not_supplied(self, monkeypatch: pytest.MonkeyPatch):
self._patch_whitelist(monkeypatch, ResourceWhitelistResources(unrestricted=False, resource_ids=[]))
monkeypatch.setattr(
f"{_RBAC_MODULE}.RBACService.MyPermissions.get",
lambda tenant_id, account_id: _permissions(workspace_keys=["app.create_and_management"]),
)
flt = resolve_app_access_filter("tenant-1", "acc-1")
assert flt.accessible_app_ids == set()
assert flt.can_manage_own_apps is True

View File

@ -20,6 +20,10 @@ from controllers.console.agent.composer import (
WorkflowAgentComposerValidateApi,
)
from controllers.console.agent.roster import (
AgentApiAccessApi,
AgentApiKeyApi,
AgentApiKeyListApi,
AgentApiStatusApi,
AgentAppApi,
AgentAppCopyApi,
AgentAppListApi,
@ -28,6 +32,7 @@ from controllers.console.agent.roster import (
AgentLogsApi,
AgentLogSourcesApi,
AgentRosterVersionDetailApi,
AgentRosterVersionRestoreApi,
AgentRosterVersionsApi,
AgentStatisticsSummaryApi,
)
@ -149,6 +154,10 @@ def test_agent_v2_console_routes_are_agent_id_first() -> None:
"/agent/<uuid:agent_id>/sandbox/files",
"/agent/<uuid:agent_id>/skills/upload",
"/agent/<uuid:agent_id>/files",
"/agent/<uuid:agent_id>/api-access",
"/agent/<uuid:agent_id>/api-enable",
"/agent/<uuid:agent_id>/api-keys",
"/agent/<uuid:agent_id>/api-keys/<uuid:api_key_id>",
"/agent/<uuid:agent_id>/chat-messages",
"/agent/<uuid:agent_id>/chat-messages/<string:task_id>/stop",
"/agent/<uuid:agent_id>/feedbacks",
@ -158,6 +167,9 @@ def test_agent_v2_console_routes_are_agent_id_first() -> None:
"/agent/<uuid:agent_id>/logs/<uuid:conversation_id>/messages",
"/agent/<uuid:agent_id>/log-sources",
"/agent/<uuid:agent_id>/statistics/summary",
"/agent/<uuid:agent_id>/versions",
"/agent/<uuid:agent_id>/versions/<uuid:version_id>",
"/agent/<uuid:agent_id>/versions/<uuid:version_id>/restore",
"/agent/invite-options",
):
assert route in paths
@ -173,6 +185,7 @@ def test_agent_v2_console_routes_are_agent_id_first() -> None:
"/apps/<uuid:app_id>/agent-features",
"/apps/<uuid:app_id>/agent-referencing-workflows",
"/apps/<uuid:app_id>/agent-sandbox/files",
"/apps/<uuid:agent_id>/api-access",
):
assert route not in paths
@ -215,16 +228,34 @@ def test_agent_app_list_and_create_use_agent_route(
roster_controller.AgentRosterService,
"load_app_backing_agents_by_app_id",
lambda _self, **kwargs: {
"app-list": SimpleNamespace(id="agent-list", role="List role", active_config_snapshot_id=None)
"app-list": SimpleNamespace(
id="agent-list",
role="List role",
debug_conversation_id="debug-conversation-list",
active_config_snapshot_id=None,
)
},
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_app_backing_agent",
lambda _self, **kwargs: SimpleNamespace(
id="agent-created", role="Created role", active_config_snapshot_id=None
id="agent-created",
role="Created role",
debug_conversation_id="debug-conversation-created",
active_config_snapshot_id=None,
),
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_or_create_agent_app_debug_conversation_id",
lambda _self, **kwargs: "debug-conversation-detail",
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_or_create_agent_app_debug_conversation_id",
lambda _self, **kwargs: "debug-conversation-detail",
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
"load_published_references_by_agent_id",
@ -245,6 +276,16 @@ def test_agent_app_list_and_create_use_agent_route(
]
},
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
"load_or_create_agent_app_debug_conversation_ids_by_agent_id",
lambda _self, **kwargs: {"agent-list": "debug-conversation-list"},
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_or_create_agent_app_debug_conversation_id",
lambda _self, **kwargs: "debug-conversation-created",
)
monkeypatch.setattr(
roster_controller.FeatureService,
"get_system_features",
@ -259,6 +300,7 @@ def test_agent_app_list_and_create_use_agent_route(
assert listed["total"] == 1
assert listed["data"][0]["id"] == "agent-list"
assert listed["data"][0]["app_id"] == "app-list"
assert listed["data"][0]["debug_conversation_id"] == "debug-conversation-list"
assert listed["data"][0]["role"] == "List role"
assert listed["data"][0]["active_config_is_published"] is False
assert listed["data"][0]["published_reference_count"] == 1
@ -292,6 +334,7 @@ def test_agent_app_list_and_create_use_agent_route(
assert status == 201
assert created["id"] == "agent-created"
assert created["app_id"] == "app-created"
assert created["debug_conversation_id"] == "debug-conversation-created"
assert created["role"] == "Created role"
assert created["active_config_is_published"] is False
assert "bound_agent_id" not in created
@ -332,7 +375,17 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_app_backing_agent",
lambda _self, **kwargs: SimpleNamespace(id=agent_id, role="Resolved role", active_config_snapshot_id=None),
lambda _self, **kwargs: SimpleNamespace(
id=agent_id,
role="Resolved role",
debug_conversation_id="debug-conversation-detail",
active_config_snapshot_id=None,
),
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_or_create_agent_app_debug_conversation_id",
lambda _self, **kwargs: "debug-conversation-detail",
)
monkeypatch.setattr(
roster_controller.FeatureService,
@ -354,9 +407,10 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
monkeypatch.setattr(roster_controller, "AppService", FakeAppService)
detail = unwrap(AgentAppApi.get)(AgentAppApi(), "tenant-1", agent_id)
detail = unwrap(AgentAppApi.get)(AgentAppApi(), "tenant-1", SimpleNamespace(id=account_id), agent_id)
assert detail["id"] == agent_id
assert detail["app_id"] == "app-1"
assert detail["debug_conversation_id"] == "debug-conversation-detail"
assert detail["role"] == "Resolved role"
assert detail["active_config_is_published"] is False
assert "bound_agent_id" not in detail
@ -365,11 +419,12 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
"/console/api/agent/00000000-0000-0000-0000-000000000001",
json={"name": "Renamed", "description": "", "role": "Reviewer", "icon_type": "emoji", "icon": "R"},
):
updated = unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", agent_id)
updated = unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", SimpleNamespace(id=account_id), agent_id)
assert updated["name"] == "Renamed"
assert updated["id"] == agent_id
assert updated["app_id"] == "app-1"
assert updated["debug_conversation_id"] == "debug-conversation-detail"
assert updated["role"] == "Resolved role"
assert updated["active_config_is_published"] is False
assert "bound_agent_id" not in updated
@ -399,7 +454,7 @@ def test_agent_app_copy_uses_agent_id_and_returns_agent_detail(
monkeypatch.setattr(
roster_controller,
"_serialize_agent_app_detail",
lambda app_model: {"id": "copied-agent", "app_id": app_model.id, "name": app_model.name},
lambda app_model, **_kwargs: {"id": "copied-agent", "app_id": app_model.id, "name": app_model.name},
)
with app.test_request_context(
@ -428,6 +483,127 @@ def test_agent_app_copy_uses_agent_id_and_returns_agent_detail(
}
def test_agent_api_access_uses_agent_id_and_returns_service_api_metadata(
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent_id = "00000000-0000-0000-0000-000000000001"
app_model = SimpleNamespace(
id="app-1",
enable_api=True,
api_base_url="https://api.example.test/v1",
api_rpm=60,
api_rph=600,
)
monkeypatch.setattr(roster_controller, "_resolve_agent_app_model", lambda **kwargs: app_model)
monkeypatch.setattr(roster_controller, "_agent_api_key_count", lambda app_id: 2)
response = unwrap(AgentApiAccessApi.get)(AgentApiAccessApi(), "tenant-1", agent_id)
assert response == {
"enabled": True,
"service_api_base_url": "https://api.example.test/v1",
"streaming_only": True,
"chat_endpoint": "https://api.example.test/v1/chat-messages",
"stop_endpoint": "https://api.example.test/v1/chat-messages/{task_id}/stop",
"conversations_endpoint": "https://api.example.test/v1/conversations",
"messages_endpoint": "https://api.example.test/v1/messages",
"files_upload_endpoint": "https://api.example.test/v1/files/upload",
"parameters_endpoint": "https://api.example.test/v1/parameters",
"info_endpoint": "https://api.example.test/v1/info",
"meta_endpoint": "https://api.example.test/v1/meta",
"api_rpm": 60,
"api_rph": 600,
"api_key_count": 2,
}
def test_agent_api_status_and_key_routes_resolve_backing_app(
app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent_id = "00000000-0000-0000-0000-000000000001"
api_key_id = "00000000-0000-0000-0000-000000000002"
app_model = SimpleNamespace(
id="app-1",
enable_api=False,
api_base_url="https://api.example.test/v1",
api_rpm=0,
api_rph=0,
)
captured: dict[str, object] = {}
monkeypatch.setattr(roster_controller, "_resolve_agent_app_model", lambda **kwargs: app_model)
monkeypatch.setattr(roster_controller, "_agent_api_key_count", lambda app_id: 1)
class FakeAppService:
def update_app_api_status(self, app_obj: object, enable_api: bool) -> object:
captured["enable"] = {"app": app_obj, "enable_api": enable_api}
app_model.enable_api = enable_api
return app_model
monkeypatch.setattr(roster_controller, "AppService", FakeAppService)
def fake_get_api_key_list(self, resource_id: str, tenant_id: str):
captured["list_keys"] = {"resource_id": resource_id, "tenant_id": tenant_id}
return roster_controller.ApiKeyList(data=[])
def fake_create_api_key(self, resource_id: str, tenant_id: str):
captured["create_key"] = {"resource_id": resource_id, "tenant_id": tenant_id}
return SimpleNamespace(
id=api_key_id,
type="app",
token="app-test-token",
last_used_at=None,
created_at=None,
)
def fake_delete_api_key(self, resource_id: str, key_id: str, tenant_id: str, current_user: object) -> None:
captured["delete_key"] = {
"resource_id": resource_id,
"api_key_id": key_id,
"tenant_id": tenant_id,
"current_user": current_user,
}
monkeypatch.setattr(AgentApiKeyListApi, "_get_api_key_list", fake_get_api_key_list)
monkeypatch.setattr(AgentApiKeyListApi, "_create_api_key", fake_create_api_key)
monkeypatch.setattr(AgentApiKeyApi, "_delete_api_key", fake_delete_api_key)
with app.test_request_context(
"/console/api/agent/00000000-0000-0000-0000-000000000001/api-enable",
json={"enable_api": True},
):
enabled = unwrap(AgentApiStatusApi.post)(AgentApiStatusApi(), "tenant-1", agent_id)
assert enabled["enabled"] is True
assert captured["enable"] == {"app": app_model, "enable_api": True}
keys = unwrap(AgentApiKeyListApi.get)(AgentApiKeyListApi(), "tenant-1", agent_id)
assert keys == {"data": []}
assert captured["list_keys"] == {"resource_id": "app-1", "tenant_id": "tenant-1"}
created, status = unwrap(AgentApiKeyListApi.post)(AgentApiKeyListApi(), "tenant-1", agent_id)
assert status == 201
assert created["id"] == api_key_id
assert created["token"] == "app-test-token"
assert captured["create_key"] == {"resource_id": "app-1", "tenant_id": "tenant-1"}
current_user = SimpleNamespace(id="account-1", is_admin_or_owner=True)
deleted, delete_status = unwrap(AgentApiKeyApi.delete)(
AgentApiKeyApi(),
"tenant-1",
current_user,
agent_id,
api_key_id,
)
assert (deleted, delete_status) == ("", 204)
assert captured["delete_key"] == {
"resource_id": "app-1",
"api_key_id": api_key_id,
"tenant_id": "tenant-1",
"current_user": current_user,
}
def test_agent_app_update_rejects_empty_role(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
agent_id = "00000000-0000-0000-0000-000000000001"
app_model = _app_detail_obj(id="app-1", bound_agent_id=agent_id)
@ -441,7 +617,12 @@ def test_agent_app_update_rejects_empty_role(app: Flask, monkeypatch: pytest.Mon
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_app_backing_agent",
lambda _self, **kwargs: SimpleNamespace(id=agent_id, role="", active_config_snapshot_id=None),
lambda _self, **kwargs: SimpleNamespace(
id=agent_id,
role="",
debug_conversation_id="debug-conversation-detail",
active_config_snapshot_id=None,
),
)
monkeypatch.setattr(
roster_controller.FeatureService,
@ -464,7 +645,7 @@ def test_agent_app_update_rejects_empty_role(app: Flask, monkeypatch: pytest.Mon
json={"name": "Renamed", "description": "", "role": "", "icon_type": "emoji", "icon": "R"},
):
with pytest.raises(ValueError, match="String should have at least 1 character"):
unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", agent_id)
unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", SimpleNamespace(id="account-1"), agent_id)
def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
@ -513,6 +694,13 @@ def test_agent_versions_call_services(app: Flask, monkeypatch: pytest.MonkeyPatc
],
},
)
captured_restore: dict[str, object] = {}
def restore_agent_version(_self, **kwargs):
captured_restore.update(kwargs)
return {"result": "success", "active_config_snapshot_id": kwargs["version_id"]}
monkeypatch.setattr(roster_controller.AgentRosterService, "restore_agent_version", restore_agent_version)
assert (
unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"]
@ -523,6 +711,16 @@ def test_agent_versions_call_services(app: Flask, monkeypatch: pytest.MonkeyPatc
)
assert version_detail["id"] == version_id
assert version_detail["agent_id"] == agent_id
restored = unwrap(AgentRosterVersionRestoreApi.post)(
AgentRosterVersionRestoreApi(), "tenant-1", SimpleNamespace(id="account-1"), agent_id, version_id
)
assert restored == {"result": "success", "active_config_snapshot_id": version_id}
assert captured_restore == {
"tenant_id": "tenant-1",
"agent_id": agent_id,
"version_id": version_id,
"account_id": "account-1",
}
def test_agent_observability_routes_resolve_app_from_agent_id(
@ -870,7 +1068,7 @@ def test_agent_chat_generate_and_stop_routes_resolve_app_from_agent_id(
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
) -> None:
agent_id = "00000000-0000-0000-0000-000000000001"
app_model = SimpleNamespace(id="app-1", mode="agent")
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="agent")
captured: dict[str, object] = {}
def resolve_agent_app_model(**kwargs: object) -> object:
@ -909,7 +1107,7 @@ def test_agent_chat_generate_and_stop_routes_resolve_app_from_agent_id(
def test_agent_chat_helper_forces_agent_streaming_and_external_trace(
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
) -> None:
app_model = SimpleNamespace(id="app-1", mode="agent")
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="agent")
current_user = SimpleNamespace(id=account_id)
captured: dict[str, object] = {}
@ -918,6 +1116,11 @@ def test_agent_chat_helper_forces_agent_streaming_and_external_trace(
return {"answer": "ok"}
monkeypatch.setattr(completion_controller.AppGenerateService, "generate", generate)
monkeypatch.setattr(
completion_controller,
"_resolve_current_user_agent_debug_conversation_id",
lambda **kwargs: "debug-conversation-1",
)
monkeypatch.setattr(
completion_controller.helper,
"compact_generate_response",
@ -936,10 +1139,83 @@ def test_agent_chat_helper_forces_agent_streaming_and_external_trace(
assert captured["streaming"] is True
args = cast(dict[str, object], captured["args"])
assert args["response_mode"] == "streaming"
assert args["conversation_id"] == "debug-conversation-1"
assert args["auto_generate_name"] is False
assert args["external_trace_id"] == "trace-1"
def test_agent_chat_helper_rejects_foreign_debug_conversation(
app: Flask,
monkeypatch: pytest.MonkeyPatch,
account_id: str,
) -> None:
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="agent")
monkeypatch.setattr(
completion_controller,
"_resolve_current_user_agent_debug_conversation_id",
lambda **kwargs: "owned-conversation",
)
with app.test_request_context(
json={
"inputs": {},
"query": "hello",
"response_mode": "streaming",
"conversation_id": "00000000-0000-0000-0000-000000000001",
}
):
with pytest.raises(NotFound):
completion_controller._create_chat_message(
current_tenant_id="tenant-1",
current_user=SimpleNamespace(id=account_id),
app_model=app_model,
agent_id="agent-1",
)
def test_resolve_current_user_agent_debug_conversation_uses_agent_or_backing_app(
monkeypatch: pytest.MonkeyPatch,
) -> None:
calls: list[dict[str, object]] = []
class FakeRosterService:
def __init__(self, session: object) -> None:
calls.append({"session": session})
def get_or_create_agent_app_debug_conversation_id(self, **kwargs: object) -> str:
calls.append({"get_or_create": kwargs})
return f"debug-{kwargs['agent_id']}"
def get_app_backing_agent(self, **kwargs: object) -> object:
calls.append({"get_app_backing_agent": kwargs})
return SimpleNamespace(id="backing-agent")
monkeypatch.setattr(completion_controller, "AgentRosterService", FakeRosterService)
monkeypatch.setattr(completion_controller, "db", SimpleNamespace(session="session-1"))
explicit_id = completion_controller._resolve_current_user_agent_debug_conversation_id(
current_tenant_id="tenant-1",
current_user=SimpleNamespace(id="account-1"),
app_model=SimpleNamespace(id="app-1"),
agent_id="agent-1",
)
fallback_id = completion_controller._resolve_current_user_agent_debug_conversation_id(
current_tenant_id="tenant-1",
current_user=SimpleNamespace(id="account-1"),
app_model=SimpleNamespace(id="app-1"),
agent_id=None,
)
assert explicit_id == "debug-agent-1"
assert fallback_id == "debug-backing-agent"
assert calls[1] == {"get_or_create": {"tenant_id": "tenant-1", "agent_id": "agent-1", "account_id": "account-1"}}
assert calls[3] == {"get_app_backing_agent": {"tenant_id": "tenant-1", "app_id": "app-1"}}
assert calls[4] == {
"get_or_create": {"tenant_id": "tenant-1", "agent_id": "backing-agent", "account_id": "account-1"}
}
@pytest.mark.parametrize(
("error", "expected"),
[
@ -986,7 +1262,7 @@ def test_agent_chat_helper_maps_generation_errors(
def test_agent_chat_message_routes_resolve_app_from_agent_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
agent_id = "00000000-0000-0000-0000-000000000001"
message_id = "00000000-0000-0000-0000-000000000002"
app_model = SimpleNamespace(id="app-1")
app_model = SimpleNamespace(id="app-1", mode="agent")
current_user = SimpleNamespace(id="account-1")
captured: dict[str, object] = {}
@ -1016,7 +1292,9 @@ def test_agent_chat_message_routes_resolve_app_from_agent_id(app: Flask, monkeyp
monkeypatch.setattr(message_controller, "_get_message_suggested_questions", get_message_suggested_questions)
monkeypatch.setattr(message_controller, "_get_message_detail", get_message_detail)
assert unwrap(AgentChatMessageListApi.get)(AgentChatMessageListApi(), "tenant-1", agent_id) == {"data": []}
assert unwrap(AgentChatMessageListApi.get)(AgentChatMessageListApi(), "tenant-1", current_user, agent_id) == {
"data": []
}
assert cast(dict[str, object], captured["list"])["app_model"] is app_model
with app.test_request_context(json={"message_id": message_id, "rating": "like"}):
@ -1073,11 +1351,73 @@ def test_list_chat_messages_supports_first_id_pagination(app: Flask, monkeypatch
"/console/api/agent/agent-1/chat-messages"
f"?conversation_id={conversation_id}&first_id={first_message_id}&limit=1"
):
result = message_controller._list_chat_messages(app_model=SimpleNamespace(id="app-1"))
result = message_controller._list_chat_messages(app_model=SimpleNamespace(id="app-1", mode="chat"))
assert result == {"data": [older_message_id], "limit": 1, "has_more": True}
def test_list_agent_chat_messages_uses_current_user_conversation(
app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
conversation_id = "00000000-0000-0000-0000-000000000010"
message_id = "00000000-0000-0000-0000-000000000011"
conversation = SimpleNamespace(id=conversation_id)
message = SimpleNamespace(id=message_id, created_at=1)
current_user = SimpleNamespace(id="account-1")
app_model = SimpleNamespace(id="app-1", mode="agent")
captured: dict[str, object] = {}
session = SimpleNamespace(
scalar=lambda _stmt: False,
scalars=lambda _stmt: SimpleNamespace(all=lambda: [message]),
)
class FakeMessagePaginationResponse:
@classmethod
def model_validate(cls, pagination: object, from_attributes: bool = False) -> object:
return SimpleNamespace(
model_dump=lambda mode: {
"data": [item.id for item in pagination.data],
"limit": pagination.limit,
"has_more": pagination.has_more,
}
)
def get_conversation(**kwargs: object) -> object:
captured.update(kwargs)
return conversation
monkeypatch.setattr(message_controller.ConversationService, "get_conversation", get_conversation)
monkeypatch.setattr(message_controller, "db", SimpleNamespace(session=session))
monkeypatch.setattr(message_controller, "attach_message_extra_contents", lambda messages: None)
monkeypatch.setattr(message_controller, "MessageInfiniteScrollPaginationResponse", FakeMessagePaginationResponse)
with app.test_request_context(f"/console/api/agent/agent-1/chat-messages?conversation_id={conversation_id}"):
result = message_controller._list_chat_messages(app_model=app_model, current_user=current_user)
assert result == {"data": [message_id], "limit": 20, "has_more": False}
assert captured == {"app_model": app_model, "conversation_id": conversation_id, "user": current_user}
def test_list_agent_chat_messages_rejects_foreign_conversation(
app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
conversation_id = "00000000-0000-0000-0000-000000000010"
monkeypatch.setattr(
message_controller.ConversationService,
"get_conversation",
lambda **kwargs: (_ for _ in ()).throw(message_controller.ConversationNotExistsError()),
)
with app.test_request_context(f"/console/api/agent/agent-1/chat-messages?conversation_id={conversation_id}"):
with pytest.raises(NotFound):
message_controller._list_chat_messages(
app_model=SimpleNamespace(id="app-1", mode="agent"),
current_user=SimpleNamespace(id="account-1"),
)
def test_update_message_feedback_rejects_empty_rating_without_existing_feedback(
app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:

View File

@ -20,6 +20,10 @@ from controllers.console.app.agent_drive_inspector import (
AgentDriveListByAgentApi,
AgentDrivePreviewApi,
AgentDrivePreviewByAgentApi,
AgentDriveSkillInspectApi,
AgentDriveSkillInspectByAgentApi,
AgentDriveSkillListApi,
AgentDriveSkillListByAgentApi,
)
from services.agent_drive_service import AgentDriveError
@ -97,6 +101,124 @@ def test_list_resolves_workflow_node_binding_agent():
assert composer.resolve_workflow_node_agent_id.call_args.kwargs["node_id"] == "agent-node-1"
def test_skill_list_by_agent_calls_service():
raw = _raw(AgentDriveSkillListByAgentApi.get)
with app.test_request_context("/"):
with (
patch(f"{_MOD}.resolve_agent_app_model", return_value=_APP) as resolve_app,
patch(f"{_MOD}.AgentDriveService") as drive,
):
drive.return_value.list_skills.return_value = [
{
"path": "pdf-toolkit",
"skill_md_key": "pdf-toolkit/SKILL.md",
"archive_key": "pdf-toolkit/.DIFY-SKILL-FULL.zip",
"name": "PDF Toolkit",
"description": "Work with PDFs.",
"size": 5,
"mime_type": "text/markdown",
"hash": None,
"created_at": 1718000000,
}
]
body = raw(AgentDriveSkillListByAgentApi(), "tenant-1", "agent-1")
assert body["items"][0]["path"] == "pdf-toolkit"
resolve_app.assert_called_once_with(tenant_id="tenant-1", agent_id="agent-1")
assert drive.return_value.list_skills.call_args.kwargs["agent_id"] == "agent-1"
def test_skill_list_resolves_workflow_node_binding_agent():
raw = _raw(AgentDriveSkillListApi.get)
with app.test_request_context("/?node_id=agent-node-1"):
with (
patch(f"{_MOD}.AgentComposerService") as composer,
patch(f"{_MOD}.AgentDriveService") as drive,
):
composer.resolve_workflow_node_agent_id.return_value = "wf-agent-9"
drive.return_value.list_skills.return_value = []
body = raw(AgentDriveSkillListApi(), _APP)
assert body == {"items": []}
assert drive.return_value.list_skills.call_args.kwargs["agent_id"] == "wf-agent-9"
def test_skill_inspect_by_agent_returns_strict_json_response():
raw = _raw(AgentDriveSkillInspectByAgentApi.get)
payload = {
"path": "pdf-toolkit",
"skill_md_key": "pdf-toolkit/SKILL.md",
"archive_key": "pdf-toolkit/.DIFY-SKILL-FULL.zip",
"name": "PDF Toolkit",
"description": "Work with PDFs.",
"size": 5,
"mime_type": "text/markdown",
"hash": None,
"created_at": 1718000000,
"source": "skill_md",
"files": [
{
"path": "SKILL.md",
"name": "SKILL.md",
"type": "file",
"drive_key": "pdf-toolkit/SKILL.md",
"available_in_drive": True,
}
],
"file_tree": [],
"skill_md": {
"key": "pdf-toolkit/SKILL.md",
"size": 5,
"truncated": False,
"binary": False,
"text": "# PDF Toolkit\nUse it.\n",
},
"warnings": [],
}
with app.test_request_context("/"):
with (
patch(f"{_MOD}.resolve_agent_app_model", return_value=_APP),
patch(f"{_MOD}.AgentDriveService") as drive,
):
drive.return_value.inspect_skill.return_value = payload
response = raw(AgentDriveSkillInspectByAgentApi(), "tenant-1", "agent-1", "pdf-toolkit")
assert response.status_code == 200
assert response.get_json()["skill_md"]["text"] == "# PDF Toolkit\nUse it.\n"
assert b"# PDF Toolkit\\nUse it.\\n" in response.get_data()
def test_skill_inspect_resolves_workflow_node_binding_agent():
raw = _raw(AgentDriveSkillInspectApi.get)
payload = {
"path": "pdf-toolkit",
"skill_md_key": "pdf-toolkit/SKILL.md",
"archive_key": None,
"name": "PDF Toolkit",
"description": "",
"size": 5,
"mime_type": "text/markdown",
"hash": None,
"created_at": None,
"source": "skill_md",
"files": [],
"file_tree": [],
"skill_md": {"key": "pdf-toolkit/SKILL.md", "size": 5, "truncated": False, "binary": False, "text": "# hi"},
"warnings": [],
}
with app.test_request_context("/?node_id=agent-node-1"):
with (
patch(f"{_MOD}.AgentComposerService") as composer,
patch(f"{_MOD}.AgentDriveService") as drive,
):
composer.resolve_workflow_node_agent_id.return_value = "wf-agent-9"
drive.return_value.inspect_skill.return_value = payload
response = raw(AgentDriveSkillInspectApi(), _APP, "pdf-toolkit")
assert response.get_json()["path"] == "pdf-toolkit"
assert drive.return_value.inspect_skill.call_args.kwargs["agent_id"] == "wf-agent-9"
def test_list_400_when_no_agent_bound():
raw = _raw(AgentDriveListApi.get)
app_without_agent = SimpleNamespace(id="app-1", tenant_id="tenant-1", bound_agent_id=None)

View File

@ -185,7 +185,7 @@ class TestPaginationMapping:
"name": "owner",
"description": "",
"is_builtin": True,
"permission_keys": list(rbac_mod._LEGACY_ROLE_PERMISSION_KEYS["owner"]),
"permission_keys": list(dict.fromkeys(rbac_mod._LEGACY_ROLE_PERMISSION_KEYS["owner"])),
"role_tag": "owner",
},
{
@ -196,7 +196,7 @@ class TestPaginationMapping:
"name": "admin",
"description": "",
"is_builtin": True,
"permission_keys": list(rbac_mod._LEGACY_ROLE_PERMISSION_KEYS["admin"]),
"permission_keys": list(dict.fromkeys(rbac_mod._LEGACY_ROLE_PERMISSION_KEYS["admin"])),
"role_tag": "",
},
]
@ -336,23 +336,6 @@ class TestResourceAccessScopeBindings:
class TestPaginationForwarding:
def test_role_members_get_forwards_outer_pagination_params(self, app):
with (
app.test_request_context("/workspaces/current/rbac/roles/role-1/members?page=2&limit=50&reverse=true"),
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
patch("controllers.console.workspace.rbac.svc.RBACService.Roles.members") as mock_members,
patch("controllers.console.workspace.rbac._dump", return_value={}),
):
inspect.unwrap(rbac_mod.RBACRoleMembersApi.get)(rbac_mod.RBACRoleMembersApi(), "role-1")
_, _, role_id = mock_members.call_args.args
_, kwargs = mock_members.call_args
assert role_id == "role-1"
options = kwargs["options"]
assert options.page_number == 2
assert options.results_per_page == 50
assert options.reverse is True
def test_access_policies_get_forwards_outer_pagination_params(self, app):
with (
app.test_request_context(

View File

@ -1,16 +1,21 @@
import uuid
from controllers.openapi.auth.composition import account_pipeline, auth_router, external_sso_pipeline
from controllers.openapi.auth.data import RequestContext
from controllers.openapi.auth.data import RBACRequirement, RequestContext
from controllers.openapi.auth.flow import When
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
from controllers.openapi.auth.verify import (
check_acl,
check_private_app_permission,
check_rbac_permission,
check_workspace_member,
check_workspace_mismatch,
check_workspace_role,
)
from libs.oauth_bearer import TokenType
from core.rbac import RBACPermission, RBACResourceScope
from libs.oauth_bearer import Scope, TokenType
from models.account import TenantAccountRole
from services.enterprise.enterprise_service import WebAppAccessMode
def test_account_pipeline_is_auth_pipeline():
@ -29,8 +34,8 @@ def test_account_pipeline_prepare_has_six_entries():
assert len(account_pipeline._prepare) == 6
def test_account_auth_list_has_seven_entries():
assert len(account_pipeline._auth) == 7
def test_account_auth_list_has_eight_entries():
assert len(account_pipeline._auth) == 8
def test_external_sso_pipeline_prepare_has_four_entries():
@ -132,3 +137,89 @@ def test_app_path_selects_workspace_mismatch_check():
def test_workspace_path_skips_workspace_mismatch_check():
steps = _selected_auth_steps(app_id=False, workspace_membership=True, allowed_roles=None)
assert check_workspace_mismatch not in steps
def _selected_webapp_steps(*, scope, app_access_mode):
"""Select auth steps for an EE, webapp-auth-enabled, app-scoped request.
Patches the config-backed conditions (edition + webapp_auth) so the gating
reduces to PATH_HAS_APP_ID, LOADED_APP_IS_PRIVATE, and the request scope.
"""
from unittest.mock import MagicMock, patch
from controllers.openapi.auth.data import AuthData, Edition
ctx = RequestContext(
token_type=TokenType.OAUTH_ACCOUNT,
scope=scope,
path_params={"app_id": str(uuid.uuid4())},
)
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="x",
scopes=frozenset({scope}) if scope is not None else frozenset(),
app_access_mode=app_access_mode,
)
features = MagicMock()
features.webapp_auth.enabled = True
selected = []
with (
patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.EE),
patch("controllers.openapi.auth.conditions.FeatureService.get_system_features", return_value=features),
):
for step in account_pipeline._auth:
if isinstance(step, When):
if step.applies(ctx, data):
selected.append(step._step)
else:
selected.append(step)
return selected
def test_apps_run_scope_selects_webapp_checks():
steps = _selected_webapp_steps(scope=Scope.APPS_RUN, app_access_mode=WebAppAccessMode.PRIVATE)
assert check_acl in steps
assert check_private_app_permission in steps
def test_management_scope_skips_webapp_checks_on_private_app():
# Export DSL et al. carry an app_id but use a management scope; the webapp
# end-user ACL / private-app gate must not block workspace members.
steps = _selected_webapp_steps(scope=Scope.APPS_READ, app_access_mode=WebAppAccessMode.PRIVATE)
assert check_acl not in steps
assert check_private_app_permission not in steps
def _selected_auth_steps_with_rbac(rbac):
ctx = RequestContext(
token_type=TokenType.OAUTH_ACCOUNT,
scope=Scope.APPS_READ,
path_params={"app_id": str(uuid.uuid4())},
rbac=rbac,
)
selected = []
for step in account_pipeline._auth:
if isinstance(step, When):
if step.applies(ctx, None):
selected.append(step._step)
else:
selected.append(step)
return selected
def test_account_pipeline_selects_rbac_step_when_required():
rbac = RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_VIEW_LAYOUT)
assert check_rbac_permission in _selected_auth_steps_with_rbac(rbac)
def test_account_pipeline_skips_rbac_step_without_requirement():
assert check_rbac_permission not in _selected_auth_steps_with_rbac(None)
def test_external_sso_pipeline_never_enforces_rbac():
# RBAC is a console (account) concern; external SSO callers are scope-gated.
rbac_steps = [
s._step for s in external_sso_pipeline._auth if isinstance(s, When) and s._step is check_rbac_permission
]
assert rbac_steps == []
assert check_rbac_permission not in external_sso_pipeline._auth

View File

@ -5,19 +5,22 @@ from controllers.openapi.auth.conditions import (
EDITION_EE,
EDITION_SAAS,
HAS_ALLOWED_ROLES,
HAS_RBAC,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
TOKEN_IS_OAUTH_ACCOUNT,
TOKEN_IS_OAUTH_EXTERNAL_SSO,
WEBAPP_AUTH_ENABLED,
WEBAPP_RUN_SCOPED,
WORKSPACE_MEMBERSHIP_REQUIRED,
Cond,
config_cond,
data_cond,
request_cond,
)
from controllers.openapi.auth.data import AuthData, Edition, RequestContext
from libs.oauth_bearer import TokenType
from controllers.openapi.auth.data import AuthData, Edition, RBACRequirement, RequestContext
from core.rbac import RBACPermission, RBACResourceScope
from libs.oauth_bearer import Scope, TokenType
from models.account import TenantAccountRole
from services.enterprise.enterprise_service import WebAppAccessMode
@ -137,6 +140,34 @@ def test_webapp_auth_enabled():
assert WEBAPP_AUTH_ENABLED(_ctx()) is True
def test_webapp_run_scoped_true_for_apps_run():
assert WEBAPP_RUN_SCOPED(_ctx(scope=Scope.APPS_RUN)) is True
def test_webapp_run_scoped_false_for_management_scope():
assert WEBAPP_RUN_SCOPED(_ctx(scope=Scope.APPS_READ)) is False
def test_webapp_run_scoped_false_when_scope_none():
assert WEBAPP_RUN_SCOPED(_ctx()) is False
def _rbac_req():
return RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN)
def test_has_rbac_true():
assert HAS_RBAC(_ctx(rbac=_rbac_req())) is True
def test_has_rbac_false():
assert HAS_RBAC(_ctx(rbac=None)) is False
def test_has_rbac_default():
assert HAS_RBAC(_ctx()) is False
def test_loaded_app_is_private():
data_private = _data(app_access_mode=WebAppAccessMode.PRIVATE)
data_public = _data(app_access_mode=WebAppAccessMode.PUBLIC)

View File

@ -5,17 +5,19 @@ import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.data import AuthData, RBACRequirement
from controllers.openapi.auth.verify import (
check_acl,
check_app_access,
check_app_api_enabled,
check_private_app_permission,
check_rbac_permission,
check_scope,
check_workspace_member,
check_workspace_mismatch,
check_workspace_role,
)
from core.rbac import RBACPermission, RBACResourceScope
from libs.oauth_bearer import Scope, TokenType
from models.account import Tenant, TenantAccountRole
from models.model import App
@ -75,6 +77,67 @@ def test_check_app_access_raises_when_not_member():
check_app_access(data)
# --- check_rbac_permission ---
_RBAC_REQ = RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_VIEW_LAYOUT)
def test_check_rbac_noop_when_no_requirement():
with patch("controllers.openapi.auth.verify.enforce_rbac_access") as mock_enforce:
check_rbac_permission(_data(rbac=None, caller_kind="account"))
mock_enforce.assert_not_called()
def test_check_rbac_noop_when_rbac_disabled():
with (
patch("controllers.openapi.auth.verify.dify_config.RBAC_ENABLED", False),
patch("controllers.openapi.auth.verify.enforce_rbac_access") as mock_enforce,
):
check_rbac_permission(_data(rbac=_RBAC_REQ, caller_kind="account"))
mock_enforce.assert_not_called()
def test_check_rbac_skips_end_user_caller():
with (
patch("controllers.openapi.auth.verify.dify_config.RBAC_ENABLED", True),
patch("controllers.openapi.auth.verify.enforce_rbac_access") as mock_enforce,
):
check_rbac_permission(_data(rbac=_RBAC_REQ, caller_kind="end_user"))
mock_enforce.assert_not_called()
def test_check_rbac_raises_when_context_missing():
with patch("controllers.openapi.auth.verify.dify_config.RBAC_ENABLED", True):
with pytest.raises(Forbidden, match="rbac context missing"):
check_rbac_permission(_data(rbac=_RBAC_REQ, caller_kind="account", account_id=None, tenant=None))
def test_check_rbac_enforces_for_account_caller():
tenant = MagicMock(spec=Tenant)
tenant.id = "t1"
account_id = uuid.uuid4()
data = _data(
rbac=_RBAC_REQ,
caller_kind="account",
account_id=account_id,
tenant=tenant,
path_params={"app_id": "app-1"},
)
with (
patch("controllers.openapi.auth.verify.dify_config.RBAC_ENABLED", True),
patch("controllers.openapi.auth.verify.enforce_rbac_access") as mock_enforce,
):
check_rbac_permission(data)
mock_enforce.assert_called_once_with(
tenant_id="t1",
account_id=str(account_id),
resource_type=RBACResourceScope.APP,
scene=RBACPermission.APP_VIEW_LAYOUT,
resource_required=True,
path_args={"app_id": "app-1"},
)
def test_check_acl_raises_when_app_or_mode_missing():
with pytest.raises(Forbidden):
check_acl(_data(app=None, app_access_mode=None))

View File

@ -20,6 +20,7 @@ def _stub_execute(
edition=None,
workspace_membership=False,
allowed_roles=None,
rbac=None,
):
"""Bypass all auth logic; inject minimal AuthData and call the view directly."""
kwargs["auth_data"] = AuthData(
@ -30,6 +31,7 @@ def _stub_execute(
scopes=frozenset({Scope.FULL}),
required_scope=scope,
allowed_roles=allowed_roles,
rbac=rbac,
)
return view(*args, **kwargs)

View File

@ -0,0 +1,73 @@
from types import SimpleNamespace
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA
from controllers.openapi.apps import _EMPTY_PARAMETERS, build_app_describe_response
from controllers.service_api.app.error import AppUnavailableError
class _FakeApp(SimpleNamespace):
pass
def _app() -> _FakeApp:
from datetime import datetime
return _FakeApp(
id="11111111-1111-1111-1111-111111111111",
name="Demo",
mode="chat",
description="d",
tags=[],
author_name="me",
updated_at=datetime(2026, 1, 1),
enable_api=True,
)
def test_fields_none_returns_all_blocks(monkeypatch):
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", lambda app: {"k": "v"})
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", lambda app: {"s": 1})
resp = build_app_describe_response(_app(), None)
assert resp.info is not None
assert resp.info.name == "Demo"
assert resp.parameters == {"k": "v"}
assert resp.input_schema == {"s": 1}
def test_fields_subset_limits_blocks(monkeypatch):
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", lambda app: {"k": "v"})
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", lambda app: {"s": 1})
resp = build_app_describe_response(_app(), ["info"])
assert resp.info is not None
assert resp.parameters is None
assert resp.input_schema is None
def test_info_omits_author_and_tags(monkeypatch):
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", lambda app: {})
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", lambda app: {})
resp = build_app_describe_response(_app(), ["info"])
assert resp.info is not None
# Usage-face describe must not expose creator identity or tags (cross-tenant leak).
assert not hasattr(resp.info, "author")
assert not hasattr(resp.info, "tags")
def test_parameters_fallback_on_app_unavailable(monkeypatch):
def _raise(app):
raise AppUnavailableError()
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", _raise)
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", lambda app: {"s": 1})
resp = build_app_describe_response(_app(), ["parameters"])
assert resp.parameters == dict(_EMPTY_PARAMETERS)
def test_input_schema_fallback_on_app_unavailable(monkeypatch):
def _raise(app):
raise AppUnavailableError()
monkeypatch.setattr("controllers.openapi.apps.parameters_payload", lambda app: {"k": "v"})
monkeypatch.setattr("controllers.openapi.apps.build_input_schema", _raise)
resp = build_app_describe_response(_app(), ["input_schema"])
assert resp.input_schema == dict(EMPTY_INPUT_SCHEMA)

View File

@ -5,7 +5,7 @@ Runs against the model directly, not the HTTP layer. Pins:
- workspace_id is required.
- numeric bounds enforced (page >= 1, limit in [1, MAX_PAGE_LIMIT]).
- mode validates against the AppMode enum.
- name and tag have length caps.
- name has a length cap.
"""
from __future__ import annotations
@ -24,7 +24,6 @@ def test_defaults():
assert q.limit == 20
assert q.mode is None
assert q.name is None
assert q.tag is None
def test_workspace_id_required():
@ -80,12 +79,6 @@ def test_name_length_capped():
AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "name": "x" * 201})
def test_tag_length_capped():
AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "tag": "x" * 100})
with pytest.raises(ValidationError):
AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "tag": "x" * 101})
def test_all_fields_accept_valid_values():
"""Pin the happy-path acceptance for every field in one place."""
q = AppListQuery.model_validate(
@ -95,7 +88,6 @@ def test_all_fields_accept_valid_values():
"limit": 50,
"mode": "workflow",
"name": "search",
"tag": "prod",
}
)
assert q.workspace_id == "00000000-0000-0000-0000-000000000001"
@ -104,4 +96,3 @@ def test_all_fields_accept_valid_values():
assert q.mode is not None
assert q.mode.value == "workflow"
assert q.name == "search"
assert q.tag == "prod"

View File

@ -26,11 +26,13 @@ from controllers.openapi._errors import (
ErrorBody,
ErrorDetail,
FilenameNotExists,
HumanInputFormNotFound,
MemberLicenseExceeded,
MemberLimitExceeded,
OpenApiError,
OpenApiErrorCode,
OpenApiErrorFormatter,
RecipientSurfaceMismatch,
)
from controllers.service_api.app.error import (
AppUnavailableError,
@ -319,6 +321,8 @@ ERROR_MATRIX = [
(BlockedFileExtensionError(), 400, "file_extension_blocked"),
(MemberLimitExceeded(), 403, "member_limit_exceeded"),
(MemberLicenseExceeded(), 403, "member_license_exceeded"),
(HumanInputFormNotFound(), 404, "form_not_found"),
(RecipientSurfaceMismatch(), 403, "recipient_surface_mismatch"),
]

View File

@ -11,8 +11,9 @@ from unittest.mock import Mock
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound, UnprocessableEntity
from werkzeug.exceptions import UnprocessableEntity
from controllers.openapi._errors import HumanInputFormNotFound, RecipientSurfaceMismatch
from controllers.openapi.auth.data import AuthData
from libs.oauth_bearer import Scope, TokenType
from models.human_input import RecipientType
@ -89,7 +90,7 @@ class TestOpenApiHumanInputFormGet:
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/bad"):
with pytest.raises(NotFound):
with pytest.raises(HumanInputFormNotFound):
api.get.__wrapped__(
api,
app_id="app-1",
@ -101,7 +102,10 @@ class TestOpenApiHumanInputFormGet:
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
form = SimpleNamespace(
app_id="other-app", tenant_id="tenant-1", expiration_time=datetime(2099, 1, 1, tzinfo=UTC)
app_id="other-app",
tenant_id="tenant-1",
recipient_type=RecipientType.STANDALONE_WEB_APP,
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
@ -114,7 +118,7 @@ class TestOpenApiHumanInputFormGet:
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
with pytest.raises(NotFound):
with pytest.raises(HumanInputFormNotFound):
api.get.__wrapped__(
api,
app_id="app-1",
@ -142,7 +146,7 @@ class TestOpenApiHumanInputFormGet:
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
with pytest.raises(NotFound):
with pytest.raises(RecipientSurfaceMismatch):
api.get.__wrapped__(
api,
app_id="app-1",
@ -234,6 +238,38 @@ class TestOpenApiHumanInputFormPost:
)
assert result == ({}, 200)
def test_post_standalone_web_app_recipient_submits(
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
):
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
form = self._make_form(recipient_type=RecipientType.STANDALONE_WEB_APP)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
module = sys.modules["controllers.openapi.human_input_form"]
monkeypatch.setattr(module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="anyone")
with app.test_request_context(
"/openapi/v1/apps/app-1/form/human_input/tok-1",
method="POST",
json={"action": "approve", "inputs": {}},
):
result = api.post.__wrapped__(
api,
app_id="app-1",
form_token="tok-1",
auth_data=_make_auth_data(app_model, caller, "end_user"),
)
service_mock.submit_form_by_token.assert_called_once()
assert result == ({}, 200)
def test_post_rejects_invalid_body_with_422(self, app: Flask, bypass_pipeline):
"""Malformed body → 422 via @accepts (was an unmapped pydantic error → 500)."""
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi

View File

@ -63,23 +63,19 @@ def test_envelope_uses_pep695_generics():
def test_app_info_response_dump_matches_spec():
from controllers.openapi._models import AppInfoResponse
from controllers.openapi._models import AppInfo
obj = AppInfoResponse(
obj = AppInfo(
id="app1",
name="X",
description="d",
mode="chat",
author="alice",
tags=[{"name": "prod"}],
)
assert obj.model_dump(mode="json") == {
"id": "app1",
"name": "X",
"description": "d",
"mode": "chat",
"author": "alice",
"tags": [{"name": "prod"}],
}
@ -91,8 +87,6 @@ def test_app_describe_response_nests_info_and_parameters():
name="X",
mode="chat",
description=None,
tags=[],
author=None,
updated_at="2026-05-05T00:00:00+00:00",
service_api_enabled=True,
)

View File

@ -136,6 +136,55 @@ class TestAppParameterApi:
assert "user_input_form" in response
assert "opening_statement" in response
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.app.app._get_agent_app_feature_dict_and_user_input_form")
def test_get_parameters_for_agent_app(
self,
mock_get_agent_parameters,
mock_db,
mock_validate_token,
mock_current_app,
mock_user_logged_in,
app: Flask,
mock_app_model,
):
"""Test retrieving parameters for an Agent App from Agent Soul app variables."""
_configure_current_app_mock(mock_current_app)
mock_app_model.mode = AppMode.AGENT
mock_app_model.app_model_config = None
mock_app_model.workflow = None
mock_get_agent_parameters.return_value = (
{"opening_statement": "Hi from Agent"},
[{"text-input": {"label": "topic", "variable": "topic", "required": True}}],
)
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_db.session.get.side_effect = [mock_app_model, mock_tenant]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppParameterApi()
response = api.get()
assert response["opening_statement"] == "Hi from Agent"
assert response["user_input_form"] == [
{"text-input": {"label": "topic", "variable": "topic", "required": True}}
]
mock_get_agent_parameters.assert_called_once_with(mock_app_model)
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")

View File

@ -29,7 +29,7 @@ from core.app.entities.task_entities import (
WorkflowPauseStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from core.workflow.human_input_policy import HumanInputSurface
from core.workflow.human_input_policy import FormDisposition, HumanInputSurface
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
@ -592,8 +592,10 @@ class TestHitlServiceApi:
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
workflow_response_converter,
"load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "token"},
"load_form_dispositions_by_form_id",
lambda form_ids, session=None, surface=None: {
"form-1": FormDisposition(form_token="token", approval_channels=[])
},
)
reason = HumanInputRequired(
@ -652,8 +654,10 @@ class TestHitlServiceApi:
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
resumption_context = _build_resumption_context("task-ctx")
monkeypatch.setattr(
"services.workflow_event_snapshot_service.load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "wtok"},
"services.workflow_event_snapshot_service.load_form_dispositions_by_form_id",
lambda form_ids, session=None, surface=None: {
"form-1": FormDisposition(form_token="wtok", approval_channels=[])
},
)
class _SessionContext:

View File

@ -25,6 +25,15 @@ MINIMAL_GRAPH = {
}
def _patch_create_session(mock_session: MagicMock):
session_context = MagicMock()
session_context.__enter__.return_value = mock_session
session_context.__exit__.return_value = False
mock_session.begin.return_value.__enter__.return_value = mock_session
mock_session.begin.return_value.__exit__.return_value = False
return patch("core.app.apps.advanced_chat.app_runner.create_session", return_value=session_context)
class TestAdvancedChatAppRunnerConversationVariables:
"""Test that AdvancedChatAppRunner correctly handles conversation variables."""
@ -135,10 +144,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
_patch_create_session(mock_session),
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(
runner,
@ -151,12 +158,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool
mock_graph_runtime_state_class.return_value = MagicMock()
@ -281,10 +282,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
_patch_create_session(mock_session),
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(
runner,
@ -298,12 +297,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.engine = MagicMock()
# Mock ConversationVariable.from_variable to return mock objects
mock_conv_vars = []
for var in workflow_vars:
@ -434,10 +427,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
_patch_create_session(mock_session),
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(
runner,
@ -450,12 +441,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool
mock_graph_runtime_state_class.return_value = MagicMock()

View File

@ -3,6 +3,7 @@ from uuid import uuid4
import pytest
import core.app.apps.advanced_chat.app_runner as module
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.queue_entities import QueueStopEvent
@ -85,27 +86,24 @@ def build_runner():
def _patch_common_run_deps(runner: AdvancedChatAppRunner):
"""Context manager that patches common heavy deps used by run()."""
# create_session() returns a context manager whose body yields a session that
# supports both scalar() (app record lookup) and begin()/scalars().all()
# (conversation variable initialization).
mock_session = MagicMock()
mock_session.scalar.return_value = MagicMock()
mock_session.scalars.return_value.all.return_value = []
session_context = MagicMock()
session_context.__enter__.return_value = mock_session
session_context.__exit__.return_value = False
mock_session.begin.return_value.__enter__.return_value = mock_session
mock_session.begin.return_value.__exit__.return_value = False
return patch.multiple(
"core.app.apps.advanced_chat.app_runner",
Session=MagicMock(
return_value=MagicMock(
__enter__=lambda s: s,
__exit__=lambda *a, **k: False,
scalar=lambda *a, **k: MagicMock(),
),
),
sessionmaker=MagicMock(
return_value=MagicMock(
begin=MagicMock(
return_value=MagicMock(
__enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))),
__exit__=lambda *a, **k: False,
),
),
),
),
create_session=MagicMock(return_value=session_context),
select=MagicMock(),
db=MagicMock(engine=MagicMock()),
session_factory=MagicMock(get_session_maker=MagicMock(return_value=MagicMock())),
RedisChannel=MagicMock(),
redis_client=MagicMock(),
WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}),
@ -192,3 +190,42 @@ def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_
# Ensure no further steps executed
mock_anno.assert_not_called()
mock_init_graph.assert_not_called()
def test_run_closes_scoped_session_before_workflow_run(build_runner):
runner = build_runner
events = []
mock_session = MagicMock()
mock_session.scalar.return_value = MagicMock()
session_context = MagicMock()
session_context.__enter__.return_value = mock_session
session_context.__exit__.return_value = False
workflow_entry = MagicMock()
def run_workflow():
events.append("run")
return iter([])
workflow_entry.run.side_effect = run_workflow
with (
patch.object(module, "create_session", return_value=session_context),
patch.object(module, "session_factory", MagicMock(get_session_maker=MagicMock(return_value=MagicMock()))),
patch.object(module, "RedisChannel"),
patch.object(module, "redis_client"),
patch.object(module, "WorkflowEntry", return_value=workflow_entry),
patch.object(module.db.session, "close", side_effect=lambda: events.append("close")),
patch.object(
runner,
"handle_input_moderation",
return_value=(False, runner.application_generate_entity.inputs, runner.application_generate_entity.query),
),
patch.object(runner, "handle_annotation_reply", return_value=False),
patch.object(runner, "_initialize_conversation_variables", return_value=[]),
patch.object(runner, "_init_graph", return_value=MagicMock()),
):
runner.run()
assert events == ["close", "run"]

View File

@ -175,6 +175,7 @@ class TestAdvancedChatGenerateTaskPipeline:
"actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
"display_in_ui": True,
"form_token": "token-1",
"approval_channels": [],
"resolved_default_values": {},
"expiration_time": 123,
}

View File

@ -19,6 +19,10 @@ def _soul() -> AgentSoulConfig:
"model_settings": {"temperature": 0.2},
},
"prompt": {"system_prompt": "You are Iris."},
"app_variables": [
{"name": "topic", "type": "string", "required": True},
{"name": "count", "type": "number", "default": 3},
],
}
)
@ -32,7 +36,10 @@ def test_model_and_prompt_come_from_soul():
"completion_params": {"temperature": 0.2},
}
assert d["pre_prompt"] == "You are Iris."
assert d["user_input_form"] == []
assert d["user_input_form"] == [
{"text-input": {"label": "topic", "variable": "topic", "required": True}},
{"number": {"label": "count", "variable": "count", "required": False, "default": 3}},
]
def test_feature_flags_come_from_app_model_config_when_present():

View File

@ -13,12 +13,24 @@ def runner():
return AgentChatAppRunner()
def patch_create_session(mocker: MockerFixture, *, return_value=None, side_effect=None):
session = mocker.MagicMock()
if side_effect is not None:
session.scalar.side_effect = side_effect
else:
session.scalar.return_value = return_value
session_context = mocker.MagicMock()
session_context.__enter__.return_value = session
mocker.patch("core.app.apps.agent_chat.app_runner.create_session", return_value=session_context)
return session
class TestAgentChatAppRunnerRun:
def test_run_app_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
patch_create_session(mocker, return_value=None)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
@ -37,7 +49,7 @@ class TestAgentChatAppRunnerRun:
conversation_id=None,
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
mocker.patch.object(runner, "direct_output")
@ -62,7 +74,7 @@ class TestAgentChatAppRunnerRun:
invoke_from=mocker.MagicMock(),
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
annotation = mocker.MagicMock(id="anno", content="answer")
@ -91,7 +103,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -121,7 +133,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -163,7 +175,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -179,10 +191,7 @@ class TestAgentChatAppRunnerRun:
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
patch_create_session(mocker, side_effect=[app_record, conversation, message])
runner_cls = mocker.MagicMock()
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
@ -219,7 +228,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -235,10 +244,7 @@ class TestAgentChatAppRunnerRun:
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
patch_create_session(mocker, side_effect=[app_record, conversation, message])
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
@ -267,7 +273,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -283,10 +289,7 @@ class TestAgentChatAppRunnerRun:
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
patch_create_session(mocker, side_effect=[app_record, conversation, message])
runner_cls = mocker.MagicMock()
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
@ -323,10 +326,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, None],
)
patch_create_session(mocker, side_effect=[app_record, None])
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -357,10 +357,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
)
patch_create_session(mocker, side_effect=[app_record, mocker.MagicMock(id="conv"), None])
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -391,7 +388,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -407,10 +404,7 @@ class TestAgentChatAppRunnerRun:
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
patch_create_session(mocker, side_effect=[app_record, conversation, message])
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)

View File

@ -1,5 +1,6 @@
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import Mock, patch
from unittest.mock import ANY, MagicMock, Mock, patch
import pytest
@ -29,6 +30,19 @@ class DummyQueueManager:
self.published.append((event, pub_from))
@contextmanager
def patched_create_session(*, return_value=None, side_effect=None):
session = MagicMock()
if side_effect is not None:
session.scalar.side_effect = side_effect
else:
session.scalar.return_value = return_value
session_context = MagicMock()
session_context.__enter__.return_value = session
with patch("core.app.apps.chat.app_runner.create_session", return_value=session_context):
yield session
class TestChatAppGenerator:
def test_generate_requires_query(self):
generator = ChatAppGenerator()
@ -167,7 +181,7 @@ class TestChatAppRunner:
invoke_from=InvokeFrom.SERVICE_API,
)
with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None):
with patched_create_session(return_value=None):
with pytest.raises(ValueError):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
@ -195,10 +209,7 @@ class TestChatAppRunner:
)
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
patch.object(ChatAppRunner, "direct_output") as mock_direct,
@ -233,10 +244,7 @@ class TestChatAppRunner:
annotation = SimpleNamespace(id="ann-1", content="answer")
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation),
@ -272,13 +280,73 @@ class TestChatAppRunner:
)
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True),
):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
def test_run_closes_scoped_session_before_stream_consumption(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
prompt_template=None,
external_data_variables=[],
dataset=None,
additional_features=None,
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model="model-1", parameters={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=True,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
events = []
queue_manager = DummyQueueManager()
model_instance = MagicMock()
def invoke_stream():
events.append("first-chunk")
yield "chunk"
def invoke_llm(**kwargs):
events.append("invoke")
return invoke_stream()
with (
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=False),
patch.object(ChatAppRunner, "recalc_llm_max_tokens"),
patch.object(
ChatAppRunner,
"_handle_invoke_result",
side_effect=lambda invoke_result, **kwargs: list(invoke_result),
) as mock_handle,
patch("core.app.apps.chat.app_runner.ModelInstance", return_value=model_instance),
patch("core.app.apps.chat.app_runner.db.session.close", side_effect=lambda: events.append("close")),
):
model_instance.invoke_llm.side_effect = invoke_llm
runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1"))
assert events == ["close", "invoke", "first-chunk"]
mock_handle.assert_called_once_with(
invoke_result=ANY,
queue_manager=queue_manager,
stream=True,
message_id="m1",
user_id="user-1",
tenant_id="tenant-1",
)

View File

@ -1,5 +1,6 @@
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import ANY, MagicMock, patch
import pytest
from pytest_mock import MockerFixture
@ -47,25 +48,28 @@ def _build_generate_entity(app_config, file_upload_config=None):
)
@contextmanager
def patched_create_session(*, return_value=None):
session = MagicMock()
session.scalar.return_value = return_value
session_context = MagicMock()
session_context.__enter__.return_value = session
with patch.object(module, "create_session", return_value=session_context):
yield session
class TestCompletionAppRunner:
def test_run_app_not_found(self, runner, mocker: MockerFixture):
session = mocker.MagicMock()
session.scalar.return_value = None
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
with pytest.raises(ValueError):
runner.run(app_generate_entity, MagicMock(), MagicMock())
with patched_create_session(return_value=None):
with pytest.raises(ValueError):
runner.run(app_generate_entity, MagicMock(), MagicMock())
def test_run_moderation_error_outputs_direct(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
@ -74,7 +78,8 @@ class TestCompletionAppRunner:
runner.direct_output = MagicMock()
runner._handle_invoke_result = MagicMock()
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
runner.direct_output.assert_called_once()
runner._handle_invoke_result.assert_not_called()
@ -82,10 +87,6 @@ class TestCompletionAppRunner:
def test_run_hosting_moderation_stops(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
@ -94,18 +95,14 @@ class TestCompletionAppRunner:
runner.check_hosting_moderation = MagicMock(return_value=True)
runner._handle_invoke_result = MagicMock()
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
runner._handle_invoke_result.assert_not_called()
def test_run_dataset_and_external_tools_flow(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
session.close = MagicMock()
mocker.patch.object(module.db, "session", session)
retrieve_config = MagicMock(query_variable="qvar")
dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config)
additional_features = MagicMock(show_retrieve_source=True)
@ -135,19 +132,56 @@ class TestCompletionAppRunner:
model_instance.invoke_llm.return_value = "invoke_result"
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
dataset_retrieval.retrieve.assert_called_once()
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
runner._handle_invoke_result.assert_called_once()
def test_run_closes_scoped_session_before_stream_consumption(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
queue_manager = MagicMock()
events = []
runner.organize_prompt_messages = MagicMock(return_value=([], None))
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.check_hosting_moderation = MagicMock(return_value=False)
runner.recalc_llm_max_tokens = MagicMock()
runner._handle_invoke_result = MagicMock(side_effect=lambda invoke_result, **kwargs: list(invoke_result))
model_instance = MagicMock()
def invoke_stream():
events.append("first-chunk")
yield "chunk"
def invoke_llm(**kwargs):
events.append("invoke")
return invoke_stream()
model_instance.invoke_llm.side_effect = invoke_llm
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
mocker.patch.object(module.db.session, "close", side_effect=lambda: events.append("close"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, queue_manager, MagicMock(id="msg"))
assert events == ["close", "invoke", "first-chunk"]
runner._handle_invoke_result.assert_called_once_with(
invoke_result=ANY,
queue_manager=queue_manager,
stream=True,
message_id="msg",
user_id="user",
tenant_id="tenant",
)
def test_run_uses_low_image_detail_default(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config, file_upload_config=None)
@ -155,7 +189,8 @@ class TestCompletionAppRunner:
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.check_hosting_moderation = MagicMock(return_value=True)
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
assert (
runner.organize_prompt_messages.call_args.kwargs["image_detail_config"]

View File

@ -53,6 +53,32 @@ def _build_app_generate_entity() -> SimpleNamespace:
)
def _patch_create_session(mocker: MockerFixture, session: MagicMock, *, events: list[str] | None = None):
"""Patch create_session() to yield ``session`` inside its ``with`` body and ``begin()`` block.
The runner now obtains short-lived sessions via ``create_session()`` instead of the
Flask scoped ``db.session``, so tests patch the module-level ``create_session`` and
hand back a context manager that yields the mock session.
"""
session_context = MagicMock()
def enter_session():
if events is not None:
events.append("session_enter")
return session
def exit_session(*args):
if events is not None:
events.append("session_exit")
return False
session_context.__enter__.side_effect = enter_session
session_context.__exit__.side_effect = exit_session
session.begin.return_value.__enter__.return_value = session
session.begin.return_value.__exit__.return_value = False
return mocker.patch.object(module, "create_session", return_value=session_context)
@pytest.fixture
def runner():
app_generate_entity = _build_app_generate_entity()
@ -77,13 +103,14 @@ def test_get_app_id(runner):
assert runner._get_app_id() == "pipe"
def test_get_workflow_returns_workflow(mocker, runner):
def test_get_workflow_returns_workflow(runner):
pipeline = MagicMock(tenant_id="tenant", id="pipe")
workflow = MagicMock(id="wf")
mocker.patch.object(module.db, "session", MagicMock(scalar=MagicMock(return_value=workflow)))
session = MagicMock()
session.scalar.return_value = workflow
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
result = runner.get_workflow(session=session, pipeline=pipeline, workflow_id="wf")
assert result == workflow
@ -116,7 +143,7 @@ def test_update_document_status_on_failure(mocker, runner):
session = MagicMock()
session.scalar.return_value = document
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session)
event = GraphRunFailedEvent(error="boom")
@ -124,7 +151,10 @@ def test_update_document_status_on_failure(mocker, runner):
assert document.indexing_status == "error"
assert document.error == "boom"
session.commit.assert_called_once()
session.add.assert_called_once_with(document)
session.begin.assert_called_once()
session.begin.return_value.__enter__.assert_called_once()
session.begin.return_value.__exit__.assert_called_once()
def test_run_pipeline_not_found(mocker: MockerFixture):
@ -135,7 +165,7 @@ def test_run_pipeline_not_found(mocker: MockerFixture):
session = MagicMock()
session.get.side_effect = [None, None]
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
@ -158,7 +188,7 @@ def test_run_workflow_not_initialized(mocker: MockerFixture):
session = MagicMock()
session.get.side_effect = [None, pipeline]
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
@ -184,7 +214,7 @@ def test_run_single_iteration_path(mocker: MockerFixture):
session = MagicMock()
session.get.side_effect = [end_user, pipeline]
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
@ -229,10 +259,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture):
pipeline = MagicMock(id="pipe")
end_user = MagicMock(session_id="sess")
events = []
session = MagicMock()
session.get.side_effect = [end_user, pipeline]
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session, events=events)
workflow = MagicMock(
id="wf",
@ -276,10 +307,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture):
workflow_entry = MagicMock()
workflow_entry.graph_engine = MagicMock()
workflow_entry.run.return_value = []
workflow_entry.run.side_effect = lambda: events.append("workflow_run") or []
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
runner.run()
assert events == ["session_enter", "session_exit", "workflow_run"]
runner._init_rag_pipeline_graph.assert_called_once()

View File

@ -0,0 +1,59 @@
from unittest.mock import patch
import pytest
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueTextChunkEvent
from models.model import AppMode
def _message_queue_manager(app_mode: str) -> MessageBasedAppQueueManager:
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
return MessageBasedAppQueueManager(
task_id="task-1",
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
conversation_id="conversation-1",
app_mode=app_mode,
message_id="message-1",
)
def _workflow_queue_manager(app_mode: str) -> WorkflowAppQueueManager:
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
return WorkflowAppQueueManager(
task_id="task-1",
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
app_mode=app_mode,
)
def test_message_queue_does_not_raise_legacy_stop_for_advanced_chat() -> None:
manager = _message_queue_manager(AppMode.ADVANCED_CHAT.value)
with patch.object(manager, "_is_stopped", return_value=True):
manager.publish(QueueTextChunkEvent(text="chunk"), PublishFrom.APPLICATION_MANAGER)
def test_workflow_queue_does_not_read_legacy_stop_flag() -> None:
manager = _workflow_queue_manager(AppMode.WORKFLOW.value)
with patch.object(manager, "_is_stopped", return_value=True) as is_stopped:
manager.publish(QueueTextChunkEvent(text="chunk"), PublishFrom.APPLICATION_MANAGER)
is_stopped.assert_not_called()
def test_message_queue_keeps_legacy_stop_for_non_graphengine_chat() -> None:
manager = _message_queue_manager(AppMode.CHAT.value)
with patch.object(manager, "_is_stopped", return_value=True):
with pytest.raises(GenerateTaskStoppedError):
manager.publish(QueueTextChunkEvent(text="chunk"), PublishFrom.APPLICATION_MANAGER)

Some files were not shown because too many files have changed in this diff Show More