From 26639e0923da9e2f453791907b86073a00eafe1c Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Tue, 23 Jun 2026 12:34:13 +0800 Subject: [PATCH 01/12] feat: add agent debug conversation refresh API (#37784) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/agent/roster.py | 30 ++++++++ api/openapi/markdown/console-openapi.md | 20 ++++++ api/services/agent/roster_service.py | 46 +++++++++++++ .../console/agent/test_agent_controllers.py | 34 ++++++++++ .../services/agent/test_agent_services.py | 68 +++++++++++++++++++ .../generated/api/console/agent/orpc.gen.ts | 54 ++++++++++----- .../generated/api/console/agent/types.gen.ts | 24 +++++++ .../generated/api/console/agent/zod.gen.ts | 17 +++++ 8 files changed, 277 insertions(+), 16 deletions(-) diff --git a/api/controllers/console/agent/roster.py b/api/controllers/console/agent/roster.py index 96bce6763f5..810dfda965a 100644 --- a/api/controllers/console/agent/roster.py +++ b/api/controllers/console/agent/roster.py @@ -228,6 +228,10 @@ class AgentAppDetailWithSite(GenericAppDetailWithSite): active_config_is_published: bool = False +class AgentDebugConversationRefreshResponse(BaseModel): + debug_conversation_id: str + + class AgentAppPagination(GenericAppPagination): data: list[AgentAppPartial] = Field( # type: ignore[assignment] # pyrefly: ignore[bad-override-mutable-attribute] validation_alias=AliasChoices("items", "data") @@ -254,6 +258,7 @@ register_response_schema_models( AgentAppPublishedReferenceResponse, AgentAppDetailWithSite, AgentAppPartial, + AgentDebugConversationRefreshResponse, AgentConfigSnapshotDetailResponse, AgentConfigSnapshotListResponse, AgentConfigSnapshotRestoreResponse, @@ -535,6 +540,31 @@ class AgentAppApi(Resource): return "", 204 +@console_ns.route("/agent//debug-conversation/refresh") +class AgentDebugConversationRefreshApi(Resource): + @console_ns.response( + 200, + "Agent debug conversation refreshed", + console_ns.models[AgentDebugConversationRefreshResponse.__name__], + ) + @console_ns.response(403, "Insufficient permissions") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, current_user: Account, agent_id: UUID): + debug_conversation_id = _agent_roster_service().refresh_agent_app_debug_conversation_id( + tenant_id=tenant_id, + agent_id=str(agent_id), + account_id=current_user.id, + ) + return AgentDebugConversationRefreshResponse(debug_conversation_id=debug_conversation_id).model_dump( + mode="json" + ) + + @console_ns.route("/agent//copy") class AgentAppCopyApi(Resource): @console_ns.expect(console_ns.models[CopyAppPayload.__name__]) diff --git a/api/openapi/markdown/console-openapi.md b/api/openapi/markdown/console-openapi.md index ef11a817662..a60e958f84b 100644 --- a/api/openapi/markdown/console-openapi.md +++ b/api/openapi/markdown/console-openapi.md @@ -602,6 +602,20 @@ Stop a running Agent App chat message generation | 400 | Invalid request parameters | | | 403 | Insufficient permissions | | +### [POST] /agent/{agent_id}/debug-conversation/refresh +#### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| agent_id | path | | Yes | string (uuid) | + +#### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Agent debug conversation refreshed | **application/json**: [AgentDebugConversationRefreshResponse](#agentdebugconversationrefreshresponse)
| +| 403 | Insufficient permissions | | + ### [GET] /agent/{agent_id}/drive/files List agent drive entries for an Agent App @@ -12562,6 +12576,12 @@ Audit operation recorded for Agent Soul version/revision changes. | date | string | | Yes | | message_count | integer | | Yes | +#### AgentDebugConversationRefreshResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| debug_conversation_id | string | | Yes | + #### AgentDriveDeleteFileByAgentQuery | Name | Type | Description | Required | diff --git a/api/services/agent/roster_service.py b/api/services/agent/roster_service.py index e78d49c65b7..b75fa0bb1ae 100644 --- a/api/services/agent/roster_service.py +++ b/api/services/agent/roster_service.py @@ -492,6 +492,52 @@ class AgentRosterService: self._session.commit() return conversation_id + def refresh_agent_app_debug_conversation_id( + self, *, tenant_id: str, agent_id: str, account_id: str, commit: bool = True + ) -> str: + """Start a new console debug conversation for the current Agent App editor.""" + + agent = self._session.scalar( + select(Agent).where( + Agent.tenant_id == tenant_id, + Agent.id == agent_id, + Agent.scope == AgentScope.ROSTER, + Agent.source == AgentSource.AGENT_APP, + Agent.status == AgentStatus.ACTIVE, + ) + ) + if agent is None or not agent.app_id: + raise AgentNotFoundError() + + conversation_id = self._create_agent_app_debug_conversation( + app_id=agent.app_id, + account_id=account_id, + ) + mapping = self._session.scalar( + select(AgentDebugConversation).where( + AgentDebugConversation.tenant_id == tenant_id, + AgentDebugConversation.agent_id == agent_id, + AgentDebugConversation.account_id == account_id, + ) + ) + if mapping is None: + self._session.add( + AgentDebugConversation( + tenant_id=tenant_id, + agent_id=agent_id, + app_id=agent.app_id, + account_id=account_id, + conversation_id=conversation_id, + ) + ) + else: + mapping.app_id = agent.app_id + mapping.conversation_id = conversation_id + self._session.flush() + if commit: + self._session.commit() + return conversation_id + def load_or_create_agent_app_debug_conversation_ids_by_agent_id( self, *, tenant_id: str, agents: list[Agent], account_id: str ) -> dict[str, str]: diff --git a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py index 02fd5da55ac..32a165ccd01 100644 --- a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py +++ b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py @@ -27,6 +27,7 @@ from controllers.console.agent.roster import ( AgentAppApi, AgentAppCopyApi, AgentAppListApi, + AgentDebugConversationRefreshApi, AgentInviteOptionsApi, AgentLogMessagesApi, AgentLogsApi, @@ -158,6 +159,7 @@ def test_agent_v2_console_routes_are_agent_id_first() -> None: "/agent//api-enable", "/agent//api-keys", "/agent//api-keys/", + "/agent//debug-conversation/refresh", "/agent//chat-messages", "/agent//chat-messages//stop", "/agent//feedbacks", @@ -483,6 +485,38 @@ def test_agent_app_copy_uses_agent_id_and_returns_agent_detail( } +def test_agent_debug_conversation_refresh_uses_current_user( + app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str +) -> None: + agent_id = "00000000-0000-0000-0000-000000000001" + captured: dict[str, object] = {} + + class FakeRosterService: + def refresh_agent_app_debug_conversation_id(self, **kwargs: object) -> str: + captured.update(kwargs) + return "new-debug-conversation-id" + + monkeypatch.setattr(roster_controller, "_agent_roster_service", lambda: FakeRosterService()) + + with app.test_request_context( + "/console/api/agent/00000000-0000-0000-0000-000000000001/debug-conversation/refresh", + method="POST", + ): + response = unwrap(AgentDebugConversationRefreshApi.post)( + AgentDebugConversationRefreshApi(), + "tenant-1", + SimpleNamespace(id=account_id), + agent_id, + ) + + assert response == {"debug_conversation_id": "new-debug-conversation-id"} + assert captured == { + "tenant_id": "tenant-1", + "agent_id": agent_id, + "account_id": account_id, + } + + def test_agent_api_access_uses_agent_id_and_returns_service_api_metadata( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/api/tests/unit_tests/services/agent/test_agent_services.py b/api/tests/unit_tests/services/agent/test_agent_services.py index 52ff00c0855..846ce5a3e62 100644 --- a/api/tests/unit_tests/services/agent/test_agent_services.py +++ b/api/tests/unit_tests/services/agent/test_agent_services.py @@ -1649,6 +1649,74 @@ class TestAgentAppBackingAgent: with pytest.raises(roster_service.AgentNotFoundError): service.get_agent_app_model(tenant_id="tenant-1", agent_id="agent-x") + def test_refresh_agent_app_debug_conversation_creates_mapping(self): + agent = Agent( + id="agent-1", + tenant_id="tenant-1", + name="Iris", + description="", + agent_kind=AgentKind.DIFY_AGENT, + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + status=AgentStatus.ACTIVE, + app_id="app-1", + ) + session = FakeSession(scalar=[agent, None]) + service = AgentRosterService(session) + + conversation_id = service.refresh_agent_app_debug_conversation_id( + tenant_id="tenant-1", + agent_id="agent-1", + account_id="account-1", + ) + + conversations = [a for a in session.added if isinstance(a, Conversation)] + assert len(conversations) == 1 + assert conversations[0].id == conversation_id + assert conversations[0].app_id == "app-1" + assert conversations[0].from_source == ConversationFromSource.CONSOLE + assert conversations[0].from_account_id == "account-1" + mappings = [a for a in session.added if isinstance(a, AgentDebugConversation)] + assert len(mappings) == 1 + assert mappings[0].tenant_id == "tenant-1" + assert mappings[0].agent_id == "agent-1" + assert mappings[0].app_id == "app-1" + assert mappings[0].account_id == "account-1" + assert mappings[0].conversation_id == conversation_id + assert session.deleted == [] + assert session.commits == 1 + + def test_refresh_agent_app_debug_conversation_replaces_existing_mapping(self): + agent = Agent( + id="agent-1", + tenant_id="tenant-1", + name="Iris", + description="", + agent_kind=AgentKind.DIFY_AGENT, + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + status=AgentStatus.ACTIVE, + app_id="app-1", + ) + mapping = SimpleNamespace(app_id="old-app", conversation_id="old-conversation") + session = FakeSession(scalar=[agent, mapping]) + service = AgentRosterService(session) + + conversation_id = service.refresh_agent_app_debug_conversation_id( + tenant_id="tenant-1", + agent_id="agent-1", + account_id="account-1", + ) + + assert mapping.app_id == "app-1" + assert mapping.conversation_id == conversation_id + assert [a for a in session.added if isinstance(a, AgentDebugConversation)] == [] + conversations = [a for a in session.added if isinstance(a, Conversation)] + assert len(conversations) == 1 + assert conversations[0].id == conversation_id + assert session.deleted == [] + assert session.commits == 1 + def test_duplicate_agent_app_copies_app_config_and_active_soul(self, monkeypatch: pytest.MonkeyPatch): source_config = SimpleNamespace( opening_statement="hello", diff --git a/packages/contracts/generated/api/console/agent/orpc.gen.ts b/packages/contracts/generated/api/console/agent/orpc.gen.ts index 99f44301703..5e4a692f244 100644 --- a/packages/contracts/generated/api/console/agent/orpc.gen.ts +++ b/packages/contracts/generated/api/console/agent/orpc.gen.ts @@ -84,6 +84,8 @@ import { zPostAgentByAgentIdCopyBody, zPostAgentByAgentIdCopyPath, zPostAgentByAgentIdCopyResponse, + zPostAgentByAgentIdDebugConversationRefreshPath, + zPostAgentByAgentIdDebugConversationRefreshResponse, zPostAgentByAgentIdFeaturesBody, zPostAgentByAgentIdFeaturesPath, zPostAgentByAgentIdFeaturesResponse, @@ -356,6 +358,25 @@ export const copy = { post: post5, } +export const post6 = oc + .route({ + inputStructure: 'detailed', + method: 'POST', + operationId: 'postAgentByAgentIdDebugConversationRefresh', + path: '/agent/{agent_id}/debug-conversation/refresh', + tags: ['console'], + }) + .input(z.object({ params: zPostAgentByAgentIdDebugConversationRefreshPath })) + .output(zPostAgentByAgentIdDebugConversationRefreshResponse) + +export const refresh = { + post: post6, +} + +export const debugConversation = { + refresh, +} + /** * Time-limited external signed URL for one Agent App drive value */ @@ -481,7 +502,7 @@ export const drive = { /** * Update an Agent App's presentation features (opener, follow-up, citations, ...) */ -export const post6 = oc +export const post7 = oc .route({ description: 'Update an Agent App\'s presentation features (opener, follow-up, citations, ...)', inputStructure: 'detailed', @@ -496,13 +517,13 @@ export const post6 = oc .output(zPostAgentByAgentIdFeaturesResponse) export const features = { - post: post6, + post: post7, } /** * Create or update Agent App message feedback */ -export const post7 = oc +export const post8 = oc .route({ description: 'Create or update Agent App message feedback', inputStructure: 'detailed', @@ -517,7 +538,7 @@ export const post7 = oc .output(zPostAgentByAgentIdFeedbacksResponse) export const feedbacks = { - post: post7, + post: post8, } /** @@ -540,7 +561,7 @@ export const delete2 = oc /** * Commit an uploaded file into the Agent App drive under files/ */ -export const post8 = oc +export const post9 = oc .route({ description: 'Commit an uploaded file into the Agent App drive under files/', inputStructure: 'detailed', @@ -555,7 +576,7 @@ export const post8 = oc export const files2 = { delete: delete2, - post: post8, + post: post9, } export const get13 = oc @@ -684,7 +705,7 @@ export const read = { /** * Upload one Agent App sandbox file as a Dify ToolFile mapping */ -export const post9 = oc +export const post10 = oc .route({ description: 'Upload one Agent App sandbox file as a Dify ToolFile mapping', inputStructure: 'detailed', @@ -702,7 +723,7 @@ export const post9 = oc .output(zPostAgentByAgentIdSandboxFilesUploadResponse) export const upload = { - post: post9, + post: post10, } /** @@ -738,7 +759,7 @@ export const sandbox = { /** * Upload + standardize a Skill into an Agent App drive */ -export const post10 = oc +export const post11 = oc .route({ description: 'Upload + standardize a Skill into an Agent App drive', inputStructure: 'detailed', @@ -757,13 +778,13 @@ export const post10 = oc .output(zPostAgentByAgentIdSkillsUploadResponse) export const upload2 = { - post: post10, + post: post11, } /** * Infer CLI tool + ENV suggestions from a standardized Agent App skill */ -export const post11 = oc +export const post12 = oc .route({ description: 'Infer CLI tool + ENV suggestions from a standardized Agent App skill', inputStructure: 'detailed', @@ -776,7 +797,7 @@ export const post11 = oc .output(zPostAgentByAgentIdSkillsBySlugInferToolsResponse) export const inferTools = { - post: post11, + post: post12, } /** @@ -828,7 +849,7 @@ export const statistics = { summary, } -export const post12 = oc +export const post13 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -840,7 +861,7 @@ export const post12 = oc .output(zPostAgentByAgentIdVersionsByVersionIdRestoreResponse) export const restore = { - post: post12, + post: post13, } export const get21 = oc @@ -919,6 +940,7 @@ export const byAgentId = { chatMessages, composer, copy, + debugConversation, drive, features, feedbacks, @@ -944,7 +966,7 @@ export const get24 = oc .input(z.object({ query: zGetAgentQuery.optional() })) .output(zGetAgentResponse) -export const post13 = oc +export const post14 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -958,7 +980,7 @@ export const post13 = oc export const agent = { get: get24, - post: post13, + post: post14, inviteOptions, byAgentId, } diff --git a/packages/contracts/generated/api/console/agent/types.gen.ts b/packages/contracts/generated/api/console/agent/types.gen.ts index 837b267a636..988a8999c30 100644 --- a/packages/contracts/generated/api/console/agent/types.gen.ts +++ b/packages/contracts/generated/api/console/agent/types.gen.ts @@ -166,6 +166,10 @@ export type CopyAppPayload = { name?: string | null } +export type AgentDebugConversationRefreshResponse = { + debug_conversation_id: string +} + export type AgentDriveListResponse = { items?: Array } @@ -1968,6 +1972,26 @@ export type PostAgentByAgentIdCopyResponses = { export type PostAgentByAgentIdCopyResponse = PostAgentByAgentIdCopyResponses[keyof PostAgentByAgentIdCopyResponses] +export type PostAgentByAgentIdDebugConversationRefreshData = { + body?: never + path: { + agent_id: string + } + query?: never + url: '/agent/{agent_id}/debug-conversation/refresh' +} + +export type PostAgentByAgentIdDebugConversationRefreshErrors = { + 403: unknown +} + +export type PostAgentByAgentIdDebugConversationRefreshResponses = { + 200: AgentDebugConversationRefreshResponse +} + +export type PostAgentByAgentIdDebugConversationRefreshResponse + = PostAgentByAgentIdDebugConversationRefreshResponses[keyof PostAgentByAgentIdDebugConversationRefreshResponses] + export type GetAgentByAgentIdDriveFilesData = { body?: never path: { diff --git a/packages/contracts/generated/api/console/agent/zod.gen.ts b/packages/contracts/generated/api/console/agent/zod.gen.ts index 297055a155c..aeab80c9463 100644 --- a/packages/contracts/generated/api/console/agent/zod.gen.ts +++ b/packages/contracts/generated/api/console/agent/zod.gen.ts @@ -61,6 +61,13 @@ export const zSimpleResultResponse = z.object({ result: z.string(), }) +/** + * AgentDebugConversationRefreshResponse + */ +export const zAgentDebugConversationRefreshResponse = z.object({ + debug_conversation_id: z.string(), +}) + /** * AgentDriveDownloadResponse */ @@ -2437,6 +2444,16 @@ export const zPostAgentByAgentIdCopyPath = z.object({ */ export const zPostAgentByAgentIdCopyResponse = zAgentAppDetailWithSite +export const zPostAgentByAgentIdDebugConversationRefreshPath = z.object({ + agent_id: z.uuid(), +}) + +/** + * Agent debug conversation refreshed + */ +export const zPostAgentByAgentIdDebugConversationRefreshResponse + = zAgentDebugConversationRefreshResponse + export const zGetAgentByAgentIdDriveFilesPath = z.object({ agent_id: z.uuid(), }) From b3e5f29421d1695c07c8794f05523ad163483b21 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Mon, 22 Jun 2026 21:40:05 -0700 Subject: [PATCH 02/12] fix(app): derive get-app --mode whitelist from listable app types (#37761) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/openapi/_models.py | 33 ++++++- api/controllers/openapi/apps.py | 9 ++ api/openapi/markdown/openapi-openapi.md | 28 +++++- .../controllers/openapi/_mode_constants.py | 10 ++ .../openapi/test_app_list_query.py | 55 ++++++----- .../controllers/openapi/test_app_payloads.py | 15 +++ .../test_apps_permitted_external_query.py | 11 ++- .../openapi/test_supported_app_type.py | 24 +++++ cli/src/api/apps.ts | 6 +- cli/src/commands/get/app/index.ts | 18 ++-- .../commands/get/app/mode-whitelist.test.ts | 13 +++ cli/src/commands/get/app/run.ts | 4 +- .../e2e/suites/discovery/get-app-list.e2e.ts | 13 +++ .../generated/api/openapi/types.gen.ts | 26 ++--- .../generated/api/openapi/zod.gen.ts | 99 ++++++++++--------- 15 files changed, 246 insertions(+), 118 deletions(-) create mode 100644 api/tests/unit_tests/controllers/openapi/_mode_constants.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_supported_app_type.py create mode 100644 cli/src/commands/get/app/mode-whitelist.test.ts diff --git a/api/controllers/openapi/_models.py b/api/controllers/openapi/_models.py index e846db3ea75..6e8a9c9d439 100644 --- a/api/controllers/openapi/_models.py +++ b/api/controllers/openapi/_models.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any, Literal +from enum import StrEnum +from typing import Any, Final, Literal from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator @@ -13,6 +14,30 @@ from models.model import AppMode MAX_PAGE_LIMIT = 200 +class SupportedAppType(StrEnum): + """App types the ``app`` usage face (``get app``) lists and filters. + + A curated subset of :class:`AppMode`: the real, user-facing app categories. + Excludes runtime-only mode tags that are not standalone apps + (``rag-pipeline`` is a knowledge ``Pipeline``; ``channel`` is unused) and the + roster-owned ``agent`` type (surfaced through the roster, not this list). + + Members reference ``AppMode.*.value`` so the subset relationship is + type-checked: dropping a member from ``AppMode`` breaks this at import. + This is the single source for the listable set — params, filters, and the + generated CLI whitelist all derive from it. + """ + + COMPLETION = AppMode.COMPLETION.value + CHAT = AppMode.CHAT.value + ADVANCED_CHAT = AppMode.ADVANCED_CHAT.value + WORKFLOW = AppMode.WORKFLOW.value + AGENT_CHAT = AppMode.AGENT_CHAT.value + + +SUPPORTED_APP_TYPES: Final[tuple[AppMode, ...]] = tuple(AppMode(t.value) for t in SupportedAppType) + + class UsageInfo(BaseModel): prompt_tokens: int = 0 completion_tokens: int = 0 @@ -279,12 +304,12 @@ class AppDescribeQuery(BaseModel): class AppListQuery(BaseModel): - """mode is a closed enum.""" + """mode is a closed enum of listable app types.""" workspace_id: UUIDStr page: int = Field(1, ge=1) limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT) - mode: AppMode | None = None + mode: SupportedAppType | None = None name: str | None = Field(None, max_length=200) @@ -335,7 +360,7 @@ class PermittedExternalAppsListQuery(BaseModel): page: int = Field(1, ge=1) limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT) - mode: AppMode | None = None + mode: SupportedAppType | None = None name: str | None = Field(None, max_length=200) diff --git a/api/controllers/openapi/apps.py b/api/controllers/openapi/apps.py index c2626cd5d8c..181af5c0742 100644 --- a/api/controllers/openapi/apps.py +++ b/api/controllers/openapi/apps.py @@ -16,6 +16,7 @@ from controllers.openapi import openapi_ns from controllers.openapi._contract import accepts, returns from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config from controllers.openapi._models import ( + SUPPORTED_APP_TYPES, AppDescribeInfo, AppDescribeQuery, AppDescribeResponse, @@ -37,6 +38,11 @@ from services.app_service import AppListParams, AppService _ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"}) +def _is_listable(app: App) -> bool: + """Whether the openapi app face exposes this app (curated, listable types only).""" + return app.mode in SUPPORTED_APP_TYPES + + _EMPTY_PARAMETERS: dict[str, Any] = { "opening_statement": None, "suggested_questions": [], @@ -171,6 +177,8 @@ class AppListApi(Resource): app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid)) if app is None or str(app.tenant_id) != workspace_id: return empty + if not _is_listable(app): + return empty # Apply RBAC visibility to the UUID fast-path the same way the service # layer does for paginated queries (id in accessible set OR own app). if apply_rbac_filter and not access_filter.is_app_accessible( @@ -223,6 +231,7 @@ class AppListApi(Resource): workspace_name=tenant_name, ) for r in pagination.items + if _is_listable(r) ] env = AppListResponse( diff --git a/api/openapi/markdown/openapi-openapi.md b/api/openapi/markdown/openapi-openapi.md index bd93557edcf..4bb6761c22e 100644 --- a/api/openapi/markdown/openapi-openapi.md +++ b/api/openapi/markdown/openapi-openapi.md @@ -80,7 +80,7 @@ User-scoped operations | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | | limit | query | | No | integer,
**Default:** 20 | -| mode | query | | No | string,
**Available values:** "advanced-chat", "agent", "agent-chat", "channel", "chat", "completion", "rag-pipeline", "workflow" | +| mode | query | App types the ``app`` usage face (``get app``) lists and filters. A curated subset of :class:`AppMode`: the real, user-facing app categories. Excludes runtime-only mode tags that are not standalone apps (``rag-pipeline`` is a knowledge ``Pipeline``; ``channel`` is unused) and the roster-owned ``agent`` type (surfaced through the roster, not this list). Members reference ``AppMode.*.value`` so the subset relationship is type-checked: dropping a member from ``AppMode`` breaks this at import. This is the single source for the listable set — params, filters, and the generated CLI whitelist all derive from it. | No | string,
**Available values:** "advanced-chat", "agent-chat", "chat", "completion", "workflow" | | name | query | | No | string | | page | query | | No | integer,
**Default:** 1 | | workspace_id | query | | Yes | string | @@ -318,7 +318,7 @@ Upload a file to use as an input variable when running the app | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | | limit | query | | No | integer,
**Default:** 20 | -| mode | query | | No | string,
**Available values:** "advanced-chat", "agent", "agent-chat", "channel", "chat", "completion", "rag-pipeline", "workflow" | +| mode | query | App types the ``app`` usage face (``get app``) lists and filters. A curated subset of :class:`AppMode`: the real, user-facing app categories. Excludes runtime-only mode tags that are not standalone apps (``rag-pipeline`` is a knowledge ``Pipeline``; ``channel`` is unused) and the roster-owned ``agent`` type (surfaced through the roster, not this list). Members reference ``AppMode.*.value`` so the subset relationship is type-checked: dropping a member from ``AppMode`` breaks this at import. This is the single source for the listable set — params, filters, and the generated CLI whitelist all derive from it. | No | string,
**Available values:** "advanced-chat", "agent-chat", "chat", "completion", "workflow" | | name | query | | No | string | | page | query | | No | integer,
**Default:** 1 | @@ -592,12 +592,12 @@ Request body for POST /workspaces//apps/imports. #### AppListQuery -mode is a closed enum. +mode is a closed enum of listable app types. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | | limit | integer,
**Default:** 20 | | No | -| mode | [AppMode](#appmode) | | No | +| mode | [SupportedAppType](#supportedapptype) | | No | | name | string | | No | | page | integer,
**Default:** 1 | | No | | workspace_id | string | | Yes | @@ -922,7 +922,7 @@ Strict (extra='forbid'). | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | | limit | integer,
**Default:** 20 | | No | -| mode | [AppMode](#appmode) | | No | +| mode | [SupportedAppType](#supportedapptype) | | No | | name | string | | No | | page | integer,
**Default:** 1 | | No | @@ -990,6 +990,24 @@ Pagination for GET /account/sessions. Strict (extra='forbid'). | last_used_at | string | | No | | prefix | string | | Yes | +#### SupportedAppType + +App types the ``app`` usage face (``get app``) lists and filters. + +A curated subset of :class:`AppMode`: the real, user-facing app categories. +Excludes runtime-only mode tags that are not standalone apps +(``rag-pipeline`` is a knowledge ``Pipeline``; ``channel`` is unused) and the +roster-owned ``agent`` type (surfaced through the roster, not this list). + +Members reference ``AppMode.*.value`` so the subset relationship is +type-checked: dropping a member from ``AppMode`` breaks this at import. +This is the single source for the listable set — params, filters, and the +generated CLI whitelist all derive from it. + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| SupportedAppType | string | App types the ``app`` usage face (``get app``) lists and filters. A curated subset of :class:`AppMode`: the real, user-facing app categories. Excludes runtime-only mode tags that are not standalone apps (``rag-pipeline`` is a knowledge ``Pipeline``; ``channel`` is unused) and the roster-owned ``agent`` type (surfaced through the roster, not this list). Members reference ``AppMode.*.value`` so the subset relationship is type-checked: dropping a member from ``AppMode`` breaks this at import. This is the single source for the listable set — params, filters, and the generated CLI whitelist all derive from it. | | + #### TaskStopResponse 200 body for POST /apps//tasks//stop. The handler always returns diff --git a/api/tests/unit_tests/controllers/openapi/_mode_constants.py b/api/tests/unit_tests/controllers/openapi/_mode_constants.py new file mode 100644 index 00000000000..2a8e477c754 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/_mode_constants.py @@ -0,0 +1,10 @@ +"""Shared mode lists for the openapi app-list query tests. + +Single source so adding/removing a listable app type is a one-line change +across every query-validator test. +""" + +from __future__ import annotations + +LISTABLE_MODES = ["completion", "chat", "advanced-chat", "workflow", "agent-chat"] +NON_LISTABLE_MODES = ["rag-pipeline", "channel", "agent"] diff --git a/api/tests/unit_tests/controllers/openapi/test_app_list_query.py b/api/tests/unit_tests/controllers/openapi/test_app_list_query.py index e0b15585323..7cc2149baff 100644 --- a/api/tests/unit_tests/controllers/openapi/test_app_list_query.py +++ b/api/tests/unit_tests/controllers/openapi/test_app_list_query.py @@ -4,7 +4,7 @@ Runs against the model directly, not the HTTP layer. Pins: - defaults match the plan (page=1, limit=20). - workspace_id is required. - numeric bounds enforced (page >= 1, limit in [1, MAX_PAGE_LIMIT]). -- mode validates against the AppMode enum. +- mode validates against the SupportedAppType enum (listable app types only). - name has a length cap. """ @@ -16,10 +16,14 @@ from pydantic import ValidationError from controllers.openapi._models import MAX_PAGE_LIMIT from controllers.openapi.apps import AppListQuery +from ._mode_constants import LISTABLE_MODES, NON_LISTABLE_MODES + +WS_ID = "00000000-0000-0000-0000-000000000001" + def test_defaults(): - q = AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001"}) - assert q.workspace_id == "00000000-0000-0000-0000-000000000001" + q = AppListQuery.model_validate({"workspace_id": WS_ID}) + assert q.workspace_id == WS_ID assert q.page == 1 assert q.limit == 20 assert q.mode is None @@ -33,64 +37,71 @@ def test_workspace_id_required(): def test_page_must_be_positive(): with pytest.raises(ValidationError): - AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "page": 0}) + AppListQuery.model_validate({"workspace_id": WS_ID, "page": 0}) with pytest.raises(ValidationError): - AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "page": -1}) + AppListQuery.model_validate({"workspace_id": WS_ID, "page": -1}) def test_page_rejects_non_integer_string(): with pytest.raises(ValidationError): - AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "page": "abc"}) + AppListQuery.model_validate({"workspace_id": WS_ID, "page": "abc"}) def test_limit_must_be_positive(): with pytest.raises(ValidationError): - AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "limit": 0}) + AppListQuery.model_validate({"workspace_id": WS_ID, "limit": 0}) with pytest.raises(ValidationError): - AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "limit": -1}) + AppListQuery.model_validate({"workspace_id": WS_ID, "limit": -1}) def test_limit_caps_at_max_page_limit(): # Boundary accepts. - q = AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "limit": MAX_PAGE_LIMIT}) + q = AppListQuery.model_validate({"workspace_id": WS_ID, "limit": MAX_PAGE_LIMIT}) assert q.limit == MAX_PAGE_LIMIT # Just over rejects. with pytest.raises(ValidationError): - AppListQuery.model_validate( - {"workspace_id": "00000000-0000-0000-0000-000000000001", "limit": MAX_PAGE_LIMIT + 1} - ) + AppListQuery.model_validate({"workspace_id": WS_ID, "limit": MAX_PAGE_LIMIT + 1}) -def test_mode_whitelisted_against_app_mode(): - # Valid mode passes. - q = AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "mode": "chat"}) +@pytest.mark.parametrize("mode", LISTABLE_MODES) +def test_mode_accepts_listable_app_types(mode: str): + q = AppListQuery.model_validate({"workspace_id": WS_ID, "mode": mode}) assert q.mode is not None - assert q.mode.value == "chat" + assert q.mode.value == mode - # Invalid mode rejects. + +def test_mode_rejects_unknown_value(): with pytest.raises(ValidationError): - AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "mode": "not-a-mode"}) + AppListQuery.model_validate({"workspace_id": WS_ID, "mode": "not-a-mode"}) + + +@pytest.mark.parametrize("mode", NON_LISTABLE_MODES) +def test_mode_rejects_non_listable_app_modes(mode: str): + """rag-pipeline (a knowledge Pipeline), channel (unused) and agent (roster-owned) + are AppMode members but not standalone listable apps — the `app` face rejects them.""" + with pytest.raises(ValidationError): + AppListQuery.model_validate({"workspace_id": WS_ID, "mode": mode}) def test_name_length_capped(): - AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "name": "x" * 200}) + AppListQuery.model_validate({"workspace_id": WS_ID, "name": "x" * 200}) with pytest.raises(ValidationError): - AppListQuery.model_validate({"workspace_id": "00000000-0000-0000-0000-000000000001", "name": "x" * 201}) + AppListQuery.model_validate({"workspace_id": WS_ID, "name": "x" * 201}) def test_all_fields_accept_valid_values(): """Pin the happy-path acceptance for every field in one place.""" q = AppListQuery.model_validate( { - "workspace_id": "00000000-0000-0000-0000-000000000001", + "workspace_id": WS_ID, "page": 5, "limit": 50, "mode": "workflow", "name": "search", } ) - assert q.workspace_id == "00000000-0000-0000-0000-000000000001" + assert q.workspace_id == WS_ID assert q.page == 5 assert q.limit == 50 assert q.mode is not None diff --git a/api/tests/unit_tests/controllers/openapi/test_app_payloads.py b/api/tests/unit_tests/controllers/openapi/test_app_payloads.py index 64cdc382500..12bf4c696f2 100644 --- a/api/tests/unit_tests/controllers/openapi/test_app_payloads.py +++ b/api/tests/unit_tests/controllers/openapi/test_app_payloads.py @@ -10,9 +10,11 @@ import pytest from controllers.openapi.apps import ( # pyright: ignore[reportPrivateUsage] _EMPTY_PARAMETERS, + _is_listable, parameters_payload, ) from controllers.service_api.app.error import AppUnavailableError +from models.model import AppMode def _fake_app(**overrides): @@ -53,3 +55,16 @@ def test_empty_parameters_constant_matches_describe_fallback_shape(): assert _EMPTY_PARAMETERS["opening_statement"] is None assert _EMPTY_PARAMETERS["file_upload"] is None assert _EMPTY_PARAMETERS["system_parameters"] == {} + + +@pytest.mark.parametrize( + "mode", + [AppMode.COMPLETION, AppMode.CHAT, AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT_CHAT], +) +def test_is_listable_accepts_supported_app_types(mode): + assert _is_listable(_fake_app(mode=mode)) is True + + +@pytest.mark.parametrize("mode", [AppMode.AGENT, AppMode.CHANNEL, AppMode.RAG_PIPELINE]) +def test_is_listable_hides_non_app_modes(mode): + assert _is_listable(_fake_app(mode=mode)) is False diff --git a/api/tests/unit_tests/controllers/openapi/test_apps_permitted_external_query.py b/api/tests/unit_tests/controllers/openapi/test_apps_permitted_external_query.py index 96873b04f46..0f530e3c3c7 100644 --- a/api/tests/unit_tests/controllers/openapi/test_apps_permitted_external_query.py +++ b/api/tests/unit_tests/controllers/openapi/test_apps_permitted_external_query.py @@ -13,6 +13,8 @@ from pydantic import ValidationError from controllers.openapi.apps_permitted_external import PermittedExternalAppsListQuery +from ._mode_constants import NON_LISTABLE_MODES + def test_query_defaults_match_apps_list(): q = PermittedExternalAppsListQuery.model_validate({}) @@ -36,11 +38,18 @@ def test_query_rejects_tag(): PermittedExternalAppsListQuery.model_validate({"tag": "prod"}) -def test_query_validates_mode_against_app_mode(): +def test_query_validates_mode_against_supported_app_type(): with pytest.raises(ValidationError): PermittedExternalAppsListQuery.model_validate({"mode": "not-a-mode"}) +@pytest.mark.parametrize("mode", NON_LISTABLE_MODES) +def test_query_rejects_non_listable_app_modes(mode: str): + """Non-app runtime modes and roster-owned agent are not listable here.""" + with pytest.raises(ValidationError): + PermittedExternalAppsListQuery.model_validate({"mode": mode}) + + def test_query_clamps_limit_at_max(): with pytest.raises(ValidationError): PermittedExternalAppsListQuery.model_validate({"limit": 500}) diff --git a/api/tests/unit_tests/controllers/openapi/test_supported_app_type.py b/api/tests/unit_tests/controllers/openapi/test_supported_app_type.py new file mode 100644 index 00000000000..3af1c148644 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_supported_app_type.py @@ -0,0 +1,24 @@ +"""Unit tests for SupportedAppType — the listable subset of AppMode that the +openapi `app` face (`get app`) exposes and the CLI `--mode` whitelist derives from. +""" + +from __future__ import annotations + +from controllers.openapi._models import SUPPORTED_APP_TYPES, SupportedAppType +from models.model import AppMode + + +def test_supported_app_type_is_the_listable_subset_of_app_mode(): + """SupportedAppType (and the derived SUPPORTED_APP_TYPES tuple) is exactly the + curated, listable subset of AppMode; non-app/runtime modes stay out.""" + assert {t.value for t in SupportedAppType} == { + "completion", + "chat", + "advanced-chat", + "workflow", + "agent-chat", + } + assert set(SUPPORTED_APP_TYPES) <= set(AppMode) + assert AppMode.AGENT not in SUPPORTED_APP_TYPES + assert AppMode.RAG_PIPELINE not in SUPPORTED_APP_TYPES + assert AppMode.CHANNEL not in SUPPORTED_APP_TYPES diff --git a/cli/src/api/apps.ts b/cli/src/api/apps.ts index 01b18d9a9da..1189fdeaa06 100644 --- a/cli/src/api/apps.ts +++ b/cli/src/api/apps.ts @@ -1,4 +1,4 @@ -import type { AppDescribeResponse, AppListResponse, AppMode } from '@dify/contracts/api/openapi/types.gen' +import type { AppDescribeResponse, AppListResponse, SupportedAppType } from '@dify/contracts/api/openapi/types.gen' import type { AppReader } from './app-reader' import type { OpenApiClient } from '@/http/orpc' import type { HttpClient } from '@/http/types' @@ -8,12 +8,12 @@ export type ListQuery = { readonly workspaceId: string readonly page?: number readonly limit?: number - readonly mode?: AppMode | '' + readonly mode?: SupportedAppType | '' readonly name?: string } // An absent or empty mode filter means "any mode" — collapse both to undefined for the query. -export function normalizeMode(mode: AppMode | '' | undefined): AppMode | undefined { +export function normalizeMode(mode: SupportedAppType | '' | undefined): SupportedAppType | undefined { return mode !== undefined && mode !== '' ? mode : undefined } diff --git a/cli/src/commands/get/app/index.ts b/cli/src/commands/get/app/index.ts index ffce31b7c49..4deed4ade9f 100644 --- a/cli/src/commands/get/app/index.ts +++ b/cli/src/commands/get/app/index.ts @@ -1,4 +1,5 @@ -import type { AppMode } from '@dify/contracts/api/openapi/types.gen' +import type { SupportedAppType } from '@dify/contracts/api/openapi/types.gen' +import { zSupportedAppType } from '@dify/contracts/api/openapi/zod.gen' import { DifyCommand } from '@/commands/_shared/dify-command' import { httpRetryFlag } from '@/commands/_shared/global-flags' import { Args, Flags } from '@/framework/flags' @@ -6,16 +7,9 @@ import { OutputFormat, table } from '@/framework/output' import { agentGuide } from './guide' import { runGetApp } from './run' -const APP_MODE_VALUES: readonly AppMode[] = [ - 'advanced-chat', - 'agent', - 'agent-chat', - 'channel', - 'chat', - 'completion', - 'rag-pipeline', - 'workflow', -] +// Single source: derived from the backend's listable app types (openapi codegen). +// Adding/removing a listable type is a backend-only change that flows here on regen. +const APP_MODE_VALUES: readonly SupportedAppType[] = zSupportedAppType.options export default class GetApp extends DifyCommand { static override description = 'List apps or describe one app\'s basic info' @@ -56,7 +50,7 @@ export default class GetApp extends DifyCommand { allWorkspaces: flags['all-workspaces'], page: flags.page, limitRaw: flags.limit, - mode: flags.mode as AppMode | undefined, + mode: flags.mode as SupportedAppType | undefined, name: flags.name, format, }, { active: ctx.active, http: ctx.http, io: ctx.io }) diff --git a/cli/src/commands/get/app/mode-whitelist.test.ts b/cli/src/commands/get/app/mode-whitelist.test.ts new file mode 100644 index 00000000000..93ecd86402c --- /dev/null +++ b/cli/src/commands/get/app/mode-whitelist.test.ts @@ -0,0 +1,13 @@ +import { zSupportedAppType } from '@dify/contracts/api/openapi/zod.gen' +import { describe, expect, it } from 'vitest' + +// The `get app --mode` whitelist is derived from this generated enum (see index.ts). +// These pins guard the original bug: the CLI must not advertise modes the backend +// rejects (rag-pipeline, channel) or modes that aren't listable here (agent). +describe('get app --mode whitelist', () => { + it('is exactly the listable app types', () => { + expect([...zSupportedAppType.options].sort()).toEqual( + ['advanced-chat', 'agent-chat', 'chat', 'completion', 'workflow'], + ) + }) +}) diff --git a/cli/src/commands/get/app/run.ts b/cli/src/commands/get/app/run.ts index c4a7911e0db..9008cfb70c5 100644 --- a/cli/src/commands/get/app/run.ts +++ b/cli/src/commands/get/app/run.ts @@ -1,4 +1,4 @@ -import type { AppDescribeResponse, AppListResponse, AppMode } from '@dify/contracts/api/openapi/types.gen' +import type { AppDescribeResponse, AppListResponse, AppMode, SupportedAppType } from '@dify/contracts/api/openapi/types.gen' import type { AppReader } from '@/api/app-reader' import type { ActiveContext } from '@/auth/hosts' import type { HttpClient } from '@/http/types' @@ -20,7 +20,7 @@ export type GetAppOptions = { readonly allWorkspaces?: boolean readonly page?: number readonly limitRaw?: string - readonly mode?: AppMode + readonly mode?: SupportedAppType readonly name?: string readonly format?: string } diff --git a/cli/test/e2e/suites/discovery/get-app-list.e2e.ts b/cli/test/e2e/suites/discovery/get-app-list.e2e.ts index 2e7623933b0..61d235794ad 100644 --- a/cli/test/e2e/suites/discovery/get-app-list.e2e.ts +++ b/cli/test/e2e/suites/discovery/get-app-list.e2e.ts @@ -177,6 +177,19 @@ describe('E2E / difyctl get app (list)', () => { expect(result.exitCode, '--mode chatbot should cause non-zero exit').not.toBe(0) }) + // Regression: rag-pipeline (a knowledge Pipeline), channel (unused) and agent + // (roster-owned) are AppMode members but not listable app types. The old CLI + // whitelist advertised rag-pipeline/channel, so the CLI forwarded them and the + // server replied 400. The whitelist now derives from SupportedAppType, so the + // CLI rejects them before any HTTP call. + it.each(['rag-pipeline', 'channel', 'agent'])( + '[P0] non-listable mode %s is intercepted client-side', + async (mode) => { + const result = await fx.r(['get', 'app', '--mode', mode]) + expect(result.exitCode, `--mode ${mode} should be rejected client-side`).not.toBe(0) + }, + ) + // ── workspace override ──────────────────────────────────────────────────── it('[P0] -w overrides the default workspace', async () => { diff --git a/packages/contracts/generated/api/openapi/types.gen.ts b/packages/contracts/generated/api/openapi/types.gen.ts index 185ee37aa6f..2d47f947247 100644 --- a/packages/contracts/generated/api/openapi/types.gen.ts +++ b/packages/contracts/generated/api/openapi/types.gen.ts @@ -73,7 +73,7 @@ export type AppInfo = { export type AppListQuery = { limit?: number - mode?: AppMode | null + mode?: SupportedAppType | null name?: string | null page?: number workspace_id: string @@ -354,7 +354,7 @@ export type Package = { export type PermittedExternalAppsListQuery = { limit?: number - mode?: AppMode | null + mode?: SupportedAppType | null name?: string | null page?: number } @@ -405,6 +405,8 @@ export type SessionRow = { prefix: string } +export type SupportedAppType = 'advanced-chat' | 'agent-chat' | 'chat' | 'completion' | 'workflow' + export type TaskStopResponse = { result: 'success' } @@ -589,15 +591,7 @@ export type GetAppsData = { path?: never query: { limit?: number - mode?: - | 'advanced-chat' - | 'agent' - | 'agent-chat' - | 'channel' - | 'chat' - | 'completion' - | 'rag-pipeline' - | 'workflow' + mode?: 'advanced-chat' | 'agent-chat' | 'chat' | 'completion' | 'workflow' name?: string page?: number workspace_id: string @@ -905,15 +899,7 @@ export type GetPermittedExternalAppsData = { path?: never query?: { limit?: number - mode?: - | 'advanced-chat' - | 'agent' - | 'agent-chat' - | 'channel' - | 'chat' - | 'completion' - | 'rag-pipeline' - | 'workflow' + mode?: 'advanced-chat' | 'agent-chat' | 'chat' | 'completion' | 'workflow' name?: string page?: number } diff --git a/packages/contracts/generated/api/openapi/zod.gen.ts b/packages/contracts/generated/api/openapi/zod.gen.ts index 804c75394f6..557447cc769 100644 --- a/packages/contracts/generated/api/openapi/zod.gen.ts +++ b/packages/contracts/generated/api/openapi/zod.gen.ts @@ -104,19 +104,6 @@ export const zAppMode = z.enum([ 'workflow', ]) -/** - * AppListQuery - * - * mode is a closed enum. - */ -export const zAppListQuery = z.object({ - limit: z.int().gte(1).lte(200).optional().default(20), - mode: zAppMode.nullish(), - name: z.string().max(200).nullish(), - page: z.int().gte(1).optional().default(1), - workspace_id: z.string(), -}) - /** * AppListRow */ @@ -452,18 +439,6 @@ export const zPackage = z.object({ version: z.string().nullish(), }) -/** - * PermittedExternalAppsListQuery - * - * Strict (extra='forbid'). - */ -export const zPermittedExternalAppsListQuery = z.object({ - limit: z.int().gte(1).lte(200).optional().default(20), - mode: zAppMode.nullish(), - name: z.string().max(200).nullish(), - page: z.int().gte(1).optional().default(1), -}) - /** * PermittedExternalAppsListResponse */ @@ -526,6 +501,54 @@ export const zSessionListResponse = z.object({ total: z.int(), }) +/** + * SupportedAppType + * + * App types the ``app`` usage face (``get app``) lists and filters. + * + * A curated subset of :class:`AppMode`: the real, user-facing app categories. + * Excludes runtime-only mode tags that are not standalone apps + * (``rag-pipeline`` is a knowledge ``Pipeline``; ``channel`` is unused) and the + * roster-owned ``agent`` type (surfaced through the roster, not this list). + * + * Members reference ``AppMode.*.value`` so the subset relationship is + * type-checked: dropping a member from ``AppMode`` breaks this at import. + * This is the single source for the listable set — params, filters, and the + * generated CLI whitelist all derive from it. + */ +export const zSupportedAppType = z.enum([ + 'advanced-chat', + 'agent-chat', + 'chat', + 'completion', + 'workflow', +]) + +/** + * AppListQuery + * + * mode is a closed enum of listable app types. + */ +export const zAppListQuery = z.object({ + limit: z.int().gte(1).lte(200).optional().default(20), + mode: zSupportedAppType.nullish(), + name: z.string().max(200).nullish(), + page: z.int().gte(1).optional().default(1), + workspace_id: z.string(), +}) + +/** + * PermittedExternalAppsListQuery + * + * Strict (extra='forbid'). + */ +export const zPermittedExternalAppsListQuery = z.object({ + limit: z.int().gte(1).lte(200).optional().default(20), + mode: zSupportedAppType.nullish(), + name: z.string().max(200).nullish(), + page: z.int().gte(1).optional().default(1), +}) + /** * TaskStopResponse * @@ -698,18 +721,7 @@ export const zDeleteAccountSessionsBySessionIdResponse = zRevokeResponse export const zGetAppsQuery = z.object({ limit: z.int().gte(1).lte(200).optional().default(20), - mode: z - .enum([ - 'advanced-chat', - 'agent', - 'agent-chat', - 'channel', - 'chat', - 'completion', - 'rag-pipeline', - 'workflow', - ]) - .optional(), + mode: z.enum(['advanced-chat', 'agent-chat', 'chat', 'completion', 'workflow']).optional(), name: z.string().max(200).optional(), page: z.int().gte(1).optional().default(1), workspace_id: z.string(), @@ -862,18 +874,7 @@ export const zPostOauthDeviceTokenResponse = zDeviceTokenResponse export const zGetPermittedExternalAppsQuery = z.object({ limit: z.int().gte(1).lte(200).optional().default(20), - mode: z - .enum([ - 'advanced-chat', - 'agent', - 'agent-chat', - 'channel', - 'chat', - 'completion', - 'rag-pipeline', - 'workflow', - ]) - .optional(), + mode: z.enum(['advanced-chat', 'agent-chat', 'chat', 'completion', 'workflow']).optional(), name: z.string().max(200).optional(), page: z.int().gte(1).optional().default(1), }) From 7fc8eed7164fdbefc697e2bac61a8c2e5098b5dc Mon Sep 17 00:00:00 2001 From: Myshkin451 <79880574+myshkin451@users.noreply.github.com> Date: Tue, 23 Jun 2026 14:21:38 +0800 Subject: [PATCH 03/12] refactor: pass session into hit testing service (#37785) --- api/controllers/console/datasets/external.py | 2 + .../console/datasets/hit_testing_base.py | 2 + api/services/hit_testing_service.py | 30 ++++++----- .../services/test_hit_testing_service.py | 44 +++++++++++---- api/tests/unit_tests/services/hit_service.py | 54 ++++++++++++------- 5 files changed, 90 insertions(+), 42 deletions(-) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 033c9a69af6..eb7b9aa84f8 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -26,6 +26,7 @@ from controllers.console.wraps import ( with_current_tenant_id, with_current_user, ) +from extensions.ext_database import db from fields.base import ResponseModel from fields.dataset_fields import ( dataset_detail_fields, @@ -390,6 +391,7 @@ class ExternalKnowledgeHitTestingApi(Resource): try: response = HitTestingService.external_retrieve( + session=db.session, dataset=dataset, query=payload.query, account=current_user, diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 4e90e66eb25..c343effa9a1 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -18,6 +18,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from libs.login import resolve_account_fallback from models.account import Account @@ -115,6 +116,7 @@ class DatasetsHitTestingBase: try: current_user, _ = resolve_account_fallback(current_user, current_tenant_id) response = HitTestingService.retrieve( + session=db.session, dataset=dataset, query=cast(str, args.get("query")), account=current_user, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 6c70e324b17..9a2843864d9 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -4,6 +4,7 @@ import time from typing import Any, TypedDict, cast from sqlalchemy import select +from sqlalchemy.orm import Session, scoped_session from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService @@ -12,7 +13,6 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from extensions.ext_database import db from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery @@ -56,7 +56,9 @@ class HitTestingService: } @classmethod - def _dump_retrieval_records(cls, records: list[RetrievalSegments]) -> list[dict[str, Any]]: + def _dump_retrieval_records( + cls, session: Session | scoped_session, records: list[RetrievalSegments] + ) -> list[dict[str, Any]]: document_ids = { document_id for record in records @@ -69,9 +71,7 @@ class HitTestingService: documents = { document.id: cls._dump_dataset_document(document) - for document in db.session.scalars( - select(DatasetDocument).where(DatasetDocument.id.in_(document_ids)) - ).all() + for document in session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all() } records_with_documents: list[dict[str, Any]] = [] @@ -105,6 +105,7 @@ class HitTestingService: @classmethod def retrieve( cls, + session: Session | scoped_session, dataset: Dataset, query: str, account: Account, @@ -142,7 +143,7 @@ class HitTestingService: if metadata_filter_document_ids: document_ids_filter = metadata_filter_document_ids.get(dataset.id, []) if metadata_condition and not document_ids_filter: - return cls.compact_retrieve_response(query, []) + return cls.compact_retrieve_response(session, query, []) all_documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod( resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH) @@ -181,14 +182,15 @@ class HitTestingService: created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) - db.session.add(dataset_query) - db.session.commit() + session.add(dataset_query) + session.commit() - return cls.compact_retrieve_response(query, all_documents) + return cls.compact_retrieve_response(session, query, all_documents) @classmethod def external_retrieve( cls, + session: Session | scoped_session, dataset: Dataset, query: str, account: Account, @@ -222,20 +224,22 @@ class HitTestingService: created_by=account.id, ) - db.session.add(dataset_query) - db.session.commit() + session.add(dataset_query) + session.commit() return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod - def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict: + def compact_retrieve_response( + cls, session: Session | scoped_session, query: str, documents: list[Document] + ) -> RetrieveResponseDict: records = RetrievalService.format_retrieval_documents(documents) return { "query": { "content": query, }, - "records": cls._dump_retrieval_records(records), + "records": cls._dump_retrieval_records(session, records), } @classmethod diff --git a/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index 2d23ae8f68f..fbf993f7d69 100644 --- a/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -181,7 +181,9 @@ class TestHitTestingService: # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") - def test_compact_retrieve_response_should_format_correctly(self, mock_format: MagicMock) -> None: + def test_compact_retrieve_response_should_format_correctly( + self, mock_format: MagicMock, db_session_with_containers: Session + ) -> None: query = "test query" mock_doc = MagicMock(spec=Document) @@ -189,7 +191,9 @@ class TestHitTestingService: mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - response = _RetrieveResponse.model_validate(HitTestingService.compact_retrieve_response(query, [mock_doc])) + response = _RetrieveResponse.model_validate( + HitTestingService.compact_retrieve_response(db_session_with_containers, query, [mock_doc]) + ) assert response.query.content == query assert len(response.records) == 1 @@ -242,6 +246,7 @@ class TestHitTestingService: response = _RetrieveResponse.model_validate( HitTestingService.external_retrieve( + db_session_with_containers, dataset=dataset, query='test "query"', account=account, @@ -269,7 +274,9 @@ class TestHitTestingService: dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - response = _RetrieveResponse.model_validate(HitTestingService.external_retrieve(dataset, "test query", account)) + response = _RetrieveResponse.model_validate( + HitTestingService.external_retrieve(db_session_with_containers, dataset, "test query", account) + ) assert response.query.content == "test query" assert response.records == [] @@ -292,6 +299,7 @@ class TestHitTestingService: response = _RetrieveResponse.model_validate( HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -320,7 +328,11 @@ class TestHitTestingService: retrieval_model = { "search_method": "semantic_search", - "metadata_filtering_conditions": {"some": "condition"}, + "metadata_filtering_conditions": { + "conditions": [ + {"name": "category", "comparison_operator": "is", "value": "test"}, + ], + }, "top_k": 5, "reranking_enable": False, "score_threshold_enabled": False, @@ -330,6 +342,7 @@ class TestHitTestingService: mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -352,7 +365,11 @@ class TestHitTestingService: retrieval_model = { "search_method": "semantic_search", - "metadata_filtering_conditions": {"some": "condition"}, + "metadata_filtering_conditions": { + "conditions": [ + {"name": "category", "comparison_operator": "is", "value": "test"}, + ], + }, "top_k": 5, "reranking_enable": False, "score_threshold_enabled": False, @@ -362,6 +379,7 @@ class TestHitTestingService: response = _RetrieveResponse.model_validate( HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -393,6 +411,7 @@ class TestHitTestingService: mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -452,6 +471,7 @@ class TestHitTestingService: mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -477,11 +497,15 @@ class TestHitTestingService: "doc_metadata": {"source": "manual"}, } - def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self) -> None: + def test_dump_retrieval_records_returns_dumped_records_without_document_ids( + self, db_session_with_containers: Session + ) -> None: segment = _build_segment(document_id="") record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95}) - records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record])) + records = _DUMPED_RETRIEVAL_RECORDS.validate_python( + HitTestingService._dump_retrieval_records(db_session_with_containers, [record]) + ) assert len(records) == 1 assert records[0].segment.id == segment.id @@ -493,7 +517,9 @@ class TestHitTestingService: segment = _create_segment(db_session_with_containers, document=document) record = RetrievalSegments.model_validate({"segment": segment, "score": 0.9}) - records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record])) + records = _DUMPED_RETRIEVAL_RECORDS.validate_python( + HitTestingService._dump_retrieval_records(db_session_with_containers, [record]) + ) assert len(records) == 1 dumped_segment = records[0].segment @@ -515,7 +541,7 @@ class TestHitTestingService: segment = _create_segment(db_session_with_containers) record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95}) - result = HitTestingService._dump_retrieval_records([record]) + result = HitTestingService._dump_retrieval_records(db_session_with_containers, [record]) assert result == [] assert "Skipping hit-testing records with missing documents" in caplog.text diff --git a/api/tests/unit_tests/services/hit_service.py b/api/tests/unit_tests/services/hit_service.py index ddbc7dc0413..ae19daba898 100644 --- a/api/tests/unit_tests/services/hit_service.py +++ b/api/tests/unit_tests/services/hit_service.py @@ -147,8 +147,7 @@ class TestHitTestingServiceRetrieve: Provides a mocked database session for testing database operations like adding and committing DatasetQuery records. """ - with patch("services.hit_testing_service.db.session", autospec=True) as mock_db: - yield mock_db + return MagicMock() def test_retrieve_success_with_default_retrieval_model(self, mock_db_session): """ @@ -186,7 +185,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = mock_records # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -232,7 +233,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = mock_records # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -257,9 +260,11 @@ class TestHitTestingServiceRetrieve: retrieval_model = { "metadata_filtering_conditions": { "conditions": [ - {"field": "category", "operator": "is", "value": "test"}, + {"name": "category", "comparison_operator": "is", "value": "test"}, ], }, + "reranking_enable": False, + "score_threshold_enabled": False, } external_retrieval_model = {} @@ -286,7 +291,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = mock_records # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -308,9 +315,11 @@ class TestHitTestingServiceRetrieve: retrieval_model = { "metadata_filtering_conditions": { "conditions": [ - {"field": "category", "operator": "is", "value": "test"}, + {"name": "category", "comparison_operator": "is", "value": "test"}, ], }, + "reranking_enable": False, + "score_threshold_enabled": False, } external_retrieval_model = {} @@ -327,7 +336,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = [] # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -344,6 +355,8 @@ class TestHitTestingServiceRetrieve: dataset_retrieval_model = { "search_method": RetrievalMethod.HYBRID_SEARCH, "top_k": 3, + "reranking_enable": False, + "score_threshold_enabled": False, } dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=dataset_retrieval_model) account = HitTestingTestDataFactory.create_user_mock() @@ -366,7 +379,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = mock_records # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -391,8 +406,7 @@ class TestHitTestingServiceExternalRetrieve: Provides a mocked database session for testing database operations like adding and committing DatasetQuery records. """ - with patch("services.hit_testing_service.db.session", autospec=True) as mock_db: - yield mock_db + return MagicMock() def test_external_retrieve_success(self, mock_db_session): """ @@ -424,7 +438,7 @@ class TestHitTestingServiceExternalRetrieve: # Act result = HitTestingService.external_retrieve( - dataset, query, account, external_retrieval_model, metadata_filtering_conditions + mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions ) # Assert @@ -455,7 +469,7 @@ class TestHitTestingServiceExternalRetrieve: # Act result = HitTestingService.external_retrieve( - dataset, query, account, external_retrieval_model, metadata_filtering_conditions + mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions ) # Assert @@ -490,7 +504,7 @@ class TestHitTestingServiceExternalRetrieve: # Act result = HitTestingService.external_retrieve( - dataset, query, account, external_retrieval_model, metadata_filtering_conditions + mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions ) # Assert @@ -524,7 +538,7 @@ class TestHitTestingServiceExternalRetrieve: # Act result = HitTestingService.external_retrieve( - dataset, query, account, external_retrieval_model, metadata_filtering_conditions + mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions ) # Assert @@ -565,7 +579,7 @@ class TestHitTestingServiceCompactRetrieveResponse: mock_format.return_value = mock_records # Act - result = HitTestingService.compact_retrieve_response(query, documents) + result = HitTestingService.compact_retrieve_response(MagicMock(), query, documents) # Assert assert result["query"]["content"] == query @@ -591,7 +605,7 @@ class TestHitTestingServiceCompactRetrieveResponse: mock_format.return_value = [] # Act - result = HitTestingService.compact_retrieve_response(query, documents) + result = HitTestingService.compact_retrieve_response(MagicMock(), query, documents) # Assert assert result["query"]["content"] == query @@ -708,7 +722,7 @@ class TestHitTestingServiceHitTestingArgsCheck: args = {"query": ""} # Act & Assert - with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"): + with pytest.raises(ValueError, match="Query or attachment_ids is required"): HitTestingService.hit_testing_args_check(args) def test_hit_testing_args_check_none_query(self): @@ -721,7 +735,7 @@ class TestHitTestingServiceHitTestingArgsCheck: args = {"query": None} # Act & Assert - with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"): + with pytest.raises(ValueError, match="Query or attachment_ids is required"): HitTestingService.hit_testing_args_check(args) def test_hit_testing_args_check_too_long_query(self): @@ -734,7 +748,7 @@ class TestHitTestingServiceHitTestingArgsCheck: args = {"query": "a" * 251} # Act & Assert - with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"): + with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): HitTestingService.hit_testing_args_check(args) def test_hit_testing_args_check_exactly_250_characters(self): From a3309cd857ba4a9a1a6f5b9fbed61ac59b92bdc7 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Tue, 23 Jun 2026 14:35:26 +0800 Subject: [PATCH 04/12] fix: support agent duplicate role and skill file preview (#37788) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/agent/roster.py | 27 ++++++++++++++--- api/openapi/markdown/console-openapi.md | 13 +++++++- api/services/agent/roster_service.py | 4 ++- .../agent/skill_standardize_service.py | 30 ++++++++++++++++++- .../console/agent/test_agent_controllers.py | 2 ++ .../services/agent/test_agent_services.py | 5 ++++ .../agent/test_skill_standardize_service.py | 21 +++++++++---- .../generated/api/console/agent/types.gen.ts | 5 ++-- .../generated/api/console/agent/zod.gen.ts | 7 +++-- 9 files changed, 96 insertions(+), 18 deletions(-) diff --git a/api/controllers/console/agent/roster.py b/api/controllers/console/agent/roster.py index 810dfda965a..ac3f7ef4824 100644 --- a/api/controllers/console/agent/roster.py +++ b/api/controllers/console/agent/roster.py @@ -14,7 +14,6 @@ from controllers.console.app.app import ( ) from controllers.console.app.app import ( AppListQuery, - CopyAppPayload, _normalize_app_list_query_args, ) from controllers.console.app.app import ( @@ -110,6 +109,25 @@ class AgentAppUpdatePayload(GenericUpdateAppPayload): return role +class AgentAppCopyPayload(BaseModel): + name: str | None = Field(default=None, description="Name for the copied agent") + description: str | None = Field(default=None, description="Description for the copied agent", max_length=400) + role: str | None = Field(default=None, description="Role for the copied agent", max_length=255) + icon_type: IconType | None = Field(default=None, description="Icon type") + icon: str | None = Field(default=None, description="Icon") + icon_background: str | None = Field(default=None, description="Icon background color") + + @field_validator("role") + @classmethod + def validate_role(cls, value: str | None) -> str | None: + if value is None: + return None + role = value.strip() + if not role: + raise ValueError("Agent role is required when provided.") + return role + + class AgentApiStatusPayload(BaseModel): enable_api: bool = Field(..., description="Enable or disable Agent service API") @@ -242,8 +260,8 @@ register_schema_models( console_ns, AgentAppCreatePayload, AgentAppUpdatePayload, + AgentAppCopyPayload, AgentApiStatusPayload, - CopyAppPayload, AgentInviteOptionsQuery, AgentLogsQuery, AgentStatisticsQuery, @@ -567,7 +585,7 @@ class AgentDebugConversationRefreshApi(Resource): @console_ns.route("/agent//copy") class AgentAppCopyApi(Resource): - @console_ns.expect(console_ns.models[CopyAppPayload.__name__]) + @console_ns.expect(console_ns.models[AgentAppCopyPayload.__name__]) @console_ns.response(201, "Agent app copied successfully", console_ns.models[AgentAppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @@ -578,13 +596,14 @@ class AgentAppCopyApi(Resource): @with_current_user @with_current_tenant_id def post(self, tenant_id: str, current_user: Account, agent_id: UUID): - args = CopyAppPayload.model_validate(console_ns.payload or {}) + args = AgentAppCopyPayload.model_validate(console_ns.payload or {}) copied_app = _agent_roster_service().duplicate_agent_app( tenant_id=tenant_id, agent_id=str(agent_id), account=current_user, name=args.name, description=args.description, + role=args.role, icon_type=args.icon_type, icon=args.icon, icon_background=args.icon_background, diff --git a/api/openapi/markdown/console-openapi.md b/api/openapi/markdown/console-openapi.md index a60e958f84b..881e83061cc 100644 --- a/api/openapi/markdown/console-openapi.md +++ b/api/openapi/markdown/console-openapi.md @@ -592,7 +592,7 @@ Stop a running Agent App chat message generation | Required | Schema | | -------- | ------ | -| Yes | **application/json**: [CopyAppPayload](#copyapppayload)
| +| Yes | **application/json**: [AgentAppCopyPayload](#agentappcopypayload)
| #### Responses @@ -12157,6 +12157,17 @@ Default namespace | validation | [ComposerValidationFindingsResponse](#composervalidationfindingsresponse) | | No | | variant | string | | Yes | +#### AgentAppCopyPayload + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| description | string | Description for the copied agent | No | +| icon | string | Icon | No | +| icon_background | string | Icon background color | No | +| icon_type | [IconType](#icontype) | Icon type | No | +| name | string | Name for the copied agent | No | +| role | string | Role for the copied agent | No | + #### AgentAppCreatePayload | Name | Type | Description | Required | diff --git a/api/services/agent/roster_service.py b/api/services/agent/roster_service.py index b75fa0bb1ae..6a9d5818647 100644 --- a/api/services/agent/roster_service.py +++ b/api/services/agent/roster_service.py @@ -633,6 +633,7 @@ class AgentRosterService: account: Any, name: str | None = None, description: str | None = None, + role: str | None = None, icon_type: Any = None, icon: str | None = None, icon_background: str | None = None, @@ -644,6 +645,7 @@ class AgentRosterService: copied_name = name or self._next_duplicate_agent_name(tenant_id=tenant_id, base_name=source_app.name) copied_description = description if description is not None else source_app.description + copied_role = role if role is not None else source_agent.role or "" copied_icon_type = icon_type if icon_type is not None else source_app.icon_type copied_icon = icon if icon is not None else source_app.icon copied_icon_background = icon_background if icon_background is not None else source_app.icon_background @@ -654,7 +656,7 @@ class AgentRosterService: name=copied_name, description=copied_description, mode="agent", - agent_role=source_agent.role or "", + agent_role=copied_role, icon_type=self._normalize_app_icon_type(copied_icon_type), icon=copied_icon, icon_background=copied_icon_background, diff --git a/api/services/agent/skill_standardize_service.py b/api/services/agent/skill_standardize_service.py index 3fbcb81e61f..f8e2c8c4633 100644 --- a/api/services/agent/skill_standardize_service.py +++ b/api/services/agent/skill_standardize_service.py @@ -17,6 +17,8 @@ normalization. from __future__ import annotations +import mimetypes +import posixpath import re from typing import Any @@ -62,7 +64,8 @@ class SkillStandardizeService: skill_md_bytes = self._package.read_member_bytes(content=content, member_path=manifest.entry_path) slug = slugify_skill_name(manifest.name) - # Two drive-owned ToolFiles: canonical SKILL.md + the full archive. + # Drive-owned files: canonical SKILL.md, every inspectable archive file, + # and the full archive for future restore/export. md_tool_file = self._tool_files.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, @@ -82,6 +85,30 @@ class SkillStandardizeService: skill_md_key = f"{slug}/{_SKILL_MD_NAME}" archive_key = f"{slug}/{_FULL_ARCHIVE_NAME}" + member_items: list[DriveCommitItem] = [] + for member_path in sorted(set(manifest.files)): + member_key = f"{slug}/{member_path}" + if member_key in {skill_md_key, archive_key}: + continue + + member_bytes = self._package.read_member_bytes(content=content, member_path=member_path) + mimetype = mimetypes.guess_type(member_path)[0] or "application/octet-stream" + member_tool_file = self._tool_files.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=member_bytes, + mimetype=mimetype, + filename=posixpath.basename(member_path), + ) + member_items.append( + DriveCommitItem( + key=member_key, + file_ref=DriveFileRef(kind="tool_file", id=member_tool_file.id), + value_owned_by_drive=True, + ) + ) + self._drive.commit( tenant_id=tenant_id, user_id=user_id, @@ -103,6 +130,7 @@ class SkillStandardizeService: file_ref=DriveFileRef(kind="tool_file", id=archive_tool_file.id), value_owned_by_drive=True, ), + *member_items, ], ) diff --git a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py index 32a165ccd01..ec3de9928a5 100644 --- a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py +++ b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py @@ -464,6 +464,7 @@ def test_agent_app_copy_uses_agent_id_and_returns_agent_detail( json={ "name": "Iris copy", "description": "Copied", + "role": "Copied role", "icon_type": "emoji", "icon": "sparkles", "icon_background": "#fff", @@ -479,6 +480,7 @@ def test_agent_app_copy_uses_agent_id_and_returns_agent_detail( "account": current_user, "name": "Iris copy", "description": "Copied", + "role": "Copied role", "icon_type": "emoji", "icon": "sparkles", "icon_background": "#fff", diff --git a/api/tests/unit_tests/services/agent/test_agent_services.py b/api/tests/unit_tests/services/agent/test_agent_services.py index 846ce5a3e62..01db8caede4 100644 --- a/api/tests/unit_tests/services/agent/test_agent_services.py +++ b/api/tests/unit_tests/services/agent/test_agent_services.py @@ -1883,8 +1883,11 @@ class TestAgentAppBackingAgent: monkeypatch.setattr(service, "_copy_agent_active_snapshot", lambda **_: None) monkeypatch.setattr(service, "_next_duplicate_agent_name", lambda **_: "Iris copy") + captured: dict[str, object] = {} + class FakeAppService: def create_app(self, tenant_id: str, params, account: object) -> object: + captured["params"] = params return target_app access_mode_updates = [] @@ -1910,9 +1913,11 @@ class TestAgentAppBackingAgent: tenant_id="tenant-1", agent_id="source-agent", account=SimpleNamespace(id="account-1"), + role="Custom Analyst", ) assert duplicated is target_app + assert captured["params"].agent_role == "Custom Analyst" assert access_mode_updates == [("target-app", "private")] def test_duplicate_agent_app_falls_back_to_public_access_mode(self, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/services/agent/test_skill_standardize_service.py b/api/tests/unit_tests/services/agent/test_skill_standardize_service.py index 29b4c7e59d6..cd27e1b40d0 100644 --- a/api/tests/unit_tests/services/agent/test_skill_standardize_service.py +++ b/api/tests/unit_tests/services/agent/test_skill_standardize_service.py @@ -32,13 +32,14 @@ def test_slugify_skill_name(): assert slugify_skill_name("") == "skill" -def test_standardize_creates_two_drive_owned_toolfiles_and_commits(): +def test_standardize_creates_drive_owned_toolfiles_and_commits_archive_members(): content = _zip({"SKILL.md": _SKILL_MD, "scripts/run.py": b"print('x')\n"}) tool_files = MagicMock() tool_files.create_file_by_raw.side_effect = [ SimpleNamespace(id="md-tool-file"), SimpleNamespace(id="zip-tool-file"), + SimpleNamespace(id="script-tool-file"), ] drive = MagicMock() drive.commit.return_value = [] @@ -52,25 +53,33 @@ def test_standardize_creates_two_drive_owned_toolfiles_and_commits(): agent_id="agent-1", ) - # Two ToolFiles: SKILL.md (markdown) + full archive (zip). - assert tool_files.create_file_by_raw.call_count == 2 - md_call, zip_call = tool_files.create_file_by_raw.call_args_list + # ToolFiles: SKILL.md, full archive, and each inspectable package member. + assert tool_files.create_file_by_raw.call_count == 3 + md_call, zip_call, script_call = tool_files.create_file_by_raw.call_args_list assert md_call.kwargs["mimetype"] == "text/markdown" assert md_call.kwargs["file_binary"] == _SKILL_MD assert zip_call.kwargs["mimetype"] == "application/zip" assert zip_call.kwargs["file_binary"] == content + assert script_call.kwargs["mimetype"] in {"text/x-python", "text/plain", "application/octet-stream"} + assert script_call.kwargs["file_binary"] == b"print('x')\n" + assert script_call.kwargs["filename"] == "run.py" # Committed as drive-owned with the standardized keys. commit_kwargs = drive.commit.call_args.kwargs assert commit_kwargs["agent_id"] == "agent-1" items = commit_kwargs["items"] - assert [item.key for item in items] == ["pdf-toolkit/SKILL.md", "pdf-toolkit/.DIFY-SKILL-FULL.zip"] + assert [item.key for item in items] == [ + "pdf-toolkit/SKILL.md", + "pdf-toolkit/.DIFY-SKILL-FULL.zip", + "pdf-toolkit/scripts/run.py", + ] assert all(item.value_owned_by_drive for item in items) - assert [item.file_ref.id for item in items] == ["md-tool-file", "zip-tool-file"] + assert [item.file_ref.id for item in items] == ["md-tool-file", "zip-tool-file", "script-tool-file"] assert items[0].is_skill is True assert items[0].skill_metadata.name == "PDF Toolkit" assert items[0].skill_metadata.manifest_files == ["SKILL.md", "scripts/run.py"] assert items[1].is_skill is False + assert items[2].is_skill is False # The returned skill ref carries stable drive paths + file ids. skill = result["skill"] diff --git a/packages/contracts/generated/api/console/agent/types.gen.ts b/packages/contracts/generated/api/console/agent/types.gen.ts index 988a8999c30..fb096b9c520 100644 --- a/packages/contracts/generated/api/console/agent/types.gen.ts +++ b/packages/contracts/generated/api/console/agent/types.gen.ts @@ -158,12 +158,13 @@ export type AgentComposerValidateResponse = { warnings?: Array } -export type CopyAppPayload = { +export type AgentAppCopyPayload = { description?: string | null icon?: string | null icon_background?: string | null icon_type?: IconType | null name?: string | null + role?: string | null } export type AgentDebugConversationRefreshResponse = { @@ -1952,7 +1953,7 @@ export type PostAgentByAgentIdComposerValidateResponse = PostAgentByAgentIdComposerValidateResponses[keyof PostAgentByAgentIdComposerValidateResponses] export type PostAgentByAgentIdCopyData = { - body: CopyAppPayload + body: AgentAppCopyPayload path: { agent_id: string } diff --git a/packages/contracts/generated/api/console/agent/zod.gen.ts b/packages/contracts/generated/api/console/agent/zod.gen.ts index aeab80c9463..e97bb5b71f9 100644 --- a/packages/contracts/generated/api/console/agent/zod.gen.ts +++ b/packages/contracts/generated/api/console/agent/zod.gen.ts @@ -170,14 +170,15 @@ export const zAgentAppUpdatePayload = z.object({ }) /** - * CopyAppPayload + * AgentAppCopyPayload */ -export const zCopyAppPayload = z.object({ +export const zAgentAppCopyPayload = z.object({ description: z.string().max(400).nullish(), icon: z.string().nullish(), icon_background: z.string().nullish(), icon_type: zIconType.nullish(), name: z.string().nullish(), + role: z.string().max(255).nullish(), }) /** @@ -2433,7 +2434,7 @@ export const zPostAgentByAgentIdComposerValidatePath = z.object({ */ export const zPostAgentByAgentIdComposerValidateResponse = zAgentComposerValidateResponse -export const zPostAgentByAgentIdCopyBody = zCopyAppPayload +export const zPostAgentByAgentIdCopyBody = zAgentAppCopyPayload export const zPostAgentByAgentIdCopyPath = z.object({ agent_id: z.uuid(), From cf1ebdadf5aec185fe4210cd74cece41d5982fec Mon Sep 17 00:00:00 2001 From: "Byron.wang" Date: Tue, 23 Jun 2026 00:20:25 -0700 Subject: [PATCH 05/12] feat(retention): add V2 workflow run archive bundlesa (#37747) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/commands/__init__.py | 2 + api/commands/retention.py | 639 +++++++++++-- api/extensions/ext_commands.py | 2 + .../api_workflow_run_repository.py | 5 + .../sqlalchemy_api_workflow_run_repository.py | 47 +- .../archive_paid_plan_workflow_run.py | 600 ++++++++---- .../bundle_archive_maintenance.py | 872 ++++++++++++++++++ .../retention/workflow_run/constants.py | 8 + .../delete_archived_workflow_run.py | 333 ++++++- .../retention/workflow_run/tenant_prefix.py | 20 + .../retention/test_workflow_run_archiver.py | 245 ++++- .../test_delete_archived_workflow_run.py | 177 +++- .../test_delete_archived_workflow_run.py | 96 ++ .../test_archive_workflow_run_logs.py | 18 +- 14 files changed, 2736 insertions(+), 328 deletions(-) create mode 100644 api/services/retention/workflow_run/bundle_archive_maintenance.py create mode 100644 api/services/retention/workflow_run/tenant_prefix.py create mode 100644 api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py diff --git a/api/commands/__init__.py b/api/commands/__init__.py index 94321ed1e49..e4207bea74e 100644 --- a/api/commands/__init__.py +++ b/api/commands/__init__.py @@ -25,6 +25,7 @@ from .plugin import ( from .rbac import migrate_member_roles_to_rbac from .retention import ( archive_workflow_runs, + archive_workflow_runs_plan, clean_expired_messages, clean_workflow_runs, cleanup_orphaned_draft_variables, @@ -51,6 +52,7 @@ from .vector import ( __all__ = [ "add_qdrant_index", "archive_workflow_runs", + "archive_workflow_runs_plan", "backfill_plugin_auto_upgrade", "clean_expired_messages", "clean_workflow_runs", diff --git a/api/commands/retention.py b/api/commands/retention.py index 657a2a2e839..1386e367aff 100644 --- a/api/commands/retention.py +++ b/api/commands/retention.py @@ -12,10 +12,160 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi from services.retention.conversation.messages_clean_policy import create_message_clean_policy from services.retention.conversation.messages_clean_service import MessagesCleanService from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup +from services.retention.workflow_run.tenant_prefix import tenant_prefix_condition from tasks.remove_app_and_related_data_task import delete_draft_variables_batch logger = logging.getLogger(__name__) +_HEX_PREFIXES = tuple("0123456789abcdef") + + +class WorkflowRunArchivePlanRow(TypedDict): + tenant_prefix: str + total_tenants: int + workflow_runs: int + workflow_node_executions: int + paid_tenants: int + unpaid_tenants: int + + +class WorkflowRunArchiveTenantPlan(TypedDict): + archive_tenant_ids: list[str] | None + paid_tenant_ids: list[str] + unpaid_tenant_ids: list[str] + + +def _parse_tenant_prefixes(prefixes: str | None) -> list[str]: + if not prefixes: + return [] + + parsed = [] + for raw_prefix in prefixes.split(","): + prefix = raw_prefix.strip().lower() + if not prefix: + continue + if len(prefix) != 1 or prefix not in _HEX_PREFIXES: + raise click.UsageError("--tenant-prefixes must be a comma-separated list of hex digits, e.g. 0,1,a,f.") + parsed.append(prefix) + return sorted(set(parsed)) + + +def _get_archive_candidate_tenant_ids_by_prefix( + prefix: str, + *, + start_from: datetime.datetime | None, + end_before: datetime.datetime, +) -> list[str]: + from graphon.enums import WorkflowExecutionStatus + from models.workflow import WorkflowRun + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + conditions = [ + WorkflowRun.created_at < end_before, + WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()), + WorkflowRun.type.in_(WorkflowRunArchiver.ARCHIVED_TYPE), + tenant_prefix_condition(WorkflowRun.tenant_id, prefix), + ] + if start_from is not None: + conditions.append(WorkflowRun.created_at >= start_from) + + tenant_ids = db.session.scalars( + sa.select(WorkflowRun.tenant_id).where(*conditions).distinct().order_by(WorkflowRun.tenant_id) + ).all() + return list(tenant_ids) + + +def _filter_paid_workflow_archive_tenant_ids(tenant_ids: list[str]) -> tuple[list[str], list[str]]: + from configs import dify_config + from enums.cloud_plan import CloudPlan + from services.billing_service import BillingService + + tenant_ids = sorted(set(tenant_ids)) + if not tenant_ids: + return [], [] + if not dify_config.BILLING_ENABLED: + return tenant_ids, [] + + plans = BillingService.get_plan_bulk_with_cache(tenant_ids) + paid_tenant_ids = [ + tenant_id + for tenant_id in tenant_ids + if plans.get(tenant_id) and plans[tenant_id].get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM) + ] + unpaid_tenant_ids = sorted(set(tenant_ids) - set(paid_tenant_ids)) + return paid_tenant_ids, unpaid_tenant_ids + + +def _resolve_archive_tenant_ids_from_plan( + *, + tenant_ids: str | None, + tenant_prefixes: list[str], + start_from: datetime.datetime | None, + end_before: datetime.datetime, +) -> WorkflowRunArchiveTenantPlan: + """ + Resolve the archive tenant scope once before scanning workflow_runs. + + Prefix rollout should use the tenant list collected by the same planning path, then archive by + tenant_id IN (...). Scanning workflow_runs with a tenant prefix range in every archive run is too expensive on + the large production table this command is meant to shrink. + """ + if tenant_ids: + requested_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + elif tenant_prefixes: + requested_tenant_ids = [] + for prefix in tenant_prefixes: + requested_tenant_ids.extend( + _get_archive_candidate_tenant_ids_by_prefix( + prefix, + start_from=start_from, + end_before=end_before, + ) + ) + else: + return WorkflowRunArchiveTenantPlan( + archive_tenant_ids=None, + paid_tenant_ids=[], + unpaid_tenant_ids=[], + ) + + paid_tenant_ids, unpaid_tenant_ids = _filter_paid_workflow_archive_tenant_ids(requested_tenant_ids) + return WorkflowRunArchiveTenantPlan( + archive_tenant_ids=paid_tenant_ids, + paid_tenant_ids=paid_tenant_ids, + unpaid_tenant_ids=unpaid_tenant_ids, + ) + + +def _resolve_archive_time_range( + *, + before_days: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, +) -> tuple[int, datetime.datetime | None, datetime.datetime | None]: + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + + if (from_days_ago is None) ^ (to_days_ago is None): + raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.") + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + raise click.UsageError("Choose either day offsets or explicit dates, not both.") + if from_days_ago <= to_days_ago: + raise click.UsageError("--from-days-ago must be greater than --to-days-ago.") + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + if start_from and end_before and start_from >= end_before: + raise click.UsageError("--start-from must be earlier than --end-before.") + + return before_days, start_from, end_before + @click.command("clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs.") @click.option("--days", prompt=True, help="The days to clear free plan tenant expired logs.", default=30) @@ -139,11 +289,143 @@ def clean_workflow_runs( ) +@click.command( + "archive-workflow-runs-plan", + help="Plan workflow run archive rollout by tenant ID first hex digit.", +) +@click.option("--before-days", default=90, show_default=True, help="Plan runs older than N days.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Plan runs created at or after this timestamp (UTC if no timezone).", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Plan runs created before this timestamp (UTC if no timezone).", +) +@click.option( + "--include-archived", + is_flag=True, + help="Compatibility no-op for V2 bundle archive; plan counts source rows in the requested window.", +) +def archive_workflow_runs_plan( + before_days: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + include_archived: bool, +): + """ + Print the 16 tenant-prefix rollout rows used to choose archive execution order. + + Counts use the same workflow run eligibility as archive-workflow-runs: ended runs, + supported workflow types, and the requested created_at window. V2 bundle archive + does not maintain per-run archive logs, so this plan reports source-table volume. + """ + from graphon.enums import WorkflowExecutionStatus + from models.workflow import WorkflowNodeExecutionModel, WorkflowRun + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + before_days, start_from, end_before = _resolve_archive_time_range( + before_days=before_days, + from_days_ago=from_days_ago, + to_days_ago=to_days_ago, + start_from=start_from, + end_before=end_before, + ) + plan_end_before = end_before or datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=before_days) + if include_archived: + click.echo(click.style("--include-archived is a no-op for V2 bundle archive plans.", fg="yellow")) + + rows: list[WorkflowRunArchivePlanRow] = [] + for prefix in _HEX_PREFIXES: + tenant_ids = _get_archive_candidate_tenant_ids_by_prefix( + prefix, + start_from=start_from, + end_before=plan_end_before, + ) + total_tenants = len(tenant_ids) + paid_tenant_ids, unpaid_tenant_ids = _filter_paid_workflow_archive_tenant_ids(tenant_ids) + + run_conditions = [ + WorkflowRun.created_at < plan_end_before, + WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()), + WorkflowRun.type.in_(WorkflowRunArchiver.ARCHIVED_TYPE), + tenant_prefix_condition(WorkflowRun.tenant_id, prefix), + ] + if start_from is not None: + run_conditions.append(WorkflowRun.created_at >= start_from) + workflow_runs = ( + db.session.scalar(sa.select(sa.func.count()).select_from(WorkflowRun).where(*run_conditions)) or 0 + ) + candidate_runs = sa.select(WorkflowRun.id).where(*run_conditions).subquery() + workflow_node_executions = ( + db.session.scalar( + sa.select(sa.func.count()) + .select_from(WorkflowNodeExecutionModel) + .join(candidate_runs, WorkflowNodeExecutionModel.workflow_run_id == candidate_runs.c.id) + ) + or 0 + ) + + rows.append( + WorkflowRunArchivePlanRow( + tenant_prefix=prefix, + total_tenants=total_tenants, + workflow_runs=workflow_runs, + workflow_node_executions=workflow_node_executions, + paid_tenants=len(paid_tenant_ids), + unpaid_tenants=len(unpaid_tenant_ids), + ) + ) + + click.echo( + click.style( + f"Workflow archive plan for runs before {plan_end_before.isoformat()}" + f"{f' and at/after {start_from.isoformat()}' if start_from else ''}.", + fg="white", + ) + ) + click.echo("tenant_prefix,total_tenants,workflow_runs,workflow_node_executions,paid_tenants,unpaid_tenants") + for row in rows: + click.echo( + f"{row['tenant_prefix']},{row['total_tenants']},{row['workflow_runs']}," + f"{row['workflow_node_executions']},{row['paid_tenants']},{row['unpaid_tenants']}" + ) + + ordered_rows = sorted( + rows, + key=lambda row: (row["workflow_runs"] + row["workflow_node_executions"], row["tenant_prefix"]), + ) + click.echo("suggested_execution_order=" + ",".join(row["tenant_prefix"] for row in ordered_rows)) + + @click.command( "archive-workflow-runs", help="Archive workflow runs for paid plan tenants to S3-compatible storage.", ) @click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.") +@click.option( + "--tenant-prefixes", + default=None, + help="Optional comma-separated tenant ID first hex digits for rollout waves, e.g. 0,1,a,f.", +) @click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.") @click.option( "--from-days-ago", @@ -169,13 +451,36 @@ def clean_workflow_runs( default=None, help="Archive runs created before this timestamp (UTC if no timezone).", ) -@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.") -@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.") +@click.option("--batch-size", default=100, show_default=True, help="Maximum workflow runs per archive bundle.") +@click.option( + "--workers", + default=1, + show_default=True, + type=int, + help="Reserved; bundle archive currently runs serially.", +) +@click.option( + "--run-shard-index", + default=None, + type=click.IntRange(min=0), + help="Zero-based workflow run shard index for parallel cron jobs. Must be paired with --run-shard-total.", +) +@click.option( + "--run-shard-total", + default=None, + type=click.IntRange(min=1, max=16), + help="Total workflow run shard count for parallel cron jobs. Must be paired with --run-shard-index.", +) @click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.") @click.option("--dry-run", is_flag=True, help="Preview without archiving.") -@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.") +@click.option( + "--delete-after-archive", + is_flag=True, + help="Not supported by bundle archive; use a separate bundle delete workflow after validation.", +) def archive_workflow_runs( tenant_ids: str | None, + tenant_prefixes: str | None, before_days: int, from_days_ago: int | None, to_days_ago: int | None, @@ -183,6 +488,8 @@ def archive_workflow_runs( end_before: datetime.datetime | None, batch_size: int, workers: int, + run_shard_index: int | None, + run_shard_total: int | None, limit: int | None, dry_run: bool, delete_after_archive: bool, @@ -190,14 +497,19 @@ def archive_workflow_runs( """ Archive workflow runs for paid plan tenants older than the specified days. - This command archives the following tables to storage: + This command writes V2 tenant/month/shard archive bundles. Each bundle contains Parquet snapshots from: + - workflow_runs + - workflow_app_logs - workflow_node_executions - workflow_node_execution_offload - workflow_pauses - workflow_pause_reasons - workflow_trigger_logs - The workflow_runs and workflow_app_logs tables are preserved for UI listing. + Source database rows are always preserved by archive. Deletion must be handled by + a separate bundle-level delete workflow after manifest, checksum, row-count, and + restore-sampling validation. In --dry-run mode, no storage or database writes + happen; the command estimates per-table Parquet bytes and object size instead. """ from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver @@ -209,32 +521,58 @@ def archive_workflow_runs( ) ) - if (start_from is None) ^ (end_before is None): - click.echo(click.style("start-from and end-before must be provided together.", fg="red")) - return - - if (from_days_ago is None) ^ (to_days_ago is None): - click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red")) - return - - if from_days_ago is not None and to_days_ago is not None: - if start_from or end_before: - click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red")) - return - if from_days_ago <= to_days_ago: - click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red")) - return - now = datetime.datetime.now() - start_from = now - datetime.timedelta(days=from_days_ago) - end_before = now - datetime.timedelta(days=to_days_ago) - before_days = 0 - - if start_from and end_before and start_from >= end_before: - click.echo(click.style("start-from must be earlier than end-before.", fg="red")) + try: + before_days, start_from, end_before = _resolve_archive_time_range( + before_days=before_days, + from_days_ago=from_days_ago, + to_days_ago=to_days_ago, + start_from=start_from, + end_before=end_before, + ) + parsed_tenant_prefixes = _parse_tenant_prefixes(tenant_prefixes) + except click.UsageError as e: + click.echo(click.style(e.message, fg="red")) return + plan_end_before = end_before or datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=before_days) if workers < 1: click.echo(click.style("workers must be at least 1.", fg="red")) return + if (run_shard_index is None) ^ (run_shard_total is None): + click.echo(click.style("run-shard-index and run-shard-total must be provided together.", fg="red")) + return + if run_shard_index is not None and run_shard_total is not None and run_shard_index >= run_shard_total: + click.echo(click.style("run-shard-index must be less than run-shard-total.", fg="red")) + return + if delete_after_archive: + click.echo(click.style("delete-after-archive is not supported by bundle archive.", fg="red")) + return + + try: + tenant_plan = _resolve_archive_tenant_ids_from_plan( + tenant_ids=tenant_ids, + tenant_prefixes=parsed_tenant_prefixes, + start_from=start_from, + end_before=plan_end_before, + ) + except Exception: + logger.exception("Failed to resolve workflow archive tenant plan") + click.echo(click.style("Failed to resolve workflow archive tenant plan.", fg="red")) + return + + planned_tenant_ids = tenant_plan["archive_tenant_ids"] + planned_paid_tenant_ids = tenant_plan["paid_tenant_ids"] if planned_tenant_ids is not None else None + paid_tenants = len(tenant_plan["paid_tenant_ids"]) + unpaid_tenants = len(tenant_plan["unpaid_tenant_ids"]) + if planned_tenant_ids is not None: + click.echo( + click.style( + f"Resolved archive tenant plan: paid_tenants={paid_tenants}, unpaid_tenants={unpaid_tenants}.", + fg="white", + ) + ) + if not planned_tenant_ids: + click.echo(click.style("No paid tenants matched the archive plan; nothing to archive.", fg="yellow")) + return archiver = WorkflowRunArchiver( days=before_days, @@ -242,7 +580,11 @@ def archive_workflow_runs( start_from=start_from, end_before=end_before, workers=workers, - tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None, + tenant_ids=planned_tenant_ids, + tenant_prefixes=parsed_tenant_prefixes, + paid_tenant_ids=planned_paid_tenant_ids, + run_shard_index=run_shard_index, + run_shard_total=run_shard_total, limit=limit, dry_run=dry_run, delete_after_archive=delete_after_archive, @@ -252,7 +594,9 @@ def archive_workflow_runs( click.style( f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, " f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " - f"time={summary.total_elapsed_time:.2f}s", + f"bundles_archived={summary.bundles_archived}, bundles_skipped={summary.bundles_skipped}, " + f"bundles_failed={summary.bundles_failed}, " + f"object_size_bytes={summary.total_object_size_bytes}, time={summary.total_elapsed_time:.2f}s", fg="cyan", ) ) @@ -268,6 +612,52 @@ def archive_workflow_runs( ) +def _echo_bundle_archive_operation_summary(summary) -> None: + status = "completed successfully" if summary.bundles_failed == 0 else "completed with failures" + fg = "green" if summary.bundles_failed == 0 else "red" + click.echo( + click.style( + f"{summary.operation} {status}. " + f"bundles_success={summary.bundles_succeeded} bundles_failed={summary.bundles_failed} " + f"runs={summary.runs_processed} rows={summary.rows_processed} " + f"archive_bytes={summary.archive_bytes} duration={summary.elapsed_time:.2f}s " + f"validation_time={summary.validation_time:.2f}s " + f"runs_per_second={summary.runs_per_second:.2f} rows_per_second={summary.rows_per_second:.2f} " + f"bytes_per_second={summary.bytes_per_second:.2f}", + fg=fg, + ) + ) + click.echo(click.style("table,row_count", fg="white")) + for table_name in [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", + ]: + click.echo(f"{table_name},{summary.table_counts.get(table_name, 0)}") + for result in summary.results: + if result.success: + click.echo( + click.style( + f" bundle={result.bundle_id} tenant={result.tenant_id} runs={result.run_count} " + f"rows={result.row_count} archive_bytes={result.archive_bytes} " + f"time={result.elapsed_time:.2f}s validation={result.validation_time:.2f}s", + fg="white", + ) + ) + else: + click.echo( + click.style( + f" failed bundle={result.bundle_id} tenant={result.tenant_id} " + f"object_prefix={result.object_prefix} error={result.error}", + fg="red", + ) + ) + + @click.command( "restore-workflow-runs", help="Restore archived workflow runs from S3-compatible storage.", @@ -290,8 +680,8 @@ def archive_workflow_runs( default=None, help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", ) -@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.") -@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.") +@click.option("--workers", default=1, show_default=True, type=int, help="V1 --run-id compatibility only.") +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of V2 bundles to restore.") @click.option("--dry-run", is_flag=True, help="Preview without restoring.") def restore_workflow_runs( tenant_ids: str | None, @@ -303,15 +693,18 @@ def restore_workflow_runs( dry_run: bool, ): """ - Restore an archived workflow run from storage to the database. + Restore archived workflow runs from storage to the database. - This restores the following tables: + Batch restore uses V2 bundle metadata and validates archive objects before writing source rows. This restores: + - workflow_runs + - workflow_app_logs - workflow_node_executions - workflow_node_execution_offload - workflow_pauses - workflow_pause_reasons - workflow_trigger_logs """ + from services.retention.workflow_run.bundle_archive_maintenance import WorkflowRunBundleArchiveMaintenance from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore parsed_tenant_ids = None @@ -335,39 +728,46 @@ def restore_workflow_runs( ) ) - restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) if run_id: + restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) results = [restorer.restore_by_run_id(run_id)] - else: - assert start_from is not None - assert end_before is not None - results = restorer.restore_batch( - parsed_tenant_ids, - start_date=start_from, - end_date=end_before, - limit=limit, - ) + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time - end_time = datetime.datetime.now(datetime.UTC) - elapsed = end_time - start_time + successes = sum(1 for result in results if result.success) + failures = len(results) - successes - successes = sum(1 for result in results if result.success) - failures = len(results) - successes - - if failures == 0: - click.echo( - click.style( - f"Restore completed successfully. success={successes} duration={elapsed}", - fg="green", + if failures == 0: + click.echo( + click.style( + f"Restore completed successfully. success={successes} duration={elapsed}", + fg="green", + ) ) - ) - else: - click.echo( - click.style( - f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", - fg="red", + else: + click.echo( + click.style( + f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) ) + return + + if workers != 1: + click.echo( + click.style("--workers is ignored for V2 bundle restore; bundles are processed serially.", fg="yellow") ) + assert start_from is not None + assert end_before is not None + bundle_restorer = WorkflowRunBundleArchiveMaintenance(dry_run=dry_run, strict_content_validation=True) + summary = bundle_restorer.restore_batch( + tenant_ids=parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + _echo_bundle_archive_operation_summary(summary) + return @click.command( @@ -392,8 +792,20 @@ def restore_workflow_runs( default=None, help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", ) -@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.") +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of V2 bundles to delete.") @click.option("--dry-run", is_flag=True, help="Preview without deleting.") +@click.option( + "--skip-bad-archives", + is_flag=True, + help="Continue batch deletion when one archive object fails validation.", +) +@click.option( + "--restore-sample-interval", + type=int, + default=0, + show_default=True, + help="Run restore dry-run after every N successful deletes; 0 disables restore sampling.", +) def delete_archived_workflow_runs( tenant_ids: str | None, run_id: str | None, @@ -401,10 +813,16 @@ def delete_archived_workflow_runs( end_before: datetime.datetime | None, limit: int, dry_run: bool, + skip_bad_archives: bool, + restore_sample_interval: int, ): """ Delete archived workflow runs from the database. + + Batch delete uses V2 bundle metadata and validates object existence, manifest schema, object size, checksum, row + counts, and source/archive content checksums before deleting source rows. `--run-id` keeps the V1 per-run path. """ + from services.retention.workflow_run.bundle_archive_maintenance import WorkflowRunBundleArchiveMaintenance from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion parsed_tenant_ids = None @@ -417,6 +835,8 @@ def delete_archived_workflow_runs( raise click.UsageError("--start-from and --end-before must be provided together.") if run_id is None and (start_from is None or end_before is None): raise click.UsageError("--start-from and --end-before are required for batch delete.") + if restore_sample_interval < 0: + raise click.BadParameter("restore-sample-interval must be >= 0") start_time = datetime.datetime.now(datetime.UTC) target_desc = f"workflow run {run_id}" if run_id else "workflow runs" @@ -427,56 +847,85 @@ def delete_archived_workflow_runs( ) ) - deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run) if run_id: - results = [deleter.delete_by_run_id(run_id)] - else: - assert start_from is not None - assert end_before is not None - results = deleter.delete_batch( - parsed_tenant_ids, - start_date=start_from, - end_date=end_before, - limit=limit, + deleter = ArchivedWorkflowRunDeletion( + dry_run=dry_run, + skip_bad_archives=skip_bad_archives, + restore_sample_interval=restore_sample_interval, ) + results = [deleter.delete_by_run_id(run_id)] + for result in results: + if result.success: + click.echo( + click.style( + f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " + f"workflow run {result.run_id} (tenant={result.tenant_id}, " + f"archive_key={result.archive_key}, counts={result.validated_counts})", + fg="green", + ) + ) + if result.restore_sampled: + sample_status = "passed" if result.restore_sample_success else "failed" + click.echo( + click.style( + f" restore dry-run sample {sample_status} for workflow run {result.run_id}", + fg="green" if result.restore_sample_success else "red", + ) + ) + else: + click.echo( + click.style( + f"Failed to delete workflow run {result.run_id}: {result.error}", + fg="red", + ) + ) + click.echo( + click.style( + " runbook: pause this delete window, verify archive storage object and manifest/checksum, " + "retry the same run after fixing storage or DB drift, or rerun with --skip-bad-archives " + "to quarantine this run and continue the batch.", + fg="yellow", + ) + ) - for result in results: - if result.success: + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: click.echo( click.style( - f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " - f"workflow run {result.run_id} (tenant={result.tenant_id})", + f"Delete completed successfully. success={successes} duration={elapsed}", fg="green", ) ) else: click.echo( click.style( - f"Failed to delete workflow run {result.run_id}: {result.error}", + f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", fg="red", ) ) + return - end_time = datetime.datetime.now(datetime.UTC) - elapsed = end_time - start_time - - successes = sum(1 for result in results if result.success) - failures = len(results) - successes - - if failures == 0: - click.echo( - click.style( - f"Delete completed successfully. success={successes} duration={elapsed}", - fg="green", - ) - ) - else: - click.echo( - click.style( - f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", - fg="red", - ) - ) + if restore_sample_interval: + click.echo(click.style("--restore-sample-interval is ignored for V2 bundle delete.", fg="yellow")) + assert start_from is not None + assert end_before is not None + bundle_deleter = WorkflowRunBundleArchiveMaintenance( + dry_run=dry_run, + strict_content_validation=True, + stop_on_error=not skip_bad_archives, + ) + summary = bundle_deleter.delete_batch( + tenant_ids=parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + _echo_bundle_archive_operation_summary(summary) def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 6cd4b08b900..a85f6569978 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -5,6 +5,7 @@ def init_app(app: DifyApp): from commands import ( add_qdrant_index, archive_workflow_runs, + archive_workflow_runs_plan, backfill_plugin_auto_upgrade, clean_expired_messages, clean_workflow_runs, @@ -72,6 +73,7 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + archive_workflow_runs_plan, archive_workflow_runs, delete_archived_workflow_runs, restore_workflow_runs, diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 2659e550552..bc30e980619 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -290,7 +290,10 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): batch_size: int, run_types: Sequence[WorkflowType] | None = None, tenant_ids: Sequence[str] | None = None, + tenant_prefixes: Sequence[str] | None = None, workflow_ids: Sequence[str] | None = None, + run_shard_index: int | None = None, + run_shard_total: int | None = None, ) -> Sequence[WorkflowRun]: """ Fetch ended workflow runs in a time window for archival and clean batching. @@ -298,7 +301,9 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): Optional filters: - run_types - tenant_ids + - tenant_prefixes, using the first hexadecimal digit of tenant_id for rollout waves - workflow_ids + - run_shard_index/run_shard_total, using a deterministic workflow_run_id shard """ ... diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index b40eb4bdd8a..2394377c9d4 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -56,6 +56,7 @@ from repositories.types import ( DailyTerminalsStats, DailyTokenCostStats, ) +from services.retention.workflow_run.tenant_prefix import tenant_prefix_condition logger = logging.getLogger(__name__) @@ -64,6 +65,40 @@ class _WorkflowRunError(Exception): pass +_HEX_SHARD_VALUES = { + "0": 0, + "1": 1, + "2": 2, + "3": 3, + "4": 4, + "5": 5, + "6": 6, + "7": 7, + "8": 8, + "9": 9, + "a": 10, + "b": 11, + "c": 12, + "d": 13, + "e": 14, + "f": 15, +} + + +def _tenant_prefix_condition(prefixes: Sequence[str]) -> sa.ColumnElement[bool]: + conditions = [tenant_prefix_condition(WorkflowRun.tenant_id, prefix) for prefix in prefixes] + return sa.or_(*conditions) + + +def _workflow_run_id_shard_expr() -> sa.ColumnElement[int]: + normalized_id = func.lower(func.replace(sa.cast(WorkflowRun.id, sa.String()), "-", "")) + last_hex = func.substr(normalized_id, func.length(normalized_id), 1) + return sa.case( + *[(last_hex == hex_digit, shard_value) for hex_digit, shard_value in _HEX_SHARD_VALUES.items()], + else_=0, + ) + + def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, @@ -378,7 +413,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): batch_size: int, run_types: Sequence[WorkflowType] | None = None, tenant_ids: Sequence[str] | None = None, + tenant_prefixes: Sequence[str] | None = None, workflow_ids: Sequence[str] | None = None, + run_shard_index: int | None = None, + run_shard_total: int | None = None, ) -> Sequence[WorkflowRun]: """ Fetch ended workflow runs in a time window for archival and clean batching. @@ -387,7 +425,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): - created_at in [start_from, end_before) - type in run_types (when provided) - status is an ended state - - optional tenant_id, workflow_id filters and cursor (last_seen) for pagination + - optional tenant_id, tenant_prefix, workflow_id filters and cursor (last_seen) for pagination + - optional deterministic shard by the last hexadecimal digit of workflow_run_id """ with self._session_maker() as session: stmt = ( @@ -410,9 +449,15 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if tenant_ids: stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids)) + if tenant_prefixes: + stmt = stmt.where(_tenant_prefix_condition(tenant_prefixes)) + if workflow_ids: stmt = stmt.where(WorkflowRun.workflow_id.in_(workflow_ids)) + if run_shard_index is not None and run_shard_total is not None: + stmt = stmt.where((_workflow_run_id_shard_expr() % run_shard_total) == run_shard_index) + if last_seen: stmt = stmt.where( tuple_(WorkflowRun.created_at, WorkflowRun.id) diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index 21be411bea7..e046013b2ce 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -1,8 +1,12 @@ """ Archive Paid Plan Workflow Run Logs Service. -This service archives workflow run logs for paid plan users older than the configured -retention period (default: 90 days) to S3-compatible storage. +This service archives workflow run logs for paid plan users older than the configured retention period (default: +90 days) to S3-compatible storage. + +Archive V2 writes bundle-level Parquet objects. A bundle contains many workflow runs and their related table rows. +Bundle metadata lives in the object-store manifest instead of a database table, so archive/delete/restore does not move +the large-table retention problem into another OLTP table. Archived tables: - workflow_runs @@ -16,18 +20,19 @@ Archived tables: """ import datetime -import io +import hashlib import json import logging import time -import zipfile from collections.abc import Sequence -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field +from enum import Enum from typing import Any, TypedDict import click -from sqlalchemy import inspect +import pyarrow as pa +import pyarrow.parquet as pq +from sqlalchemy import inspect, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -39,12 +44,24 @@ from libs.archive_storage import ( ArchiveStorageNotConfiguredError, get_archive_storage, ) -from models.workflow import WorkflowAppLog, WorkflowRun +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPause, + WorkflowPauseReason, + WorkflowRun, +) from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.billing_service import BillingService -from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION +from services.retention.workflow_run.constants import ( + ARCHIVE_BUNDLE_FORMAT, + ARCHIVE_BUNDLE_MANIFEST_NAME, + ARCHIVE_BUNDLE_SCHEMA_VERSION, +) logger = logging.getLogger(__name__) @@ -53,17 +70,41 @@ class TableStatsManifestEntry(TypedDict): row_count: int checksum: str size_bytes: int + object_key: str class ArchiveManifestDict(TypedDict): schema_version: str - workflow_run_id: str + archive_format: str tenant_id: str - app_id: str - workflow_id: str - created_at: str + tenant_prefix: str + year: int + month: int + shard: str + bundle_id: str + object_prefix: str + workflow_run_count: int + workflow_node_execution_count: int + min_created_at: str + max_created_at: str + min_run_id: str + max_run_id: str archived_at: str tables: dict[str, TableStatsManifestEntry] + run_ids: list[str] + + +@dataclass(frozen=True) +class ArchiveBundleIdentity: + """Stable identity and object prefix for one V2 archive bundle.""" + + tenant_prefix: str + tenant_id: str + year: int + month: int + shard: str + bundle_id: str + object_prefix: str @dataclass @@ -74,16 +115,21 @@ class TableStats: row_count: int checksum: str size_bytes: int + object_key: str = "" @dataclass class ArchiveResult: - """Result of archiving a single workflow run.""" + """Result of archiving a bundle of workflow runs.""" - run_id: str + bundle_id: str tenant_id: str + object_prefix: str success: bool + run_count: int = 0 tables: list[TableStats] = field(default_factory=list) + object_size_bytes: int = 0 + skipped: bool = False error: str | None = None elapsed_time: float = 0.0 @@ -96,6 +142,12 @@ class ArchiveSummary: runs_archived: int = 0 runs_skipped: int = 0 runs_failed: int = 0 + total_bundles_processed: int = 0 + bundles_archived: int = 0 + bundles_skipped: int = 0 + bundles_failed: int = 0 + total_object_size_bytes: int = 0 + table_stats: dict[str, TableStats] = field(default_factory=dict) total_elapsed_time: float = 0.0 @@ -104,16 +156,20 @@ class WorkflowRunArchiver: Archive workflow run logs for paid plan users. Storage Layout: - {tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/ - └── archive.v1.0.zip + workflow-runs/v2/tenant_prefix={prefix}/tenant_id={tenant_id}/year={YYYY}/month={MM}/ + shard={shard}/bundle={bundle_id}/ ├── manifest.json - ├── workflow_runs.jsonl - ├── workflow_app_logs.jsonl - ├── workflow_node_executions.jsonl - ├── workflow_node_execution_offload.jsonl - ├── workflow_pauses.jsonl - ├── workflow_pause_reasons.jsonl - └── workflow_trigger_logs.jsonl + ├── workflow_runs.parquet + ├── workflow_app_logs.parquet + ├── workflow_node_executions.parquet + ├── workflow_node_execution_offload.parquet + ├── workflow_pauses.parquet + ├── workflow_pause_reasons.parquet + └── workflow_trigger_logs.parquet + + `batch_size` is the maximum workflow_runs per bundle. The current implementation groups each fetched page by + tenant/month before writing bundles. Bundle idempotency is based on the manifest object key; the manifest is + uploaded after all table objects, so a missing manifest means the bundle should be retried. """ ARCHIVED_TYPE = [ @@ -132,6 +188,10 @@ class WorkflowRunArchiver: start_from: datetime.datetime | None end_before: datetime.datetime + paid_tenant_ids: set[str] | None + tenant_prefixes: list[str] + run_shard_index: int | None + run_shard_total: int | None def __init__( self, @@ -141,6 +201,10 @@ class WorkflowRunArchiver: end_before: datetime.datetime | None = None, workers: int = 1, tenant_ids: Sequence[str] | None = None, + tenant_prefixes: Sequence[str] | None = None, + paid_tenant_ids: Sequence[str] | None = None, + run_shard_index: int | None = None, + run_shard_total: int | None = None, limit: int | None = None, dry_run: bool = False, delete_after_archive: bool = False, @@ -156,10 +220,19 @@ class WorkflowRunArchiver: end_before: Optional end time (exclusive) for archiving workers: Number of concurrent workflow runs to archive tenant_ids: Optional tenant IDs for grayscale rollout + tenant_prefixes: Optional tenant ID first-hex prefixes for rollout waves. CLI callers should resolve these + to tenant_ids during planning so workflow_runs scan uses tenant_id IN (...) instead of a prefix range. + paid_tenant_ids: Optional paid-tenant whitelist resolved by the archive plan. When provided, archive uses it + for per-run paid filtering and does not call billing on every fetched page. + run_shard_index: Optional zero-based workflow run shard index for parallel cron jobs + run_shard_total: Optional total workflow run shard count for parallel cron jobs limit: Maximum number of runs to archive (None for unlimited) dry_run: If True, only preview without making changes - delete_after_archive: If True, delete runs and related data after archiving + delete_after_archive: Reserved for the V1 per-run path. Bundle archive requires a separate validated + bundle delete workflow. """ + if delete_after_archive: + raise ValueError("delete_after_archive is not supported by bundle archive") self.days = days self.batch_size = batch_size if start_from or end_before: @@ -176,6 +249,16 @@ class WorkflowRunArchiver: raise ValueError("workers must be at least 1") self.workers = workers self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else [] + self.tenant_prefixes = sorted(set(tenant_prefixes)) if tenant_prefixes else [] + self.paid_tenant_ids = set(paid_tenant_ids) if paid_tenant_ids is not None else None + if (run_shard_index is None) ^ (run_shard_total is None): + raise ValueError("run_shard_index and run_shard_total must be provided together") + if run_shard_total is not None and not 1 <= run_shard_total <= 16: + raise ValueError("run_shard_total must be between 1 and 16") + if run_shard_index is not None and run_shard_total is not None and not 0 <= run_shard_index < run_shard_total: + raise ValueError("run_shard_index must be between 0 and run_shard_total - 1") + self.run_shard_index = run_shard_index + self.run_shard_total = run_shard_total self.limit = limit self.dry_run = dry_run self.delete_after_archive = delete_after_archive @@ -209,124 +292,185 @@ class WorkflowRunArchiver: return summary session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) - repo = self._get_workflow_run_repo() + attempted_count = 0 - def _archive_with_session(run: WorkflowRun) -> ArchiveResult: - with session_maker() as session: - return self._archive_run(session, storage, run) - - last_seen: tuple[datetime.datetime, str] | None = None - archived_count = 0 - - with ThreadPoolExecutor(max_workers=self.workers) as executor: + for tenant_scope in self._tenant_scan_scopes(): + last_seen: tuple[datetime.datetime, str] | None = None while True: - # Check limit - if self.limit and archived_count >= self.limit: + if self.limit and attempted_count >= self.limit: click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow")) break - # Fetch batch of runs - runs = self._get_runs_batch(last_seen) - + runs = self._get_runs_batch(last_seen, tenant_scope=tenant_scope) if not runs: break - run_ids = [run.id for run in runs] - with session_maker() as session: - archived_run_ids = repo.get_archived_run_ids(session, run_ids) - last_seen = (runs[-1].created_at, runs[-1].id) - - # Filter to paid tenants only tenant_ids = {run.tenant_id for run in runs} paid_tenants = self._filter_paid_tenants(tenant_ids) runs_to_process: list[WorkflowRun] = [] for run in runs: summary.total_runs_processed += 1 - - # Skip non-paid tenants if run.tenant_id not in paid_tenants: summary.runs_skipped += 1 continue - - # Skip already archived runs - if run.id in archived_run_ids: - summary.runs_skipped += 1 - continue - - # Check limit - if self.limit and archived_count + len(runs_to_process) >= self.limit: + if self.limit and attempted_count + len(runs_to_process) >= self.limit: break - runs_to_process.append(run) if not runs_to_process: continue - results = list(executor.map(_archive_with_session, runs_to_process)) + for bundle_runs in self._group_runs_for_bundles(runs_to_process): + summary.total_bundles_processed += 1 + with session_maker() as session: + result = self._archive_bundle(session, storage, bundle_runs) - for run, result in zip(runs_to_process, results): - if result.success: - summary.runs_archived += 1 - archived_count += 1 + if result.skipped: + attempted_count += result.run_count + summary.bundles_skipped += 1 + summary.runs_skipped += result.run_count + click.echo( + click.style( + f"Skipped bundle {result.bundle_id} (tenant={result.tenant_id}, " + f"runs={result.run_count}, reason={result.error or 'already handled'})", + fg="yellow", + ) + ) + elif result.success: + attempted_count += result.run_count + summary.bundles_archived += 1 + summary.runs_archived += result.run_count + self._merge_result_stats(summary, result) click.echo( click.style( f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} " - f"run {run.id} (tenant={run.tenant_id}, " - f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)", + f"bundle {result.bundle_id} (tenant={result.tenant_id}, runs={result.run_count}, " + f"tables={len(result.tables)}, object_size_bytes={result.object_size_bytes}, " + f"time={result.elapsed_time:.2f}s)", fg="green", ) ) + if self.dry_run: + self._echo_table_estimates(result.tables) else: - summary.runs_failed += 1 + attempted_count += result.run_count + summary.bundles_failed += 1 + summary.runs_failed += result.run_count click.echo( click.style( - f"Failed to archive run {run.id}: {result.error}", + f"Failed to archive bundle {result.bundle_id}: {result.error}", fg="red", ) ) + if self.limit and attempted_count >= self.limit: + break + summary.total_elapsed_time = time.time() - start_time click.echo( click.style( f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: " f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, " f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"bundles_archived={summary.bundles_archived}, bundles_skipped={summary.bundles_skipped}, " + f"bundles_failed={summary.bundles_failed}, " + f"object_size_bytes={summary.total_object_size_bytes}, " f"time={summary.total_elapsed_time:.2f}s", fg="white", ) ) + if self.dry_run: + self._echo_summary_estimates(summary) return summary + @staticmethod + def _merge_result_stats(summary: ArchiveSummary, result: ArchiveResult) -> None: + summary.total_object_size_bytes += result.object_size_bytes + for table_stat in result.tables: + summary_stat = summary.table_stats.get(table_stat.table_name) + if summary_stat is None: + summary.table_stats[table_stat.table_name] = TableStats( + table_name=table_stat.table_name, + row_count=table_stat.row_count, + checksum="", + size_bytes=table_stat.size_bytes, + ) + continue + summary_stat.row_count += table_stat.row_count + summary_stat.size_bytes += table_stat.size_bytes + + @staticmethod + def _echo_table_estimates(table_stats: Sequence[TableStats]) -> None: + for stat in table_stats: + click.echo( + click.style( + f" table={stat.table_name} rows={stat.row_count} parquet_bytes={stat.size_bytes}", + fg="white", + ) + ) + + def _echo_summary_estimates(self, summary: ArchiveSummary) -> None: + click.echo(click.style("[DRY RUN] Estimated archive totals by table:", fg="white")) + for table_name in self.ARCHIVED_TABLES: + stat = summary.table_stats.get(table_name) + row_count = stat.row_count if stat else 0 + size_bytes = stat.size_bytes if stat else 0 + click.echo(click.style(f" table={table_name} rows={row_count} parquet_bytes={size_bytes}", fg="white")) + def _get_runs_batch( self, last_seen: tuple[datetime.datetime, str] | None, + tenant_scope: Sequence[str] | None = None, ) -> Sequence[WorkflowRun]: """Fetch a batch of workflow runs to archive.""" repo = self._get_workflow_run_repo() + tenant_ids = list(tenant_scope) if tenant_scope is not None else self.tenant_ids or None return repo.get_runs_batch_by_time_range( start_from=self.start_from, end_before=self.end_before, last_seen=last_seen, batch_size=self.batch_size, run_types=self.ARCHIVED_TYPE, - tenant_ids=self.tenant_ids or None, + tenant_ids=tenant_ids, + tenant_prefixes=None if tenant_ids else self.tenant_prefixes or None, + run_shard_index=self.run_shard_index, + run_shard_total=self.run_shard_total, ) + def _tenant_scan_scopes(self) -> list[list[str] | None]: + if not self.tenant_ids: + return [None] + return [[tenant_id] for tenant_id in self.tenant_ids] + def _build_start_message(self) -> str: range_desc = f"before {self.end_before.isoformat()}" if self.start_from: range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}" + run_shard_desc = "all" + if self.run_shard_index is not None and self.run_shard_total is not None: + run_shard_desc = f"{self.run_shard_index}/{self.run_shard_total}" return ( f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving " f"for runs {range_desc} " - f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})" + f"(batch_size={self.batch_size}, tenant_ids={self._format_tenant_scope()}, " + f"tenant_prefixes={','.join(self.tenant_prefixes) or 'all'}, run_shard={run_shard_desc})" ) + def _format_tenant_scope(self) -> str: + if not self.tenant_ids: + return "all" + if len(self.tenant_ids) <= 10: + return ",".join(self.tenant_ids) + return f"{len(self.tenant_ids)} planned tenants" + def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]: """Filter tenant IDs to only include paid tenants.""" + if self.paid_tenant_ids is not None: + return tenant_ids & self.paid_tenant_ids + if not dify_config.BILLING_ENABLED: # If billing is not enabled, treat all tenants as paid return tenant_ids @@ -349,177 +493,293 @@ class WorkflowRunArchiver: return paid - def _archive_run( + def _archive_bundle( self, session: Session, storage: ArchiveStorage | None, - run: WorkflowRun, + runs: Sequence[WorkflowRun], ) -> ArchiveResult: - """Archive a single workflow run.""" + """Archive one tenant/month bundle of workflow runs.""" + if not runs: + raise ValueError("runs must not be empty") start_time = time.time() - result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + identity = self._build_bundle_identity(runs) + result = ArchiveResult( + bundle_id=identity.bundle_id, + tenant_id=identity.tenant_id, + object_prefix=identity.object_prefix, + run_count=len(runs), + success=False, + ) try: - # Extract data from all tables - table_data, app_logs, trigger_metadata = self._extract_data(session, run) + if not self.dry_run: + if storage is None: + raise ArchiveStorageNotConfiguredError("Archive storage not configured") + if storage.object_exists(self._get_manifest_object_key(identity)): + result.success = True + result.skipped = True + result.error = "bundle already archived" + result.elapsed_time = time.time() - start_time + return result + + locked_runs = self._lock_runs_for_archive(session, [run.id for run in runs]) + if len(locked_runs) != len(runs): + result.success = True + result.skipped = True + result.error = "one or more runs locked or deleted by another archiver" + result.elapsed_time = time.time() - start_time + return result + runs = locked_runs + + table_data = self._extract_bundle_data(session, runs) + table_stats, table_payloads, manifest_data = self._build_archive_payload(identity, runs, table_data) + object_size = len(manifest_data) + sum(len(payload) for payload in table_payloads.values()) if self.dry_run: - # In dry run, just report what would be archived - for table_name in self.ARCHIVED_TABLES: - records = table_data.get(table_name, []) - result.tables.append( - TableStats( - table_name=table_name, - row_count=len(records), - checksum="", - size_bytes=0, - ) - ) + result.tables = table_stats + result.object_size_bytes = object_size result.success = True else: if storage is None: raise ArchiveStorageNotConfiguredError("Archive storage not configured") - archive_key = self._get_archive_key(run) - # Serialize tables for the archive bundle - table_stats: list[TableStats] = [] - table_payloads: dict[str, bytes] = {} - for table_name in self.ARCHIVED_TABLES: - records = table_data.get(table_name, []) - data = ArchiveStorage.serialize_to_jsonl(records) - table_payloads[table_name] = data - checksum = ArchiveStorage.compute_checksum(data) - - table_stats.append( - TableStats( - table_name=table_name, - row_count=len(records), - checksum=checksum, - size_bytes=len(data), - ) - ) - - # Generate and upload archive bundle - manifest = self._generate_manifest(run, table_stats) - manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8") - archive_data = self._build_archive_bundle(manifest_data, table_payloads) - storage.put_object(archive_key, archive_data) - - repo = self._get_workflow_run_repo() - archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata) + for table_name, payload in table_payloads.items(): + storage.put_object(self._get_table_object_key(identity, table_name), payload) + storage.put_object(self._get_manifest_object_key(identity), manifest_data) session.commit() - deleted_counts = None - if self.delete_after_archive: - deleted_counts = repo.delete_runs_with_related( - [run], - delete_node_executions=self._delete_node_executions, - delete_trigger_logs=self._delete_trigger_logs, - ) - logger.info( - "Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s", - run.id, + "Archived workflow run bundle %s: tenant=%s runs=%s tables=%s object_prefix=%s", + identity.bundle_id, + identity.tenant_id, + len(runs), {s.table_name: s.row_count for s in table_stats}, - archived_log_count, - deleted_counts, + identity.object_prefix, ) result.tables = table_stats + result.object_size_bytes = object_size result.success = True except Exception as e: - logger.exception("Failed to archive workflow run %s", run.id) + logger.exception("Failed to archive workflow run bundle %s", identity.bundle_id) result.error = str(e) session.rollback() result.elapsed_time = time.time() - start_time return result - def _extract_data( + def _lock_runs_for_archive( self, session: Session, - run: WorkflowRun, - ) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]: + run_ids: Sequence[str], + ) -> list[WorkflowRun]: + """ + Lock workflow runs before archiving a bundle. + + Parallel cron jobs may select overlapping pages. Row-level SKIP LOCKED keeps duplicate archivers from uploading + conflicting bundle objects for the same source rows. + """ + if not run_ids: + return [] + stmt = ( + select(WorkflowRun) + .where(WorkflowRun.id.in_(run_ids)) + .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc()) + .with_for_update(skip_locked=True) + ) + return list(session.scalars(stmt)) + + def _extract_bundle_data( + self, + session: Session, + runs: Sequence[WorkflowRun], + ) -> dict[str, list[dict[str, Any]]]: + """Extract all archived table rows for a bundle.""" + run_ids = [run.id for run in runs] table_data: dict[str, list[dict[str, Any]]] = {} - table_data["workflow_runs"] = [self._row_to_dict(run)] - repo = self._get_workflow_run_repo() - app_logs = repo.get_app_logs_by_run_id(session, run.id) + table_data["workflow_runs"] = [self._row_to_dict(run) for run in runs] + + app_logs = list(session.scalars(select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)))) table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs] - node_exec_repo = self._get_workflow_node_execution_repo(session) - node_exec_records = node_exec_repo.get_executions_by_workflow_run( - tenant_id=run.tenant_id, - app_id=run.app_id, - workflow_run_id=run.id, + + node_exec_records = list( + session.scalars( + select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)) + ) ) node_exec_ids = [record.id for record in node_exec_records] - offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids) + offload_records = [] + if node_exec_ids: + offload_records = list( + session.scalars( + select(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_exec_ids) + ) + ) + ) table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records] table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records] - repo = self._get_workflow_run_repo() - pause_records = repo.get_pause_records_by_run_id(session, run.id) + + pause_records = list(session.scalars(select(WorkflowPause).where(WorkflowPause.workflow_run_id.in_(run_ids)))) pause_ids = [pause.id for pause in pause_records] - pause_reason_records = repo.get_pause_reason_records_by_run_id( - session, - pause_ids, - ) + pause_reason_records = [] + if pause_ids: + pause_reason_records = list( + session.scalars(select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))) + ) table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records] table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records] + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) - trigger_records = trigger_repo.list_by_run_id(run.id) + trigger_records: list[WorkflowTriggerLog] = [] + for run_id in run_ids: + trigger_records.extend(trigger_repo.list_by_run_id(run_id)) table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records] - trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None - return table_data, app_logs, trigger_metadata + return table_data @staticmethod def _row_to_dict(row: Any) -> dict[str, Any]: mapper = inspect(row).mapper return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns} - def _get_archive_key(self, run: WorkflowRun) -> str: - """Get the storage key for the archive bundle.""" - created_at = run.created_at - prefix = ( - f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/" - f"month={created_at.strftime('%m')}/workflow_run_id={run.id}" - ) - return f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + def _build_archive_payload( + self, + identity: ArchiveBundleIdentity, + runs: Sequence[WorkflowRun], + table_data: dict[str, list[dict[str, Any]]], + ) -> tuple[list[TableStats], dict[str, bytes], bytes]: + """Build the archive payload and size stats without writing it to storage.""" + table_stats: list[TableStats] = [] + table_payloads: dict[str, bytes] = {} + for table_name in self.ARCHIVED_TABLES: + records = table_data.get(table_name, []) + data = self._serialize_to_parquet(records) + table_payloads[table_name] = data + checksum = ArchiveStorage.compute_checksum(data) + + table_stats.append( + TableStats( + table_name=table_name, + row_count=len(records), + checksum=checksum, + size_bytes=len(data), + object_key=self._get_table_object_key(identity, table_name), + ) + ) + + manifest = self._generate_manifest(identity, runs, table_stats) + manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8") + return table_stats, table_payloads, manifest_data def _generate_manifest( self, - run: WorkflowRun, + identity: ArchiveBundleIdentity, + runs: Sequence[WorkflowRun], table_stats: list[TableStats], ) -> ArchiveManifestDict: - """Generate a manifest for the archived workflow run.""" + """Generate a manifest for the archived workflow run bundle.""" tables: dict[str, TableStatsManifestEntry] = { stat.table_name: { "row_count": stat.row_count, "checksum": stat.checksum, "size_bytes": stat.size_bytes, + "object_key": stat.object_key, } for stat in table_stats } + sorted_runs = sorted(runs, key=lambda run: (run.created_at, run.id)) return ArchiveManifestDict( - schema_version=ARCHIVE_SCHEMA_VERSION, - workflow_run_id=run.id, - tenant_id=run.tenant_id, - app_id=run.app_id, - workflow_id=run.workflow_id, - created_at=run.created_at.isoformat(), + schema_version=ARCHIVE_BUNDLE_SCHEMA_VERSION, + archive_format=ARCHIVE_BUNDLE_FORMAT, + tenant_id=identity.tenant_id, + tenant_prefix=identity.tenant_prefix, + year=identity.year, + month=identity.month, + shard=identity.shard, + bundle_id=identity.bundle_id, + object_prefix=identity.object_prefix, + workflow_run_count=len(runs), + workflow_node_execution_count=tables["workflow_node_executions"]["row_count"], + min_created_at=sorted_runs[0].created_at.isoformat(), + max_created_at=sorted_runs[-1].created_at.isoformat(), + min_run_id=min(run.id for run in runs), + max_run_id=max(run.id for run in runs), archived_at=datetime.datetime.now(datetime.UTC).isoformat(), tables=tables, + run_ids=[run.id for run in sorted_runs], ) - def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes: - buffer = io.BytesIO() - with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: - archive.writestr("manifest.json", manifest_data) - for table_name in self.ARCHIVED_TABLES: - data = table_payloads.get(table_name) - if data is None: - raise ValueError(f"Missing archive payload for {table_name}") - archive.writestr(f"{table_name}.jsonl", data) - return buffer.getvalue() + @staticmethod + def _serialize_to_parquet(records: list[dict[str, Any]]) -> bytes: + normalized_records = [WorkflowRunArchiver._normalize_record_for_parquet(record) for record in records] + table = pa.Table.from_pylist(normalized_records) if normalized_records else pa.table({}) + sink = pa.BufferOutputStream() + pq.write_table(table, sink, compression="zstd") + return sink.getvalue().to_pybytes() + + @staticmethod + def _normalize_record_for_parquet(record: dict[str, Any]) -> dict[str, Any]: + def normalize(value: Any) -> Any: + if isinstance(value, Enum): + return value.value + if isinstance(value, dict | list): + return json.dumps(value, default=str, ensure_ascii=False) + return value + + return {key: normalize(value) for key, value in record.items()} + + def _group_runs_for_bundles(self, runs: Sequence[WorkflowRun]) -> list[list[WorkflowRun]]: + """Group a fetched page into tenant/month bundles.""" + grouped: dict[tuple[str, int, int], list[WorkflowRun]] = {} + for run in runs: + key = (run.tenant_id, run.created_at.year, run.created_at.month) + grouped.setdefault(key, []).append(run) + return [sorted(group, key=lambda run: (run.created_at, run.id)) for group in grouped.values()] + + def _build_bundle_identity(self, runs: Sequence[WorkflowRun]) -> ArchiveBundleIdentity: + """Build the object-store identity for a bundle.""" + sorted_runs = sorted(runs, key=lambda run: (run.created_at, run.id)) + first_run = sorted_runs[0] + tenant_ids = {run.tenant_id for run in sorted_runs} + if len(tenant_ids) != 1: + raise ValueError("archive bundle cannot span multiple tenants") + years_months = {(run.created_at.year, run.created_at.month) for run in sorted_runs} + if len(years_months) != 1: + raise ValueError("archive bundle cannot span multiple months") + + run_ids_digest = hashlib.sha256(",".join(run.id for run in sorted_runs).encode("utf-8")).hexdigest() + tenant_prefix = first_run.tenant_id[0].lower() + shard = self._bundle_shard_name() + year, month = next(iter(years_months)) + bundle_id = run_ids_digest[:16] + object_prefix = ( + f"workflow-runs/v2/tenant_prefix={tenant_prefix}/tenant_id={first_run.tenant_id}/" + f"year={year:04d}/month={month:02d}/shard={shard}/bundle={bundle_id}" + ) + return ArchiveBundleIdentity( + tenant_prefix=tenant_prefix, + tenant_id=first_run.tenant_id, + year=year, + month=month, + shard=shard, + bundle_id=bundle_id, + object_prefix=object_prefix, + ) + + def _bundle_shard_name(self) -> str: + if self.run_shard_index is None or self.run_shard_total is None: + return "00-of-01" + return f"{self.run_shard_index:02d}-of-{self.run_shard_total:02d}" + + @staticmethod + def _get_table_object_key(identity: ArchiveBundleIdentity, table_name: str) -> str: + return f"{identity.object_prefix}/{table_name}.parquet" + + @staticmethod + def _get_manifest_object_key(identity: ArchiveBundleIdentity) -> str: + return f"{identity.object_prefix}/{ARCHIVE_BUNDLE_MANIFEST_NAME}" def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) diff --git a/api/services/retention/workflow_run/bundle_archive_maintenance.py b/api/services/retention/workflow_run/bundle_archive_maintenance.py new file mode 100644 index 00000000000..8678e044c5e --- /dev/null +++ b/api/services/retention/workflow_run/bundle_archive_maintenance.py @@ -0,0 +1,872 @@ +""" +Maintain V2 workflow-run archive bundles. + +Archive V2 keeps bundle metadata in object-store manifests, not in a database table. This module discovers bundles by +listing `manifest.json` objects, uses object-store marker files for delete/restore state, and only touches the database +for source-table validation, deletion, and restoration. + +Each bundle is processed in its own database transaction. A failed bundle leaves source rows unchanged unless the +transaction has already committed; marker handling makes the next run able to reconcile the common committed-but-marker +not-updated case. +""" + +import datetime +import io +import json +import logging +import time +from collections.abc import Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, TypedDict, cast + +import pyarrow.parquet as pq +import sqlalchemy as sa +from sqlalchemy import delete, func, inspect, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_database import db +from libs.archive_storage import ArchiveStorage, ArchiveStorageNotConfiguredError, get_archive_storage +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPause, + WorkflowPauseReason, + WorkflowRun, +) +from services.retention.workflow_run.constants import ( + ARCHIVE_BUNDLE_DELETE_STARTED_MARKER_NAME, + ARCHIVE_BUNDLE_DELETED_MARKER_NAME, + ARCHIVE_BUNDLE_FORMAT, + ARCHIVE_BUNDLE_MANIFEST_NAME, + ARCHIVE_BUNDLE_RESTORE_STARTED_MARKER_NAME, + ARCHIVE_BUNDLE_RESTORED_MARKER_NAME, + ARCHIVE_BUNDLE_SCHEMA_VERSION, +) + +logger = logging.getLogger(__name__) + +_ARCHIVE_ROOT_PREFIX = "workflow-runs/v2/" +_CHUNK_SIZE = 5_000 + + +class TableManifestEntry(TypedDict): + row_count: int + checksum: str + size_bytes: int + object_key: str + + +class BundleManifest(TypedDict): + schema_version: str + archive_format: str + tenant_id: str + tenant_prefix: str + year: int + month: int + shard: str + bundle_id: str + object_prefix: str + workflow_run_count: int + workflow_node_execution_count: int + min_created_at: str + max_created_at: str + min_run_id: str + max_run_id: str + archived_at: str + tables: dict[str, TableManifestEntry] + run_ids: list[str] + + +@dataclass(frozen=True) +class BundleReference: + """Object-store reference for one V2 archive bundle.""" + + object_prefix: str + manifest_key: str + manifest: BundleManifest + + +@dataclass +class BundleOperationResult: + """Result for one V2 bundle delete or restore operation.""" + + bundle_id: str + tenant_id: str + object_prefix: str + success: bool = False + table_counts: dict[str, int] = field(default_factory=dict) + archive_bytes: int = 0 + elapsed_time: float = 0.0 + validation_time: float = 0.0 + error: str | None = None + + @property + def run_count(self) -> int: + return self.table_counts.get("workflow_runs", 0) + + @property + def row_count(self) -> int: + return sum(self.table_counts.values()) + + +@dataclass +class BundleOperationSummary: + """Aggregate metrics for a V2 bundle maintenance command.""" + + operation: str + bundles_processed: int = 0 + bundles_succeeded: int = 0 + bundles_failed: int = 0 + rows_processed: int = 0 + runs_processed: int = 0 + archive_bytes: int = 0 + elapsed_time: float = 0.0 + validation_time: float = 0.0 + table_counts: dict[str, int] = field(default_factory=dict) + results: list[BundleOperationResult] = field(default_factory=list) + + @property + def runs_per_second(self) -> float: + if self.elapsed_time <= 0: + return 0.0 + return self.runs_processed / self.elapsed_time + + @property + def rows_per_second(self) -> float: + if self.elapsed_time <= 0: + return 0.0 + return self.rows_processed / self.elapsed_time + + @property + def bytes_per_second(self) -> float: + if self.elapsed_time <= 0: + return 0.0 + return self.archive_bytes / self.elapsed_time + + +TABLE_MODELS: dict[str, Any] = { + "workflow_runs": WorkflowRun, + "workflow_app_logs": WorkflowAppLog, + "workflow_node_executions": WorkflowNodeExecutionModel, + "workflow_node_execution_offload": WorkflowNodeExecutionOffload, + "workflow_pauses": WorkflowPause, + "workflow_pause_reasons": WorkflowPauseReason, + "workflow_trigger_logs": WorkflowTriggerLog, +} + +ARCHIVED_TABLES = [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", +] + +RESTORE_ORDER = [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", +] + + +class WorkflowRunBundleArchiveMaintenance: + """ + Delete and restore V2 workflow-run archive bundles. + + Args: + dry_run: Validate and report counts without changing source rows or object-store markers. + strict_content_validation: Compare source-table content checksums against Parquet content before destructive + delete and after restore. Keep enabled for real maintenance. + stop_on_error: Stop batch processing after the first failed bundle. + """ + + dry_run: bool + strict_content_validation: bool + stop_on_error: bool + + def __init__( + self, + *, + dry_run: bool = False, + strict_content_validation: bool = True, + stop_on_error: bool = True, + ) -> None: + self.dry_run = dry_run + self.strict_content_validation = strict_content_validation + self.stop_on_error = stop_on_error + + def delete_batch( + self, + *, + tenant_ids: Sequence[str] | None, + start_date: datetime.datetime, + end_date: datetime.datetime, + limit: int, + ) -> BundleOperationSummary: + """Validate and delete source rows for archived V2 bundles in the requested created_at window.""" + return self._process_batch( + operation="delete", + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + + def restore_batch( + self, + *, + tenant_ids: Sequence[str] | None, + start_date: datetime.datetime, + end_date: datetime.datetime, + limit: int, + ) -> BundleOperationSummary: + """Restore source rows for deleted V2 bundles in the requested created_at window.""" + return self._process_batch( + operation="restore", + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + + def _process_batch( + self, + *, + operation: str, + tenant_ids: Sequence[str] | None, + start_date: datetime.datetime, + end_date: datetime.datetime, + limit: int, + ) -> BundleOperationSummary: + start_time = time.time() + summary = BundleOperationSummary(operation=operation) + if tenant_ids is not None and not tenant_ids: + return summary + + storage = self._get_archive_storage() + bundle_refs = self._list_bundle_refs( + storage, + operation=operation, + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + + logger.info("Found %s V2 archive bundles for %s", len(bundle_refs), operation) + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + for bundle_ref in bundle_refs: + with session_maker() as session: + if operation == "delete": + result = self._delete_bundle(session, storage, bundle_ref) + elif operation == "restore": + result = self._restore_bundle(session, storage, bundle_ref) + else: + raise ValueError(f"Unsupported operation: {operation}") + + self._merge_result(summary, result) + if not result.success and self.stop_on_error: + logger.error("Stopping V2 bundle %s after failure: %s", operation, result.error) + break + + summary.elapsed_time = time.time() - start_time + return summary + + def _list_bundle_refs( + self, + storage: ArchiveStorage, + *, + operation: str, + tenant_ids: Sequence[str] | None, + start_date: datetime.datetime, + end_date: datetime.datetime, + limit: int, + ) -> list[BundleReference]: + start_date = self._to_naive_utc(start_date) + end_date = self._to_naive_utc(end_date) + manifest_keys = self._list_manifest_keys(storage, tenant_ids) + refs: list[BundleReference] = [] + for manifest_key in manifest_keys: + manifest_data = self._get_checked_object(storage, manifest_key) + object_prefix = manifest_key.removesuffix(f"/{ARCHIVE_BUNDLE_MANIFEST_NAME}") + manifest = self._load_and_validate_manifest(manifest_data, object_prefix=object_prefix) + min_created_at = self._parse_manifest_datetime(manifest["min_created_at"]) + max_created_at = self._parse_manifest_datetime(manifest["max_created_at"]) + if max_created_at < start_date or min_created_at >= end_date: + continue + if tenant_ids and manifest["tenant_id"] not in tenant_ids: + continue + if operation == "delete" and self._is_deleted(storage, object_prefix): + continue + if operation == "restore" and not self._is_deleted(storage, object_prefix): + continue + refs.append(BundleReference(object_prefix=object_prefix, manifest_key=manifest_key, manifest=manifest)) + + refs.sort( + key=lambda ref: ( + self._parse_manifest_datetime(ref.manifest["min_created_at"]), + ref.manifest["tenant_id"], + ref.manifest["bundle_id"], + ) + ) + return refs[:limit] + + @staticmethod + def _list_manifest_keys(storage: ArchiveStorage, tenant_ids: Sequence[str] | None) -> list[str]: + keys: list[str] = [] + if tenant_ids: + prefixes = [ + f"{_ARCHIVE_ROOT_PREFIX}tenant_prefix={tenant_id[0].lower()}/tenant_id={tenant_id}/" + for tenant_id in tenant_ids + ] + else: + prefixes = [_ARCHIVE_ROOT_PREFIX] + for prefix in prefixes: + keys.extend(storage.list_objects(prefix)) + return sorted(key for key in keys if key.endswith(f"/{ARCHIVE_BUNDLE_MANIFEST_NAME}")) + + def _delete_bundle( + self, + session: Session, + storage: ArchiveStorage, + bundle_ref: BundleReference, + ) -> BundleOperationResult: + start_time = time.time() + result = self._new_result(bundle_ref.manifest) + try: + validation_start = time.time() + manifest, table_records, archive_bytes = self._validate_archive_object(storage, bundle_ref) + result.table_counts = self._manifest_table_counts(manifest) + result.archive_bytes = archive_bytes + + self._lock_workflow_runs(session, manifest["run_ids"]) + if self._is_delete_started(storage, bundle_ref.object_prefix) and self._live_counts_match( + session, manifest, expected_present=False + ): + result.validation_time = time.time() - validation_start + if not self.dry_run: + self._mark_deleted(storage, bundle_ref.object_prefix) + self._delete_marker(storage, bundle_ref.object_prefix, ARCHIVE_BUNDLE_DELETE_STARTED_MARKER_NAME) + result.success = True + return result + + self._validate_live_counts(session, manifest, expected_present=True) + if self.strict_content_validation: + self._validate_live_content(session, table_records) + result.validation_time = time.time() - validation_start + + if not self.dry_run: + self._put_marker(storage, bundle_ref.object_prefix, ARCHIVE_BUNDLE_DELETE_STARTED_MARKER_NAME) + deleted_counts = self._delete_bundle_rows(session, table_records) + if deleted_counts != result.table_counts: + raise ValueError( + f"Deleted row count mismatch: expected={result.table_counts}, actual={deleted_counts}" + ) + self._validate_live_counts(session, manifest, expected_present=False) + session.commit() + self._mark_deleted(storage, bundle_ref.object_prefix) + self._delete_marker(storage, bundle_ref.object_prefix, ARCHIVE_BUNDLE_DELETE_STARTED_MARKER_NAME) + self._delete_marker(storage, bundle_ref.object_prefix, ARCHIVE_BUNDLE_RESTORED_MARKER_NAME) + result.success = True + except Exception as e: + session.rollback() + result.error = str(e) + logger.exception("Failed to delete V2 archive bundle %s", bundle_ref.object_prefix) + result.elapsed_time = time.time() - start_time + return result + + def _restore_bundle( + self, + session: Session, + storage: ArchiveStorage, + bundle_ref: BundleReference, + ) -> BundleOperationResult: + start_time = time.time() + result = self._new_result(bundle_ref.manifest) + try: + validation_start = time.time() + manifest, table_records, archive_bytes = self._validate_archive_object(storage, bundle_ref) + result.table_counts = self._manifest_table_counts(manifest) + result.archive_bytes = archive_bytes + + if self._live_counts_match(session, manifest, expected_present=True): + if self.strict_content_validation: + self._validate_live_content(session, table_records) + result.validation_time = time.time() - validation_start + if not self.dry_run: + self._mark_restored(storage, bundle_ref.object_prefix) + result.success = True + return result + + self._validate_live_counts(session, manifest, expected_present=False) + result.validation_time = time.time() - validation_start + + if not self.dry_run: + self._put_marker(storage, bundle_ref.object_prefix, ARCHIVE_BUNDLE_RESTORE_STARTED_MARKER_NAME) + restored_counts = self._restore_bundle_rows(session, table_records) + if restored_counts != result.table_counts: + self._validate_live_counts(session, manifest, expected_present=True) + self._validate_live_counts(session, manifest, expected_present=True) + if self.strict_content_validation: + self._validate_live_content(session, table_records) + session.commit() + self._mark_restored(storage, bundle_ref.object_prefix) + result.success = True + except Exception as e: + session.rollback() + result.error = str(e) + logger.exception("Failed to restore V2 archive bundle %s", bundle_ref.object_prefix) + result.elapsed_time = time.time() - start_time + return result + + @staticmethod + def _new_result(manifest: BundleManifest) -> BundleOperationResult: + return BundleOperationResult( + bundle_id=manifest["bundle_id"], + tenant_id=manifest["tenant_id"], + object_prefix=manifest["object_prefix"], + ) + + def _validate_archive_object( + self, + storage: ArchiveStorage, + bundle_ref: BundleReference, + ) -> tuple[BundleManifest, dict[str, list[dict[str, Any]]], int]: + manifest = bundle_ref.manifest + table_records: dict[str, list[dict[str, Any]]] = {} + total_size = len(storage.get_object(bundle_ref.manifest_key)) + for table_name in ARCHIVED_TABLES: + info = manifest["tables"][table_name] + payload = self._get_checked_object(storage, info["object_key"]) + total_size += len(payload) + if len(payload) != info["size_bytes"]: + raise ValueError( + f"Archive object size mismatch for {info['object_key']}: " + f"expected={info['size_bytes']}, actual={len(payload)}" + ) + checksum = ArchiveStorage.compute_checksum(payload) + if checksum != info["checksum"]: + raise ValueError( + f"Archive object checksum mismatch for {info['object_key']}: " + f"expected={info['checksum']}, actual={checksum}" + ) + records = self._deserialize_parquet(payload) + if len(records) != info["row_count"]: + raise ValueError( + f"Parquet row count mismatch for {info['object_key']}: " + f"expected={info['row_count']}, actual={len(records)}" + ) + table_records[table_name] = records + return manifest, table_records, total_size + + @staticmethod + def _get_checked_object(storage: ArchiveStorage, object_key: str) -> bytes: + if not storage.object_exists(object_key): + raise FileNotFoundError(f"Archive object not found: {object_key}") + return storage.get_object(object_key) + + @staticmethod + def _load_and_validate_manifest( + manifest_data: bytes, + *, + object_prefix: str, + ) -> BundleManifest: + loaded = json.loads(manifest_data) + if not isinstance(loaded, dict): + raise ValueError("manifest.json must be an object") + required_fields = { + "schema_version", + "archive_format", + "tenant_id", + "tenant_prefix", + "year", + "month", + "shard", + "bundle_id", + "object_prefix", + "workflow_run_count", + "workflow_node_execution_count", + "tables", + "run_ids", + } + missing_fields = sorted(required_fields - set(loaded)) + if missing_fields: + raise ValueError(f"manifest missing required fields: {', '.join(missing_fields)}") + manifest = cast(BundleManifest, loaded) + if manifest["schema_version"] != ARCHIVE_BUNDLE_SCHEMA_VERSION: + raise ValueError(f"unsupported bundle schema_version: {manifest['schema_version']}") + if manifest["archive_format"] != ARCHIVE_BUNDLE_FORMAT: + raise ValueError(f"unsupported bundle archive_format: {manifest['archive_format']}") + if manifest["object_prefix"] != object_prefix: + raise ValueError("manifest object_prefix does not match object key") + if manifest["tenant_id"][0].lower() != manifest["tenant_prefix"]: + raise ValueError("manifest tenant_prefix does not match tenant_id") + if len(manifest["run_ids"]) != manifest["workflow_run_count"]: + raise ValueError("manifest run_ids count does not match workflow_run_count") + + tables = manifest["tables"] + if not isinstance(tables, dict): + raise ValueError("manifest tables must be an object") + for table_name in ARCHIVED_TABLES: + if table_name not in tables: + raise ValueError(f"manifest missing table: {table_name}") + info = tables[table_name] + for key in ("row_count", "checksum", "size_bytes", "object_key"): + if key not in info: + raise ValueError(f"manifest table {table_name} missing {key}") + expected_key = f"{object_prefix}/{table_name}.parquet" + if info["object_key"] != expected_key: + raise ValueError( + f"manifest object_key mismatch for {table_name}: " + f"expected={expected_key}, actual={info['object_key']}" + ) + return manifest + + @staticmethod + def _deserialize_parquet(payload: bytes) -> list[dict[str, Any]]: + table = pq.read_table(io.BytesIO(payload)) + return table.to_pylist() + + def _validate_live_counts( + self, + session: Session, + manifest: BundleManifest, + *, + expected_present: bool, + ) -> None: + expected_counts = self._manifest_table_counts(manifest) + actual_counts = self._count_live_rows(session, manifest["run_ids"]) + if not expected_present: + expected_counts = dict.fromkeys(expected_counts, 0) + if actual_counts != expected_counts: + state = "present" if expected_present else "deleted" + raise ValueError( + f"Live row count mismatch for {state} bundle: expected={expected_counts}, actual={actual_counts}" + ) + + def _live_counts_match(self, session: Session, manifest: BundleManifest, *, expected_present: bool) -> bool: + expected_counts = self._manifest_table_counts(manifest) + if not expected_present: + expected_counts = dict.fromkeys(expected_counts, 0) + return self._count_live_rows(session, manifest["run_ids"]) == expected_counts + + @staticmethod + def _manifest_table_counts(manifest: BundleManifest) -> dict[str, int]: + return {table_name: manifest["tables"][table_name]["row_count"] for table_name in ARCHIVED_TABLES} + + def _count_live_rows(self, session: Session, run_ids: Sequence[str]) -> dict[str, int]: + node_ids = self._select_ids_by_run_ids(session, WorkflowNodeExecutionModel, run_ids) + pause_ids = self._select_ids_by_run_ids(session, WorkflowPause, run_ids) + return { + "workflow_runs": self._count_by_run_ids(session, WorkflowRun, run_ids), + "workflow_app_logs": self._count_by_run_ids(session, WorkflowAppLog, run_ids), + "workflow_node_executions": len(node_ids), + "workflow_node_execution_offload": self._count_by_column( + session, WorkflowNodeExecutionOffload, WorkflowNodeExecutionOffload.node_execution_id, node_ids + ), + "workflow_pauses": len(pause_ids), + "workflow_pause_reasons": self._count_by_column( + session, WorkflowPauseReason, WorkflowPauseReason.pause_id, pause_ids + ), + "workflow_trigger_logs": self._count_by_run_ids(session, WorkflowTriggerLog, run_ids), + } + + def _validate_live_content( + self, + session: Session, + table_records: dict[str, list[dict[str, Any]]], + ) -> None: + run_ids = [str(record["id"]) for record in table_records["workflow_runs"]] + node_ids = [str(record["id"]) for record in table_records["workflow_node_executions"]] + pause_ids = [str(record["id"]) for record in table_records["workflow_pauses"]] + + live_records = { + "workflow_runs": self._load_records_by_run_ids(session, WorkflowRun, run_ids), + "workflow_app_logs": self._load_records_by_run_ids(session, WorkflowAppLog, run_ids), + "workflow_node_executions": self._load_records_by_run_ids(session, WorkflowNodeExecutionModel, run_ids), + "workflow_node_execution_offload": self._load_records_by_column( + session, WorkflowNodeExecutionOffload, WorkflowNodeExecutionOffload.node_execution_id, node_ids + ), + "workflow_pauses": self._load_records_by_run_ids(session, WorkflowPause, run_ids), + "workflow_pause_reasons": self._load_records_by_column( + session, WorkflowPauseReason, WorkflowPauseReason.pause_id, pause_ids + ), + "workflow_trigger_logs": self._load_records_by_run_ids(session, WorkflowTriggerLog, run_ids), + } + for table_name in ARCHIVED_TABLES: + live_checksum = self._records_checksum(live_records[table_name]) + archive_checksum = self._records_checksum(table_records[table_name]) + if live_checksum != archive_checksum: + raise ValueError( + f"Live/archive content checksum mismatch for {table_name}: " + f"expected={archive_checksum}, actual={live_checksum}" + ) + + def _delete_bundle_rows( + self, + session: Session, + table_records: dict[str, list[dict[str, Any]]], + ) -> dict[str, int]: + run_ids = [str(record["id"]) for record in table_records["workflow_runs"]] + node_ids = [str(record["id"]) for record in table_records["workflow_node_executions"]] + pause_ids = [str(record["id"]) for record in table_records["workflow_pauses"]] + + deleted_counts = dict.fromkeys(ARCHIVED_TABLES, 0) + deleted_counts["workflow_pause_reasons"] = self._delete_by_column( + session, WorkflowPauseReason, WorkflowPauseReason.pause_id, pause_ids + ) + deleted_counts["workflow_node_execution_offload"] = self._delete_by_column( + session, WorkflowNodeExecutionOffload, WorkflowNodeExecutionOffload.node_execution_id, node_ids + ) + deleted_counts["workflow_trigger_logs"] = self._delete_by_run_ids(session, WorkflowTriggerLog, run_ids) + deleted_counts["workflow_app_logs"] = self._delete_by_run_ids(session, WorkflowAppLog, run_ids) + deleted_counts["workflow_node_executions"] = self._delete_by_run_ids( + session, WorkflowNodeExecutionModel, run_ids + ) + deleted_counts["workflow_pauses"] = self._delete_by_run_ids(session, WorkflowPause, run_ids) + deleted_counts["workflow_runs"] = self._delete_by_run_ids(session, WorkflowRun, run_ids) + return deleted_counts + + def _restore_bundle_rows( + self, + session: Session, + table_records: dict[str, list[dict[str, Any]]], + ) -> dict[str, int]: + restored_counts = dict.fromkeys(ARCHIVED_TABLES, 0) + for table_name in RESTORE_ORDER: + restored_counts[table_name] = self._restore_table_records(session, table_name, table_records[table_name]) + return restored_counts + + def _restore_table_records( + self, + session: Session, + table_name: str, + records: list[dict[str, Any]], + ) -> int: + if not records: + return 0 + model = TABLE_MODELS[table_name] + total = 0 + for chunk in self._chunks(records, _CHUNK_SIZE): + converted = [self._prepare_insert_record(model, record) for record in chunk] + stmt = pg_insert(cast(Any, model.__table__)).values(converted) + stmt = stmt.on_conflict_do_nothing(index_elements=["id"]) + result = session.execute(stmt) + total += cast(CursorResult, result).rowcount or 0 + return total + + def _prepare_insert_record( + self, + model: Any, + record: dict[str, Any], + ) -> dict[str, Any]: + table = model.__table__ + columns_by_name = {column.name: column for column in table.columns} + prepared = {key: value for key, value in record.items() if key in columns_by_name} + for column_name, value in list(prepared.items()): + column = columns_by_name[column_name] + if value is None: + continue + if isinstance(column.type, sa.DateTime) and isinstance(value, str): + prepared[column_name] = datetime.datetime.fromisoformat(value) + elif isinstance(column.type, sa.JSON) and isinstance(value, str): + prepared[column_name] = json.loads(value) + return prepared + + @staticmethod + def _row_to_dict(row: Any) -> dict[str, Any]: + mapper = inspect(row).mapper + return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns} + + @staticmethod + def _normalize_record_for_checksum(record: dict[str, Any]) -> dict[str, Any]: + def normalize(value: Any) -> Any: + if isinstance(value, Enum): + return value.value + if isinstance(value, dict | list): + return json.dumps(value, default=str, ensure_ascii=False) + return value + + return {key: normalize(value) for key, value in record.items()} + + @classmethod + def _records_checksum(cls, records: list[dict[str, Any]]) -> str: + normalized = [cls._normalize_record_for_checksum(record) for record in records] + normalized.sort(key=lambda record: json.dumps(record, sort_keys=True, default=str, ensure_ascii=False)) + payload = json.dumps(normalized, sort_keys=True, default=str, ensure_ascii=False, separators=(",", ":")) + return ArchiveStorage.compute_checksum(payload.encode("utf-8")) + + @staticmethod + def _lock_workflow_runs(session: Session, run_ids: Sequence[str]) -> None: + for chunk in WorkflowRunBundleArchiveMaintenance._chunks(run_ids, _CHUNK_SIZE): + list(session.scalars(select(WorkflowRun.id).where(WorkflowRun.id.in_(chunk)).with_for_update())) + + @staticmethod + def _select_ids_by_run_ids( + session: Session, + model: Any, + run_ids: Sequence[str], + ) -> list[str]: + if not run_ids: + return [] + ids: list[str] = [] + for chunk in WorkflowRunBundleArchiveMaintenance._chunks(run_ids, _CHUNK_SIZE): + ids.extend( + str(row_id) for row_id in session.scalars(select(model.id).where(model.workflow_run_id.in_(chunk))) + ) + return ids + + @staticmethod + def _count_by_run_ids( + session: Session, + model: Any, + run_ids: Sequence[str], + ) -> int: + return WorkflowRunBundleArchiveMaintenance._count_by_column( + session, model, WorkflowRunBundleArchiveMaintenance._run_id_column(model), run_ids + ) + + @staticmethod + def _count_by_column( + session: Session, + model: Any, + column: Any, + values: Sequence[str], + ) -> int: + if not values: + return 0 + total = 0 + for chunk in WorkflowRunBundleArchiveMaintenance._chunks(values, _CHUNK_SIZE): + total += session.scalar(select(func.count()).select_from(model).where(column.in_(chunk))) or 0 + return total + + def _load_records_by_run_ids( + self, + session: Session, + model: Any, + run_ids: Sequence[str], + ) -> list[dict[str, Any]]: + return self._load_records_by_column(session, model, self._run_id_column(model), run_ids) + + def _load_records_by_column( + self, + session: Session, + model: Any, + column: Any, + values: Sequence[str], + ) -> list[dict[str, Any]]: + if not values: + return [] + rows: list[Any] = [] + for chunk in self._chunks(values, _CHUNK_SIZE): + rows.extend(session.scalars(select(model).where(column.in_(chunk)))) + return [self._row_to_dict(row) for row in rows] + + @staticmethod + def _delete_by_run_ids( + session: Session, + model: Any, + run_ids: Sequence[str], + ) -> int: + return WorkflowRunBundleArchiveMaintenance._delete_by_column( + session, model, WorkflowRunBundleArchiveMaintenance._run_id_column(model), run_ids + ) + + @staticmethod + def _run_id_column(model: Any) -> Any: + if model is WorkflowRun: + return WorkflowRun.id + return model.workflow_run_id + + @staticmethod + def _delete_by_column( + session: Session, + model: Any, + column: Any, + values: Sequence[str], + ) -> int: + if not values: + return 0 + total = 0 + for chunk in WorkflowRunBundleArchiveMaintenance._chunks(values, _CHUNK_SIZE): + result = session.execute(delete(model).where(column.in_(chunk))) + total += cast(CursorResult, result).rowcount or 0 + return total + + @staticmethod + def _is_deleted(storage: ArchiveStorage, object_prefix: str) -> bool: + return storage.object_exists(f"{object_prefix}/{ARCHIVE_BUNDLE_DELETED_MARKER_NAME}") + + @staticmethod + def _is_delete_started(storage: ArchiveStorage, object_prefix: str) -> bool: + return storage.object_exists(f"{object_prefix}/{ARCHIVE_BUNDLE_DELETE_STARTED_MARKER_NAME}") + + @staticmethod + def _mark_deleted(storage: ArchiveStorage, object_prefix: str) -> None: + WorkflowRunBundleArchiveMaintenance._put_marker(storage, object_prefix, ARCHIVE_BUNDLE_DELETED_MARKER_NAME) + + @staticmethod + def _mark_restored(storage: ArchiveStorage, object_prefix: str) -> None: + WorkflowRunBundleArchiveMaintenance._delete_marker(storage, object_prefix, ARCHIVE_BUNDLE_DELETED_MARKER_NAME) + WorkflowRunBundleArchiveMaintenance._delete_marker( + storage, object_prefix, ARCHIVE_BUNDLE_RESTORE_STARTED_MARKER_NAME + ) + WorkflowRunBundleArchiveMaintenance._put_marker(storage, object_prefix, ARCHIVE_BUNDLE_RESTORED_MARKER_NAME) + + @staticmethod + def _put_marker(storage: ArchiveStorage, object_prefix: str, marker_name: str) -> None: + payload = json.dumps({"created_at": datetime.datetime.now(datetime.UTC).isoformat()}).encode("utf-8") + storage.put_object(f"{object_prefix}/{marker_name}", payload) + + @staticmethod + def _delete_marker(storage: ArchiveStorage, object_prefix: str, marker_name: str) -> None: + marker_key = f"{object_prefix}/{marker_name}" + if storage.object_exists(marker_key): + storage.delete_object(marker_key) + + @staticmethod + def _parse_manifest_datetime(value: str) -> datetime.datetime: + return WorkflowRunBundleArchiveMaintenance._to_naive_utc(datetime.datetime.fromisoformat(value)) + + @staticmethod + def _to_naive_utc(value: datetime.datetime) -> datetime.datetime: + if value.tzinfo is None: + return value + return value.astimezone(datetime.UTC).replace(tzinfo=None) + + @staticmethod + def _chunks(values: Sequence[Any], size: int) -> list[Sequence[Any]]: + return [values[index : index + size] for index in range(0, len(values), size)] + + @staticmethod + def _get_archive_storage() -> ArchiveStorage: + try: + return get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + raise RuntimeError(f"Archive storage not configured: {e}") from e + + @staticmethod + def _merge_result(summary: BundleOperationSummary, result: BundleOperationResult) -> None: + summary.results.append(result) + summary.bundles_processed += 1 + summary.validation_time += result.validation_time + if result.success: + summary.bundles_succeeded += 1 + summary.rows_processed += result.row_count + summary.runs_processed += result.run_count + summary.archive_bytes += result.archive_bytes + for table_name, count in result.table_counts.items(): + summary.table_counts[table_name] = summary.table_counts.get(table_name, 0) + count + else: + summary.bundles_failed += 1 diff --git a/api/services/retention/workflow_run/constants.py b/api/services/retention/workflow_run/constants.py index 162bb4947df..d8d2807bb16 100644 --- a/api/services/retention/workflow_run/constants.py +++ b/api/services/retention/workflow_run/constants.py @@ -1,2 +1,10 @@ ARCHIVE_SCHEMA_VERSION = "1.0" ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip" + +ARCHIVE_BUNDLE_SCHEMA_VERSION = "2.0" +ARCHIVE_BUNDLE_FORMAT = "parquet" +ARCHIVE_BUNDLE_MANIFEST_NAME = "manifest.json" +ARCHIVE_BUNDLE_DELETE_STARTED_MARKER_NAME = "_DELETE_STARTED" +ARCHIVE_BUNDLE_DELETED_MARKER_NAME = "_DELETED" +ARCHIVE_BUNDLE_RESTORE_STARTED_MARKER_NAME = "_RESTORE_STARTED" +ARCHIVE_BUNDLE_RESTORED_MARKER_NAME = "_RESTORED" diff --git a/api/services/retention/workflow_run/delete_archived_workflow_run.py b/api/services/retention/workflow_run/delete_archived_workflow_run.py index 937a1067105..8fd8b3c8989 100644 --- a/api/services/retention/workflow_run/delete_archived_workflow_run.py +++ b/api/services/retention/workflow_run/delete_archived_workflow_run.py @@ -2,20 +2,68 @@ Delete Archived Workflow Run Service. This service deletes archived workflow run data from the database while keeping -archive logs intact. +archive logs intact. Deletion is intentionally gated by archive-object validation: +the archive bundle must exist, have a supported manifest, pass zip/member checksum +checks, and match the live row counts for every cleanup-owned table before rows +are removed from the primary database. """ +import io +import json +import logging import time +import zipfile from collections.abc import Sequence from dataclasses import dataclass, field from datetime import datetime +from typing import TypedDict from sqlalchemy.orm import Session, sessionmaker from extensions.ext_database import db -from models.workflow import WorkflowRun +from libs.archive_storage import ArchiveStorage, ArchiveStorageNotConfiguredError, get_archive_storage +from models.workflow import WorkflowArchiveLog, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION + +logger = logging.getLogger(__name__) + + +class _TableManifestEntry(TypedDict): + row_count: int + checksum: str + size_bytes: int + + +class _ArchiveManifest(TypedDict): + schema_version: str + workflow_run_id: str + tenant_id: str + app_id: str + workflow_id: str + tables: dict[str, _TableManifestEntry] + + +_ARCHIVED_TABLES = [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", +] + +_TABLE_TO_COUNT_KEY = { + "workflow_runs": "runs", + "workflow_app_logs": "app_logs", + "workflow_node_executions": "node_executions", + "workflow_node_execution_offload": "offloads", + "workflow_pauses": "pauses", + "workflow_pause_reasons": "pause_reasons", + "workflow_trigger_logs": "trigger_logs", +} @dataclass @@ -34,13 +82,49 @@ class DeleteResult: "pause_reasons": 0, } ) + validated_counts: RunsWithRelatedCountsDict = field( + default_factory=lambda: { # type: ignore[assignment] + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + ) + archive_key: str | None = None + restore_sampled: bool = False + restore_sample_success: bool | None = None error: str | None = None elapsed_time: float = 0.0 class ArchivedWorkflowRunDeletion: - def __init__(self, dry_run: bool = False): + """ + Delete archived workflow-run rows after validating the archive bundle. + + Args: + dry_run: Preview validation and row counts without deleting. + skip_bad_archives: Continue batch deletion after a validation/delete failure. + restore_sample_interval: Run restore dry-run for every Nth successful deletion; 0 disables sampling. + """ + + _delete_attempt_count: int + + def __init__( + self, + dry_run: bool = False, + *, + skip_bad_archives: bool = False, + restore_sample_interval: int = 0, + ): self.dry_run = dry_run + self.skip_bad_archives = skip_bad_archives + if restore_sample_interval < 0: + raise ValueError("restore_sample_interval must be >= 0") + self.restore_sample_interval = restore_sample_interval + self._delete_attempt_count = 0 self.workflow_run_repo: APIWorkflowRunRepository | None = None def delete_by_run_id(self, run_id: str) -> DeleteResult: @@ -57,12 +141,13 @@ class ArchivedWorkflowRunDeletion: return result result.tenant_id = run.tenant_id - if not repo.get_archived_run_ids(session, [run.id]): + archive_log = repo.get_archived_log_by_run_id(run.id) + if archive_log is None: result.error = f"Workflow run {run_id} is not archived" result.elapsed_time = time.time() - start_time return result - result = self._delete_run(run) + result = self._delete_run(run, archive_log) result.elapsed_time = time.time() - start_time return result @@ -78,8 +163,8 @@ class ArchivedWorkflowRunDeletion: repo = self._get_workflow_run_repo() with session_maker() as session: - runs = list( - repo.get_archived_runs_by_time_range( + archive_logs = list( + repo.get_archived_logs_by_time_range( session=session, tenant_ids=tenant_ids, start_date=start_date, @@ -87,14 +172,44 @@ class ArchivedWorkflowRunDeletion: limit=limit, ) ) - for run in runs: - results.append(self._delete_run(run)) + run_ids = [archive_log.workflow_run_id for archive_log in archive_logs] + runs_by_id = {run.id: run for run in session.query(WorkflowRun).where(WorkflowRun.id.in_(run_ids)).all()} + for archive_log in archive_logs: + run = runs_by_id.get(archive_log.workflow_run_id) + if run is None: + result = DeleteResult( + run_id=archive_log.workflow_run_id, + tenant_id=archive_log.tenant_id, + success=False, + error=f"Workflow run {archive_log.workflow_run_id} not found", + ) + else: + result = self._delete_run(run, archive_log) + results.append(result) + if not result.success and not self.skip_bad_archives: + logger.error("Stopping archived workflow run deletion after failure: %s", result.error) + break return results - def _delete_run(self, run: WorkflowRun) -> DeleteResult: + def _delete_run(self, run: WorkflowRun, archive_log: WorkflowArchiveLog | None = None) -> DeleteResult: start_time = time.time() result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + if archive_log is None: + archive_log = self._get_workflow_run_repo().get_archived_log_by_run_id(run.id) + if archive_log is None: + result.error = f"Workflow run {run.id} is not archived" + result.elapsed_time = time.time() - start_time + return result + + try: + result.archive_key = self._validate_archive_before_delete(run, archive_log) + result.validated_counts = self._count_live_related_rows(run) + except Exception as e: + result.error = str(e) + result.elapsed_time = time.time() - start_time + return result + if self.dry_run: result.success = True result.elapsed_time = time.time() - start_time @@ -108,17 +223,202 @@ class ArchivedWorkflowRunDeletion: delete_trigger_logs=self._delete_trigger_logs, ) result.deleted_counts = deleted_counts + self._verify_post_delete(run.id) + if self._should_run_restore_sample(): + result.restore_sampled = True + result.restore_sample_success = self._run_restore_dry_run_sample(archive_log) + if not result.restore_sample_success: + raise RuntimeError(f"Restore dry-run sample failed for workflow run {run.id}") result.success = True except Exception as e: result.error = str(e) result.elapsed_time = time.time() - start_time return result + def _validate_archive_before_delete(self, run: WorkflowRun, archive_log: WorkflowArchiveLog) -> str: + storage = self._get_archive_storage() + archive_key = self._get_archive_key(archive_log) + if not storage.object_exists(archive_key): + raise FileNotFoundError(f"Archive bundle not found: {archive_key}") + + archive_data = storage.get_object(archive_key) + manifest = self._validate_archive_bundle( + archive_data, + run_id=run.id, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + ) + expected_counts = self._counts_from_manifest(manifest) + current_counts = self._count_live_related_rows(run) + if current_counts != expected_counts: + raise ValueError( + "Archive row count mismatch before delete: " + f"run_id={run.id}, expected={expected_counts}, current={current_counts}" + ) + return archive_key + + @staticmethod + def _validate_archive_bundle( + archive_data: bytes, + *, + run_id: str, + tenant_id: str, + app_id: str, + workflow_id: str, + ) -> _ArchiveManifest: + try: + with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive: + bad_member = archive.testzip() + if bad_member: + raise ValueError(f"zip CRC check failed for member {bad_member}") + try: + manifest_data = archive.read("manifest.json") + except KeyError as e: + raise ValueError("manifest.json missing from archive bundle") from e + loaded = json.loads(manifest_data) + if not isinstance(loaded, dict): + raise ValueError("manifest.json must be an object") + manifest = loaded + + required_fields = { + "schema_version", + "workflow_run_id", + "tenant_id", + "app_id", + "workflow_id", + "tables", + } + missing_fields = sorted(required_fields - set(manifest)) + if missing_fields: + raise ValueError(f"manifest missing required fields: {', '.join(missing_fields)}") + if manifest["schema_version"] != ARCHIVE_SCHEMA_VERSION: + raise ValueError( + f"unsupported archive schema_version: {manifest['schema_version']} " + f"(expected {ARCHIVE_SCHEMA_VERSION})" + ) + if manifest["workflow_run_id"] != run_id: + raise ValueError("manifest workflow_run_id does not match delete target") + if manifest["tenant_id"] != tenant_id: + raise ValueError("manifest tenant_id does not match delete target") + if manifest["app_id"] != app_id: + raise ValueError("manifest app_id does not match delete target") + if manifest["workflow_id"] != workflow_id: + raise ValueError("manifest workflow_id does not match delete target") + + tables = manifest["tables"] + if not isinstance(tables, dict): + raise ValueError("manifest tables must be an object") + missing_tables = [table_name for table_name in _ARCHIVED_TABLES if table_name not in tables] + if missing_tables: + raise ValueError(f"manifest missing tables: {', '.join(missing_tables)}") + + for table_name in _ARCHIVED_TABLES: + info = tables[table_name] + if not isinstance(info, dict): + raise ValueError(f"manifest table entry must be an object: {table_name}") + for key in ("row_count", "checksum", "size_bytes"): + if key not in info: + raise ValueError(f"manifest table {table_name} missing {key}") + member_path = f"{table_name}.jsonl" + try: + payload = archive.read(member_path) + except KeyError as e: + raise ValueError(f"archive member missing: {member_path}") from e + if len(payload) != info["size_bytes"]: + raise ValueError( + f"archive member size mismatch for {member_path}: " + f"expected={info['size_bytes']}, actual={len(payload)}" + ) + checksum = ArchiveStorage.compute_checksum(payload) + if checksum != info["checksum"]: + raise ValueError( + f"archive member checksum mismatch for {member_path}: " + f"expected={info['checksum']}, actual={checksum}" + ) + row_count = len(ArchiveStorage.deserialize_from_jsonl(payload)) + if row_count != info["row_count"]: + raise ValueError( + f"archive row count mismatch for {member_path}: " + f"expected={info['row_count']}, actual={row_count}" + ) + + return manifest # type: ignore[return-value] + except zipfile.BadZipFile as e: + raise ValueError("archive bundle is not a valid zip file") from e + + @staticmethod + def _counts_from_manifest(manifest: _ArchiveManifest) -> RunsWithRelatedCountsDict: + counts: RunsWithRelatedCountsDict = { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + for table_name, count_key in _TABLE_TO_COUNT_KEY.items(): + counts[count_key] = manifest["tables"][table_name]["row_count"] # type: ignore[literal-required] + return counts + + def _count_live_related_rows(self, run: WorkflowRun) -> RunsWithRelatedCountsDict: + repo = self._get_workflow_run_repo() + return repo.count_runs_with_related( + [run], + count_node_executions=self._count_node_executions, + count_trigger_logs=self._count_trigger_logs, + ) + + def _verify_post_delete(self, run_id: str) -> None: + with sessionmaker(bind=db.engine, expire_on_commit=False)() as session: + if session.get(WorkflowRun, run_id) is not None: + raise RuntimeError(f"Post-delete verification failed: workflow run {run_id} still exists") + + def _should_run_restore_sample(self) -> bool: + if self.restore_sample_interval == 0: + return False + self._delete_attempt_count += 1 + return self._delete_attempt_count % self.restore_sample_interval == 0 + + @staticmethod + def _run_restore_dry_run_sample(archive_log: WorkflowArchiveLog) -> bool: + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + restorer = WorkflowRunRestore(dry_run=True, workers=1) + # Reuse restore's dry-run path so the runbook exercises the actual restore code. + result = restorer._restore_from_run( + archive_log, + session_maker=sessionmaker(bind=db.engine, expire_on_commit=False), + ) + return result.success + + @staticmethod + def _get_archive_key(archive_log: WorkflowArchiveLog) -> str: + created_at = archive_log.run_created_at + prefix = ( + f"{archive_log.tenant_id}/app_id={archive_log.app_id}/year={created_at.strftime('%Y')}/" + f"month={created_at.strftime('%m')}/workflow_run_id={archive_log.workflow_run_id}" + ) + return f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + + @staticmethod + def _get_archive_storage() -> ArchiveStorage: + try: + return get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + raise RuntimeError(f"Archive storage not configured: {e}") from e + @staticmethod def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int: trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) return trigger_repo.delete_by_run_ids(run_ids) + @staticmethod + def _count_trigger_logs(session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.count_by_run_ids(run_ids) + @staticmethod def _delete_node_executions( session: Session, @@ -132,6 +432,19 @@ class ArchivedWorkflowRunDeletion: ) return repo.delete_by_runs(session, run_ids) + @staticmethod + def _count_node_executions( + session: Session, + runs: Sequence[WorkflowRun], + ) -> tuple[int, int]: + from repositories.factory import DifyAPIRepositoryFactory + + run_ids = [run.id for run in runs] + repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False) + ) + return repo.count_by_runs(session, run_ids) + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: if self.workflow_run_repo is not None: return self.workflow_run_repo diff --git a/api/services/retention/workflow_run/tenant_prefix.py b/api/services/retention/workflow_run/tenant_prefix.py new file mode 100644 index 00000000000..34d1591c70d --- /dev/null +++ b/api/services/retention/workflow_run/tenant_prefix.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import sqlalchemy as sa + + +def tenant_prefix_bounds(prefix: str) -> tuple[str, str | None]: + prefix_value = int(prefix, 16) + lower_bound = f"{prefix}0000000-0000-0000-0000-000000000000" + if prefix_value == 15: + return lower_bound, None + upper_bound = f"{prefix_value + 1:x}0000000-0000-0000-0000-000000000000" + return lower_bound, upper_bound + + +def tenant_prefix_condition(column, prefix: str): + lower_bound, upper_bound = tenant_prefix_bounds(prefix) + condition = column >= lower_bound + if upper_bound is not None: + condition = sa.and_(condition, column < upper_bound) + return condition diff --git a/api/tests/integration_tests/services/retention/test_workflow_run_archiver.py b/api/tests/integration_tests/services/retention/test_workflow_run_archiver.py index 5728eacdfb4..9ce15c86e75 100644 --- a/api/tests/integration_tests/services/retention/test_workflow_run_archiver.py +++ b/api/tests/integration_tests/services/retention/test_workflow_run_archiver.py @@ -1,17 +1,17 @@ import datetime -import io import json import uuid -import zipfile from unittest.mock import MagicMock, patch +import pyarrow as pa +import pyarrow.parquet as pq import pytest from services.retention.workflow_run.archive_paid_plan_workflow_run import ( ArchiveSummary, WorkflowRunArchiver, ) -from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_FORMAT, ARCHIVE_BUNDLE_SCHEMA_VERSION class TestWorkflowRunArchiverInit: @@ -39,6 +39,22 @@ class TestWorkflowRunArchiverInit: with pytest.raises(ValueError, match="workers must be at least 1"): WorkflowRunArchiver(workers=0) + def test_run_shard_index_without_total_raises(self): + with pytest.raises(ValueError, match="run_shard_index and run_shard_total must be provided together"): + WorkflowRunArchiver(run_shard_index=0) + + def test_run_shard_total_without_index_raises(self): + with pytest.raises(ValueError, match="run_shard_index and run_shard_total must be provided together"): + WorkflowRunArchiver(run_shard_total=4) + + def test_run_shard_total_above_supported_range_raises(self): + with pytest.raises(ValueError, match="run_shard_total must be between 1 and 16"): + WorkflowRunArchiver(run_shard_index=0, run_shard_total=17) + + def test_run_shard_index_must_be_less_than_total(self): + with pytest.raises(ValueError, match="run_shard_index must be between 0 and run_shard_total - 1"): + WorkflowRunArchiver(run_shard_index=4, run_shard_total=4) + def test_valid_init_defaults(self): archiver = WorkflowRunArchiver(days=30, batch_size=50) assert archiver.days == 30 @@ -55,29 +71,93 @@ class TestWorkflowRunArchiverInit: assert archiver.end_before is not None assert archiver.workers == 2 + def test_delete_after_archive_is_not_supported_for_bundle_archive(self): + with pytest.raises(ValueError, match="delete_after_archive is not supported by bundle archive"): + WorkflowRunArchiver(delete_after_archive=True) + + def test_get_runs_batch_passes_shard_options(self): + repo = MagicMock() + repo.get_runs_batch_by_time_range.return_value = [] + archiver = WorkflowRunArchiver( + tenant_prefixes=["0", "a"], + run_shard_index=1, + run_shard_total=4, + workflow_run_repo=repo, + ) + + archiver._get_runs_batch(None) + + repo.get_runs_batch_by_time_range.assert_called_once() + assert repo.get_runs_batch_by_time_range.call_args.kwargs["tenant_prefixes"] == ["0", "a"] + assert repo.get_runs_batch_by_time_range.call_args.kwargs["run_shard_index"] == 1 + assert repo.get_runs_batch_by_time_range.call_args.kwargs["run_shard_total"] == 4 + + def test_get_runs_batch_prefers_planned_tenant_ids_over_prefix_filter(self): + repo = MagicMock() + repo.get_runs_batch_by_time_range.return_value = [] + archiver = WorkflowRunArchiver( + tenant_ids=["0tenant"], + tenant_prefixes=["0"], + paid_tenant_ids=["0tenant"], + workflow_run_repo=repo, + ) + + archiver._get_runs_batch(None) + + repo.get_runs_batch_by_time_range.assert_called_once() + assert repo.get_runs_batch_by_time_range.call_args.kwargs["tenant_ids"] == ["0tenant"] + assert repo.get_runs_batch_by_time_range.call_args.kwargs["tenant_prefixes"] is None + + def test_get_runs_batch_uses_current_tenant_scan_scope(self): + repo = MagicMock() + repo.get_runs_batch_by_time_range.return_value = [] + archiver = WorkflowRunArchiver( + tenant_ids=["tenant-a", "tenant-b"], + workflow_run_repo=repo, + ) + + archiver._get_runs_batch(None, tenant_scope=["tenant-b"]) + + repo.get_runs_batch_by_time_range.assert_called_once() + assert repo.get_runs_batch_by_time_range.call_args.kwargs["tenant_ids"] == ["tenant-b"] + + def test_start_message_includes_shard(self): + archiver = WorkflowRunArchiver(tenant_prefixes=["0"], run_shard_index=1, run_shard_total=4) + + message = archiver._build_start_message() + + assert "tenant_prefixes=0" in message + assert "run_shard=1/4" in message + + def test_start_message_summarizes_large_planned_tenant_list(self): + tenant_ids = [f"tenant-{index}" for index in range(11)] + archiver = WorkflowRunArchiver(tenant_ids=tenant_ids, tenant_prefixes=["0"]) + + message = archiver._build_start_message() + + assert "tenant_ids=11 planned tenants" in message + assert "tenant-10" not in message + class TestBuildArchiveBundle: - def test_bundle_contains_manifest_and_all_tables(self): + def test_bundle_contains_manifest_and_all_table_objects(self): archiver = WorkflowRunArchiver(days=90) + run = MagicMock() + run.id = str(uuid.uuid4()) + run.tenant_id = str(uuid.uuid4()) + run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0) + identity = archiver._build_bundle_identity([run]) + table_data = {"workflow_runs": [{"id": run.id, "tenant_id": run.tenant_id}]} - manifest_data = json.dumps({"schema_version": ARCHIVE_SCHEMA_VERSION}).encode("utf-8") - table_payloads = dict.fromkeys(archiver.ARCHIVED_TABLES, b"") + table_stats, table_payloads, manifest_data = archiver._build_archive_payload(identity, [run], table_data) + manifest = json.loads(manifest_data) - bundle_bytes = archiver._build_archive_bundle(manifest_data, table_payloads) - - with zipfile.ZipFile(io.BytesIO(bundle_bytes), "r") as zf: - names = set(zf.namelist()) - assert "manifest.json" in names - for table in archiver.ARCHIVED_TABLES: - assert f"{table}.jsonl" in names, f"Missing {table}.jsonl in bundle" - - def test_bundle_missing_table_payload_raises(self): - archiver = WorkflowRunArchiver(days=90) - manifest_data = b"{}" - incomplete_payloads = {archiver.ARCHIVED_TABLES[0]: b"data"} - - with pytest.raises(ValueError, match="Missing archive payload"): - archiver._build_archive_bundle(manifest_data, incomplete_payloads) + assert manifest["schema_version"] == ARCHIVE_BUNDLE_SCHEMA_VERSION + assert manifest["archive_format"] == ARCHIVE_BUNDLE_FORMAT + assert manifest["object_prefix"] == identity.object_prefix + assert set(table_payloads) == set(archiver.ARCHIVED_TABLES) + assert {stat.table_name for stat in table_stats} == set(archiver.ARCHIVED_TABLES) + assert pq.read_table(pa.BufferReader(table_payloads["workflow_runs"])).num_rows == 1 class TestGenerateManifest: @@ -88,25 +168,39 @@ class TestGenerateManifest: run = MagicMock() run.id = str(uuid.uuid4()) run.tenant_id = str(uuid.uuid4()) - run.app_id = str(uuid.uuid4()) - run.workflow_id = str(uuid.uuid4()) run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0) + identity = archiver._build_bundle_identity([run]) stats = [ - TableStats(table_name="workflow_runs", row_count=1, checksum="abc123", size_bytes=512), - TableStats(table_name="workflow_app_logs", row_count=2, checksum="def456", size_bytes=1024), + TableStats( + table_name="workflow_runs", + row_count=1, + checksum="abc123", + size_bytes=512, + object_key="workflow_runs.parquet", + ), + TableStats( + table_name="workflow_node_executions", + row_count=2, + checksum="def456", + size_bytes=1024, + object_key="workflow_node_executions.parquet", + ), ] - manifest = archiver._generate_manifest(run, stats) + manifest = archiver._generate_manifest(identity, [run], stats) - assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION - assert manifest["workflow_run_id"] == run.id + assert manifest["schema_version"] == ARCHIVE_BUNDLE_SCHEMA_VERSION + assert manifest["archive_format"] == ARCHIVE_BUNDLE_FORMAT + assert manifest["bundle_id"] == identity.bundle_id assert manifest["tenant_id"] == run.tenant_id - assert manifest["app_id"] == run.app_id + assert manifest["workflow_run_count"] == 1 + assert manifest["workflow_node_execution_count"] == 2 + assert manifest["run_ids"] == [run.id] assert "tables" in manifest assert manifest["tables"]["workflow_runs"]["row_count"] == 1 assert manifest["tables"]["workflow_runs"]["checksum"] == "abc123" - assert manifest["tables"]["workflow_app_logs"]["row_count"] == 2 + assert manifest["tables"]["workflow_node_executions"]["row_count"] == 2 class TestFilterPaidTenants: @@ -163,6 +257,19 @@ class TestFilterPaidTenants: assert result == set() + def test_planned_paid_tenants_skip_billing_lookup(self): + archiver = WorkflowRunArchiver(days=90, paid_tenant_ids=["t1", "t3"]) + + with ( + patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg, + patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing, + ): + cfg.BILLING_ENABLED = True + result = archiver._filter_paid_tenants({"t1", "t2", "t3"}) + + billing.get_plan_bulk_with_cache.assert_not_called() + assert result == {"t1", "t3"} + class TestDryRunArchive: @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage") @@ -175,3 +282,81 @@ class TestDryRunArchive: mock_get_storage.assert_not_called() assert isinstance(summary, ArchiveSummary) assert summary.runs_failed == 0 + + def test_dry_run_estimates_table_and_object_sizes(self): + archiver = WorkflowRunArchiver(days=90, dry_run=True) + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + run.app_id = "app-1" + run.workflow_id = "workflow-1" + run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0) + table_data = { + "workflow_runs": [{"id": "run-1", "tenant_id": "tenant-1"}], + "workflow_app_logs": [{"id": "log-1", "workflow_run_id": "run-1"}], + } + + with patch.object(archiver, "_extract_bundle_data", return_value=table_data): + result = archiver._archive_bundle(MagicMock(), None, [run]) + + stats_by_table = {stat.table_name: stat for stat in result.tables} + assert result.success is True + assert result.object_size_bytes > 0 + assert stats_by_table["workflow_runs"].row_count == 1 + assert stats_by_table["workflow_runs"].size_bytes > 0 + assert stats_by_table["workflow_app_logs"].row_count == 1 + assert stats_by_table["workflow_app_logs"].size_bytes > 0 + assert stats_by_table["workflow_node_executions"].row_count == 0 + assert stats_by_table["workflow_node_executions"].size_bytes > 0 + + def test_summary_merges_dry_run_estimates(self): + summary = ArchiveSummary() + result = MagicMock() + result.object_size_bytes = 128 + result.tables = [ + MagicMock(table_name="workflow_runs", row_count=1, size_bytes=64), + MagicMock(table_name="workflow_app_logs", row_count=2, size_bytes=32), + ] + + WorkflowRunArchiver._merge_result_stats(summary, result) + + assert summary.total_object_size_bytes == 128 + assert summary.table_stats["workflow_runs"].row_count == 1 + assert summary.table_stats["workflow_runs"].size_bytes == 64 + assert summary.table_stats["workflow_app_logs"].row_count == 2 + assert summary.table_stats["workflow_app_logs"].size_bytes == 32 + + +class TestArchiveRunIdempotency: + def test_locked_bundle_is_skipped(self): + archiver = WorkflowRunArchiver(days=90) + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0) + + with ( + patch.object(archiver, "_lock_runs_for_archive", return_value=[]), + ): + storage = MagicMock() + storage.object_exists.return_value = False + result = archiver._archive_bundle(MagicMock(), storage, [run]) + + assert result.success is True + assert result.skipped is True + assert result.error == "one or more runs locked or deleted by another archiver" + + def test_already_archived_bundle_is_skipped(self): + archiver = WorkflowRunArchiver(days=90) + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0) + storage = MagicMock() + storage.object_exists.return_value = True + + result = archiver._archive_bundle(MagicMock(), storage, [run]) + + assert result.success is True + assert result.skipped is True + assert result.error == "bundle already archived" diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index 69c39b8bfbb..62d6dd489a5 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -2,17 +2,44 @@ Testcontainers integration tests for archived workflow run deletion service. """ +import io +import json +import zipfile from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch from uuid import uuid4 from sqlalchemy import select from sqlalchemy.orm import Session from graphon.enums import WorkflowExecutionStatus +from libs.archive_storage import ArchiveStorage from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion +ARCHIVED_TABLES = [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", +] + + +class FakeArchiveStorage: + def __init__(self, objects: dict[str, bytes]): + self.objects = objects + + def object_exists(self, key: str) -> bool: + return key in self.objects + + def get_object(self, key: str) -> bytes: + return self.objects[key] + class TestArchivedWorkflowRunDeletion: def _create_workflow_run( @@ -47,7 +74,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.commit() return run - def _create_archive_log(self, db_session_with_containers: Session, *, run: WorkflowRun) -> None: + def _create_archive_log(self, db_session_with_containers: Session, *, run: WorkflowRun) -> WorkflowArchiveLog: archive_log = WorkflowArchiveLog( tenant_id=run.tenant_id, app_id=run.app_id, @@ -72,6 +99,59 @@ class TestArchivedWorkflowRunDeletion: ) db_session_with_containers.add(archive_log) db_session_with_containers.commit() + return archive_log + + def _archive_key(self, run: WorkflowRun) -> str: + return ( + f"{run.tenant_id}/app_id={run.app_id}/year={run.created_at.strftime('%Y')}/" + f"month={run.created_at.strftime('%m')}/workflow_run_id={run.id}/{ARCHIVE_BUNDLE_NAME}" + ) + + def _archive_bundle(self, run: WorkflowRun, *, workflow_run_rows: int = 1) -> bytes: + table_payloads: dict[str, bytes] = {} + table_counts = { + "workflow_runs": workflow_run_rows, + "workflow_app_logs": 0, + "workflow_node_executions": 0, + "workflow_node_execution_offload": 0, + "workflow_pauses": 0, + "workflow_pause_reasons": 0, + "workflow_trigger_logs": 0, + } + for table_name in ARCHIVED_TABLES: + records = [{"id": run.id}] if table_name == "workflow_runs" and workflow_run_rows else [] + table_payloads[table_name] = ArchiveStorage.serialize_to_jsonl(records) + + manifest = { + "schema_version": ARCHIVE_SCHEMA_VERSION, + "workflow_run_id": run.id, + "tenant_id": run.tenant_id, + "app_id": run.app_id, + "workflow_id": run.workflow_id, + "created_at": run.created_at.isoformat(), + "archived_at": datetime.now(UTC).isoformat(), + "tables": { + table_name: { + "row_count": table_counts[table_name], + "checksum": ArchiveStorage.compute_checksum(payload), + "size_bytes": len(payload), + } + for table_name, payload in table_payloads.items() + }, + } + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: + archive.writestr("manifest.json", json.dumps(manifest).encode("utf-8")) + for table_name, payload in table_payloads.items(): + archive.writestr(f"{table_name}.jsonl", payload) + return buffer.getvalue() + + def _patch_storage(self, run: WorkflowRun): + storage = FakeArchiveStorage({self._archive_key(run): self._archive_bundle(run)}) + return patch( + "services.retention.workflow_run.delete_archived_workflow_run.get_archive_storage", + return_value=storage, + ) def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers: Session): deleter = ArchivedWorkflowRunDeletion() @@ -109,13 +189,23 @@ class TestArchivedWorkflowRunDeletion: self._create_archive_log(db_session_with_containers, run=run2) run_ids = [run1.id, run2.id] - deleter = ArchivedWorkflowRunDeletion() - results = deleter.delete_batch( - tenant_ids=[tenant_id], - start_date=base_time - timedelta(minutes=1), - end_date=base_time + timedelta(minutes=1), - limit=2, + storage = FakeArchiveStorage( + { + self._archive_key(run1): self._archive_bundle(run1), + self._archive_key(run2): self._archive_bundle(run2), + } ) + deleter = ArchivedWorkflowRunDeletion() + with patch( + "services.retention.workflow_run.delete_archived_workflow_run.get_archive_storage", + return_value=storage, + ): + results = deleter.delete_batch( + tenant_ids=[tenant_id], + start_date=base_time - timedelta(minutes=1), + end_date=base_time + timedelta(minutes=1), + limit=2, + ) assert len(results) == 2 assert all(result.success for result in results) @@ -133,9 +223,11 @@ class TestArchivedWorkflowRunDeletion: created_at=datetime.now(UTC), ) run_id = run.id + archive_log = self._create_archive_log(db_session_with_containers, run=run) deleter = ArchivedWorkflowRunDeletion() - result = deleter._delete_run(run) + with self._patch_storage(run): + result = deleter._delete_run(run, archive_log) assert result.success is True assert result.deleted_counts["runs"] == 1 @@ -152,9 +244,11 @@ class TestArchivedWorkflowRunDeletion: created_at=datetime.now(UTC), ) run_id = run.id + archive_log = self._create_archive_log(db_session_with_containers, run=run) deleter = ArchivedWorkflowRunDeletion(dry_run=True) - result = deleter._delete_run(run) + with self._patch_storage(run): + result = deleter._delete_run(run, archive_log) assert result.success is True assert result.run_id == run_id @@ -164,22 +258,33 @@ class TestArchivedWorkflowRunDeletion: def test_delete_run_exception_returns_error(self, db_session_with_containers: Session): """Exception during deletion should return failure result.""" - from unittest.mock import MagicMock, patch - tenant_id = str(uuid4()) run = self._create_workflow_run( db_session_with_containers, tenant_id=tenant_id, created_at=datetime.now(UTC), ) + archive_log = self._create_archive_log(db_session_with_containers, run=run) deleter = ArchivedWorkflowRunDeletion(dry_run=False) + expected_counts = { + "runs": 1, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: mock_repo = MagicMock() mock_get_repo.return_value = mock_repo + mock_repo.get_archived_log_by_run_id.return_value = archive_log + mock_repo.count_runs_with_related.return_value = expected_counts mock_repo.delete_runs_with_related.side_effect = Exception("Database error") - result = deleter._delete_run(run) + with self._patch_storage(run): + result = deleter._delete_run(run, archive_log) assert result.success is False assert result.error == "Database error" @@ -197,7 +302,8 @@ class TestArchivedWorkflowRunDeletion: run_id = run.id deleter = ArchivedWorkflowRunDeletion() - result = deleter.delete_by_run_id(run_id) + with self._patch_storage(run): + result = deleter.delete_by_run_id(run_id) assert result.success is True db_session_with_containers.expunge_all() @@ -212,3 +318,48 @@ class TestArchivedWorkflowRunDeletion: assert repo1 is repo2 assert deleter.workflow_run_repo is repo1 + + def test_delete_run_fails_when_archive_object_missing(self, db_session_with_containers: Session): + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + archive_log = self._create_archive_log(db_session_with_containers, run=run) + deleter = ArchivedWorkflowRunDeletion() + storage = FakeArchiveStorage({}) + + with patch( + "services.retention.workflow_run.delete_archived_workflow_run.get_archive_storage", + return_value=storage, + ): + result = deleter._delete_run(run, archive_log) + + assert result.success is False + assert result.error == f"Archive bundle not found: {self._archive_key(run)}" + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowRun, run.id) is not None + + def test_delete_run_fails_when_manifest_count_differs_from_live_rows(self, db_session_with_containers: Session): + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + archive_log = self._create_archive_log(db_session_with_containers, run=run) + bundle = self._archive_bundle(run, workflow_run_rows=0) + storage = FakeArchiveStorage({self._archive_key(run): bundle}) + deleter = ArchivedWorkflowRunDeletion() + + with patch( + "services.retention.workflow_run.delete_archived_workflow_run.get_archive_storage", + return_value=storage, + ): + result = deleter._delete_run(run, archive_log) + + assert result.success is False + assert "Archive row count mismatch before delete" in str(result.error) + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowRun, run.id) is not None diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py new file mode 100644 index 00000000000..be1b94b2fad --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py @@ -0,0 +1,96 @@ +import io +import json +import zipfile +from datetime import UTC, datetime + +import pytest + +from libs.archive_storage import ArchiveStorage +from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION +from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + +ARCHIVED_TABLES = [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", +] + + +def _build_archive_bundle( + *, + run_id: str = "run-1", + tenant_id: str = "tenant-1", + app_id: str = "app-1", + workflow_id: str = "workflow-1", + corrupt_checksum_for: str | None = None, +) -> bytes: + table_payloads: dict[str, bytes] = {} + for table_name in ARCHIVED_TABLES: + records = [{"id": run_id}] if table_name == "workflow_runs" else [] + table_payloads[table_name] = ArchiveStorage.serialize_to_jsonl(records) + + manifest = { + "schema_version": ARCHIVE_SCHEMA_VERSION, + "workflow_run_id": run_id, + "tenant_id": tenant_id, + "app_id": app_id, + "workflow_id": workflow_id, + "created_at": datetime.now(UTC).isoformat(), + "archived_at": datetime.now(UTC).isoformat(), + "tables": { + table_name: { + "row_count": 1 if table_name == "workflow_runs" else 0, + "checksum": ArchiveStorage.compute_checksum(payload), + "size_bytes": len(payload), + } + for table_name, payload in table_payloads.items() + }, + } + if corrupt_checksum_for: + manifest["tables"][corrupt_checksum_for]["checksum"] = "bad-checksum" + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: + archive.writestr("manifest.json", json.dumps(manifest).encode("utf-8")) + for table_name, payload in table_payloads.items(): + archive.writestr(f"{table_name}.jsonl", payload) + return buffer.getvalue() + + +def test_validate_archive_bundle_accepts_valid_archive() -> None: + manifest = ArchivedWorkflowRunDeletion._validate_archive_bundle( + _build_archive_bundle(), + run_id="run-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + ) + + assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION + assert manifest["tables"]["workflow_runs"]["row_count"] == 1 + + +def test_validate_archive_bundle_rejects_checksum_mismatch() -> None: + with pytest.raises(ValueError, match="archive member checksum mismatch"): + ArchivedWorkflowRunDeletion._validate_archive_bundle( + _build_archive_bundle(corrupt_checksum_for="workflow_runs"), + run_id="run-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + ) + + +def test_validate_archive_bundle_rejects_manifest_target_mismatch() -> None: + with pytest.raises(ValueError, match="manifest tenant_id does not match delete target"): + ArchivedWorkflowRunDeletion._validate_archive_bundle( + _build_archive_bundle(), + run_id="run-1", + tenant_id="different-tenant", + app_id="app-1", + workflow_id="workflow-1", + ) diff --git a/api/tests/unit_tests/services/test_archive_workflow_run_logs.py b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py index bb9e5a8deab..a21f1de769c 100644 --- a/api/tests/unit_tests/services/test_archive_workflow_run_logs.py +++ b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py @@ -9,8 +9,6 @@ This module contains tests for: from datetime import datetime from unittest.mock import MagicMock, patch -from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME - class TestWorkflowRunArchiver: """Tests for the WorkflowRunArchiver class.""" @@ -37,18 +35,20 @@ class TestWorkflowRunArchiver: assert archiver.limit == 50 assert archiver.dry_run is True - def test_get_archive_key(self): - """Test archive key generation.""" + def test_get_bundle_manifest_key(self): + """Test V2 bundle manifest key generation.""" from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver - archiver = WorkflowRunArchiver.__new__(WorkflowRunArchiver) + archiver = WorkflowRunArchiver(run_shard_index=1, run_shard_total=4) mock_run = MagicMock() - mock_run.tenant_id = "tenant-123" - mock_run.app_id = "app-999" + mock_run.tenant_id = "9enant-123" mock_run.id = "run-456" mock_run.created_at = datetime(2024, 1, 15, 12, 0, 0) - key = archiver._get_archive_key(mock_run) + identity = archiver._build_bundle_identity([mock_run]) + key = archiver._get_manifest_object_key(identity) - assert key == f"tenant-123/app_id=app-999/year=2024/month=01/workflow_run_id=run-456/{ARCHIVE_BUNDLE_NAME}" + assert key.endswith("/manifest.json") + assert "workflow-runs/v2/tenant_prefix=9/tenant_id=9enant-123/year=2024/month=01" in key + assert "/shard=01-of-04/" in key From 99c3d7d0f0d8ae6c8024ac89bf8725a8bd6c7fa4 Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Tue, 23 Jun 2026 15:26:55 +0800 Subject: [PATCH 06/12] refactor(web): consolidate deployment state atoms (#37783) --- .../skills/how-to-write-component/SKILL.md | 20 +- web/app/(commonLayout)/deployments/layout.tsx | 2 + .../__tests__/state.spec.ts | 222 ++++++++++++++++ .../deployment-actions/delete-dialog.tsx | 56 ++-- .../deployment-actions/edit-dialog.tsx | 157 +++++------ .../deployment-actions/index.spec.tsx | 78 +++++- .../components/deployment-actions/index.tsx | 84 +++--- .../components/deployment-actions/state.ts | 168 ++++++++++++ .../state/__tests__/index.spec.ts | 169 +++++++++++- .../deployments/create-guide/state/index.ts | 101 ++++---- .../deployments/create-release/index.tsx | 2 +- .../state/__tests__/dsl-enabled.spec.ts | 245 ++++++++++++++++++ .../deployments/create-release/state/index.ts | 90 +++++-- .../create-release/ui/source-app-picker.tsx | 30 +-- .../deploy-drawer/__tests__/index.spec.tsx | 36 +++ .../deployments/deploy-drawer/index.tsx | 2 +- .../deployments/deploy-drawer/state/index.ts | 37 ++- .../deployments/deploy-drawer/ui/form.tsx | 26 +- .../detail/__tests__/state.spec.ts | 119 +++++++++ .../deployments/detail/access-tab.tsx | 10 +- .../deployments/detail/deploy-tab.tsx | 13 +- .../deploy-tab/deployment-row-actions.tsx | 5 +- .../deploy-tab/new-deployment-button.tsx | 14 +- .../deployments/detail/deployment-sidebar.tsx | 19 +- web/features/deployments/detail/index.tsx | 60 +++-- .../deployments/detail/overview-tab.tsx | 7 +- .../detail/overview-tab/release-hero.tsx | 29 ++- .../__tests__/api-key-generate-menu.spec.tsx | 25 ++ .../access/__tests__/permissions.spec.tsx | 16 +- .../access/api-key-generate-menu.tsx | 24 +- .../access/developer-api-section.tsx | 21 +- .../settings-tab/access/permissions.tsx | 20 +- .../detail/settings-tab/access/state.ts | 32 +++ web/features/deployments/detail/state.ts | 61 +++++ .../__tests__/deploy-release-menu.spec.tsx | 18 +- .../__tests__/release-history-rows.spec.tsx | 42 ++- .../versions-tab/__tests__/state.spec.ts | 143 ++++++++++ .../versions-tab/deploy-release-menu.tsx | 44 ++-- .../versions-tab/edit-release-dialog.tsx | 2 +- .../versions-tab/release-history-rows.tsx | 29 ++- .../versions-tab/release-history-table.tsx | 28 +- .../deployments/detail/versions-tab/state.ts | 90 +++++++ web/features/deployments/list/state/index.ts | 26 +- .../deployments/list/ui/instance-card.tsx | 1 - .../deployments/nav/__tests__/index.spec.tsx | 79 ++++++ .../deployments/nav/__tests__/state.spec.ts | 187 +++++++++++++ web/features/deployments/nav/index.tsx | 89 +------ web/features/deployments/nav/state.ts | 112 ++++++++ .../deployments/route-state-hydrator.tsx | 35 +++ web/features/deployments/route-state.ts | 6 + 50 files changed, 2364 insertions(+), 567 deletions(-) create mode 100644 web/features/deployments/components/deployment-actions/__tests__/state.spec.ts create mode 100644 web/features/deployments/components/deployment-actions/state.ts create mode 100644 web/features/deployments/create-release/state/__tests__/dsl-enabled.spec.ts create mode 100644 web/features/deployments/deploy-drawer/__tests__/index.spec.tsx create mode 100644 web/features/deployments/detail/__tests__/state.spec.ts create mode 100644 web/features/deployments/detail/settings-tab/access/state.ts create mode 100644 web/features/deployments/detail/state.ts create mode 100644 web/features/deployments/detail/versions-tab/__tests__/state.spec.ts create mode 100644 web/features/deployments/detail/versions-tab/state.ts create mode 100644 web/features/deployments/nav/__tests__/index.spec.tsx create mode 100644 web/features/deployments/nav/__tests__/state.spec.ts create mode 100644 web/features/deployments/nav/state.ts create mode 100644 web/features/deployments/route-state-hydrator.tsx create mode 100644 web/features/deployments/route-state.ts diff --git a/.agents/skills/how-to-write-component/SKILL.md b/.agents/skills/how-to-write-component/SKILL.md index f7e6e595092..8a480c8fd09 100644 --- a/.agents/skills/how-to-write-component/SKILL.md +++ b/.agents/skills/how-to-write-component/SKILL.md @@ -37,12 +37,16 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Do not replace prop drilling with one top-level hook that returns a large view model and then thread that object through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it, or use feature-scoped Jotai atoms for simple shared form/UI state when siblings need the same source of truth. - When using feature-scoped Jotai state for a form, drawer, or other secondary surface, scope the store to that surface instance when stale cross-instance state is possible. Initialize stable config at the owning boundary, then let descendants read only the atoms or purpose-named hooks they actually need. - For Jotai-backed surfaces, put shared query atoms, mutation atoms, derived state, and write actions in the feature state file when they coordinate multiple descendants. The lowest-owner rule still applies to independent visual surfaces that do not participate in shared state. +- For repeated row/menu action surfaces that need reset, hydrate the stable identity at the surface entry and scope only the primitives that truly need per-instance reset, such as open flags, drafts, or selected local options. - Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action. - Prefer uncontrolled DOM state and CSS variables before adding controlled props. ## Feature-Scoped Jotai State - A module's feature-local state lives in one state file for Jotai-backed features: primitive atoms, query atoms, derived atoms, write-only action atoms, mutation atoms, submission orchestration, provider exports, and optional scope configuration. +- Keep state local when one component owns it, even inside Jotai-backed features. Dialog open flags, menu/popover visibility, confirmation visibility, form/input drafts, row-local pending flags, and in-flight refs usually belong in component state. +- Promote UI state to an atom only when siblings need the same source of truth, the value drives a query or mutation atom, a parent workflow coordinates the state, or the state intentionally persists across hidden or unmounted descendants within a scoped surface. +- Reflect atom-backed surface-wide locks or invariants in every affected trigger. If only one row, menu, or dialog should be disabled, keep the pending or lock state local to that row, menu, or dialog. - Atom order in the state file follows the dependency graph: types/constants, editable primitives, query atoms, query-data derived atoms, readiness/business derived atoms, write actions, mutation atoms, submission orchestration, provider exports. - Derived atom names read as business facts. Write atom names read as user or workflow commands. - UI components read and write the exact atom they use with `useAtomValue` or `useSetAtom`. Repeated workflow semantics live in named derived atoms or write atoms. @@ -51,7 +55,11 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Avoid feature hooks that aggregate form values, query results, derived state, and commands for sibling components. Prefer named derived atoms and write atoms so UI components read the exact shared fact or command they need. - When a form library owns validation, keep submit orchestration in feature state when post-submit result or error state is shared by the surface. Avoid duplicating validation gates or request shaping in UI hooks. - `jotai-tanstack-query` atoms use the same QueryClient as the React Query provider. Query atoms belong in feature state when atoms are the feature's local state surface. -- Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query atoms keep shared cache behavior through the shared QueryClient. +- Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query and mutation atoms keep shared cache behavior through the shared QueryClient. +- Do not put `atomWithQuery`, `atomWithInfiniteQuery`, `atomWithMutation`, or broad derived orchestration atoms in a `ScopeProvider` just to reset a surface. Scoped derived atoms implicitly scope their dependencies, which can duplicate query client access and break shared invalidation. Leave query/mutation atoms unscoped; let them read scoped primitive inputs. +- Scope providers should list resettable primitive atoms and explicit hydration tuples. If a derived atom must be scoped, confirm that every dependency it implicitly scopes is meant to be private to that surface. +- Keep independent dialog lifecycles separate. Avoid a single discriminated "current action dialog" atom when edit, delete, and other dialogs have their own open state, loading guard, or reset behavior. +- Route-derived stable identities that do not need instance reset or scoped isolation can be hydrated at the route or layout boundary into a feature route atom. Use scoped atoms only when stale cross-instance state or per-surface reset semantics are needed. ## Components, Props, And Types @@ -74,6 +82,7 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Use generated enum objects and union types directly in props, comparisons, status logic, and i18n keys. Do not add local enum constants or parallel frontend enum/status layers unless they model real product state not represented by the API. Presentation-only tone maps should be keyed by the generated enum. - Normalize or coerce only at a real boundary, such as user-entered forms, search, URL/query params, file names, DOM IDs, or legacy adapters. Preserve user-entered values when whitespace or formatting can be meaningful. - Do not coerce nullable or optional API strings to `''` in query, derived model, or payload-building code. Keep `undefined` or `null` until the final boundary that requires a string. +- Do not use `value || undefined` for mutation payload fields where an empty string means "clear this value". Trim or normalize at the form boundary, then preserve `''` when the API contract treats it as an intentional update. - Local UI models are fine for presentation, form state, select options, or guarded required-field refinements. Name them as UI concepts, not generated DTO mirrors. - Required-value refinements are allowed only after same-branch filtering or early return. Prefer nullable-tolerant props for render-only data. - When a component needs a stricter shape than a generated DTO, refine once at the API/query-to-UI boundary into a purpose-named UI type instead of hiding missing fields with generic fallback or coercion helpers. @@ -93,12 +102,17 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape. - Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`. +- In `atomWithQuery` and `atomWithInfiniteQuery`, return generated `queryOptions()` or `infiniteOptions()` directly. Pass `enabled`, `retry`, `placeholderData`, `select`, and pagination options into that call instead of spreading generated options into a hand-built object. +- In `atomWithMutation`, return generated `mutationOptions()` directly when using generated clients. Put request shaping and submit orchestration in write atoms; do not rebuild mutation option objects just to pass through the generated mutation function. +- For custom query functions that do not come from generated clients, wrap the options object with TanStack `queryOptions(...)` so query atoms still return a query options contract. - Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it. - Keep feature hooks for real orchestration, workflow state, or shared domain behavior. - For TanStack cache data, use generated or query-derived types; do not create local wrappers for `getQueryData` or `getQueriesData`. -- For generated oRPC `queryOptions()` / `infiniteOptions()`, do not pass `skipToken` as `input`; keep a valid placeholder input shape and use `enabled` to gate missing required params because the OpenAPI codec encodes input eagerly. +- For generated oRPC `queryOptions()` / `infiniteOptions()`, keep returning the generated options directly. When required input is missing, use a whole-input branch such as `input: condition ? validInput : skipToken` together with `enabled: Boolean(condition)` so no request runs and no fake payload is built. +- Do not put `skipToken` inside a nested placeholder payload, such as `{ params: { appInstanceId: skipToken } }`. Do not create hand-written "missing queryOptions" objects or coerce required IDs to `''`. - Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))`; use oRPC clients as `mutationFn` only for custom flows. - Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules. +- Component or atom mutation callbacks can handle local UI feedback such as toasts, closing dialogs, or navigation. They should not replace shared invalidation or add local cache patches for shared server state. - Do not use deprecated `useInvalid` or `useReset`. - Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`. @@ -110,6 +124,7 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary. - Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow. - Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment. +- When a dialog, dropdown, or popover component already accepts controlled `open` state, mount the surface unconditionally unless unmounting is required for performance or reset semantics. Use keyed scope or local state reset for reset behavior instead of `{open && }` wrappers. - Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible. - Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, children-as-pass-through composition, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook, forwards props, or passes trigger/content through to one child, move the logic into that child or make the wrapper own a real surface. @@ -120,6 +135,7 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known. - Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render. - Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary. +- For forms initialized from query data, prefer keyed remounts or surface-entry hydration of form/field atoms over an Effect that copies query data into form state. - Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components. - If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow. diff --git a/web/app/(commonLayout)/deployments/layout.tsx b/web/app/(commonLayout)/deployments/layout.tsx index eb522444778..b8088169fd5 100644 --- a/web/app/(commonLayout)/deployments/layout.tsx +++ b/web/app/(commonLayout)/deployments/layout.tsx @@ -1,11 +1,13 @@ import type { ReactNode } from 'react' import { DeployDrawer } from '@/features/deployments/deploy-drawer' +import { DeploymentsRouteStateHydrator } from '@/features/deployments/route-state-hydrator' export default function DeploymentsLayout({ children }: { children: ReactNode }) { return ( <> + {children} diff --git a/web/features/deployments/components/deployment-actions/__tests__/state.spec.ts b/web/features/deployments/components/deployment-actions/__tests__/state.spec.ts new file mode 100644 index 00000000000..0137a339d1a --- /dev/null +++ b/web/features/deployments/components/deployment-actions/__tests__/state.spec.ts @@ -0,0 +1,222 @@ +import type { Getter } from 'jotai/vanilla' +import { skipToken } from '@tanstack/react-query' +import { atom, createStore } from 'jotai' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +type QueryOptions = { + enabled?: boolean + input?: unknown + queryKey?: readonly unknown[] +} + +type QueryResult = { + data?: unknown +} + +type MutationOptions = { + mutationKey?: readonly string[] +} + +type MutationResult = { + isPending: boolean + mutate: ReturnType + mutateAsync: ReturnType +} + +const mockQueryResults = vi.hoisted(() => ({ + current: new Map(), +})) + +const mockUpdateMutation = vi.hoisted<{ current: MutationResult }>(() => ({ + current: { + isPending: false, + mutate: vi.fn(), + mutateAsync: vi.fn(), + }, +})) + +const mockDeleteMutation = vi.hoisted<{ current: MutationResult }>(() => ({ + current: { + isPending: false, + mutate: vi.fn(), + mutateAsync: vi.fn(), + }, +})) + +vi.mock('jotai-tanstack-query', () => ({ + atomWithQuery: (createOptions: (get: Getter) => QueryOptions) => atom((get) => { + const options = createOptions(get) + const queryName = String(options.queryKey?.[0] ?? 'unknown') + const queryResult = options.enabled === false + ? undefined + : mockQueryResults.current.get(queryName) + + return { + ...options, + data: queryResult?.data, + isError: false, + isFetching: false, + isLoading: false, + isSuccess: Boolean(queryResult?.data), + } + }), + atomWithMutation: (createOptions: () => MutationOptions) => atom(() => { + const options = createOptions() + return options.mutationKey?.[0] === 'deleteAppInstance' + ? mockDeleteMutation.current + : mockUpdateMutation.current + }), +})) + +vi.mock('@/service/client', () => ({ + consoleQuery: { + enterprise: { + appInstanceService: { + getAppInstance: { + queryOptions: (options: QueryOptions) => ({ + ...options, + queryKey: ['getAppInstance', options.input], + }), + }, + updateAppInstance: { + mutationOptions: () => ({ mutationKey: ['updateAppInstance'] }), + }, + deleteAppInstance: { + mutationOptions: () => ({ mutationKey: ['deleteAppInstance'] }), + }, + }, + }, + }, +})) + +async function loadState() { + return await import('../state') +} + +async function mountedStore() { + const state = await loadState() + const store = createStore() + const unsubscribe = store.sub(state.editDeploymentFormCanSaveAtom, () => undefined) + + store.set(state.deploymentActionAppInstanceIdHydrationAtom, 'app-instance-1') + + return { + state, + store, + unsubscribe, + } +} + +function setAppInstance(overrides: Record = {}) { + mockQueryResults.current.set('getAppInstance', { + data: { + appInstance: { + id: 'app-instance-1', + displayName: 'Deployment 1', + description: 'Initial description', + ...overrides, + }, + }, + }) +} + +describe('deployment action state', () => { + beforeEach(() => { + vi.clearAllMocks() + mockQueryResults.current.clear() + mockUpdateMutation.current = { + isPending: false, + mutate: vi.fn(), + mutateAsync: vi.fn(), + } + mockDeleteMutation.current = { + isPending: false, + mutate: vi.fn(), + mutateAsync: vi.fn(), + } + }) + + it('should fetch app instance data only while an action dialog is open', async () => { + const { state, store, unsubscribe } = await mountedStore() + + expect(store.get(state.deploymentActionAppInstanceQueryAtom)).toMatchObject({ + enabled: false, + input: skipToken, + }) + + store.set(state.editDeploymentDialogOpenAtom, true) + expect(store.get(state.deploymentActionAppInstanceQueryAtom)).toMatchObject({ + enabled: true, + input: { params: { appInstanceId: 'app-instance-1' } }, + }) + + store.set(state.editDeploymentDialogOpenAtom, false) + store.set(state.deleteDeploymentDialogOpenAtom, true) + expect(store.get(state.deploymentActionAppInstanceQueryAtom)).toMatchObject({ + enabled: true, + input: { params: { appInstanceId: 'app-instance-1' } }, + }) + + unsubscribe() + }) + + it('should keep an edit dialog open while update is pending', async () => { + const { state, store, unsubscribe } = await mountedStore() + mockUpdateMutation.current = { + isPending: true, + mutate: vi.fn(), + mutateAsync: vi.fn(), + } + store.set(state.editDeploymentDialogOpenAtom, true) + + store.set(state.setEditDeploymentDialogOpenAtom, false) + + expect(store.get(state.editDeploymentDialogOpenAtom)).toBe(true) + + unsubscribe() + }) + + it('should submit edited deployment metadata with trimmed values', async () => { + const { state, store, unsubscribe } = await mountedStore() + const response = { appInstance: { id: 'app-instance-1' } } + setAppInstance() + mockUpdateMutation.current.mutateAsync.mockResolvedValue(response) + store.set(state.editDeploymentDialogOpenAtom, true) + store.set(state.editDeploymentNameFieldAtom, ' Deployment 2 ') + store.set(state.editDeploymentDescriptionFieldAtom, ' Updated description ') + + const result = await store.set(state.submitEditDeploymentFormAtom) + + expect(result).toBe(true) + expect(mockUpdateMutation.current.mutateAsync).toHaveBeenCalledWith({ + params: { + appInstanceId: 'app-instance-1', + }, + body: { + appInstanceId: 'app-instance-1', + displayName: 'Deployment 2', + description: 'Updated description', + }, + }) + + unsubscribe() + }) + + it('should submit delete with the hydrated app instance id and caller callbacks', async () => { + const { state, store, unsubscribe } = await mountedStore() + const onSuccess = vi.fn() + + store.set(state.submitDeleteDeploymentInstanceAtom, { onSuccess }) + + expect(mockDeleteMutation.current.mutate).toHaveBeenCalledWith( + { + params: { + appInstanceId: 'app-instance-1', + }, + }, + { onSuccess }, + ) + + unsubscribe() + }) +}) diff --git a/web/features/deployments/components/deployment-actions/delete-dialog.tsx b/web/features/deployments/components/deployment-actions/delete-dialog.tsx index e2a79e060b6..6af38106365 100644 --- a/web/features/deployments/components/deployment-actions/delete-dialog.tsx +++ b/web/features/deployments/components/deployment-actions/delete-dialog.tsx @@ -10,47 +10,37 @@ import { AlertDialogTitle, } from '@langgenius/dify-ui/alert-dialog' import { toast } from '@langgenius/dify-ui/toast' -import { useMutation } from '@tanstack/react-query' +import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { useTranslation } from 'react-i18next' import { useRouter } from '@/next/navigation' -import { consoleQuery } from '@/service/client' +import { + deleteDeploymentDialogOpenAtom, + deleteDeploymentInstanceMutationAtom, + deploymentActionDisplayNameAtom, + submitDeleteDeploymentInstanceAtom, +} from './state' -export function DeleteDeploymentDialog({ - appInstanceId, - appName, - open, - onOpenChange, -}: { - appInstanceId: string - appName?: string - open: boolean - onOpenChange: (open: boolean) => void -}) { +export function DeleteDeploymentDialog() { const { t } = useTranslation('deployments') const router = useRouter() - const deleteInstance = useMutation(consoleQuery.enterprise.appInstanceService.deleteAppInstance.mutationOptions()) - const displayName = appName || appInstanceId + const [open, setOpen] = useAtom(deleteDeploymentDialogOpenAtom) + const deleteInstance = useAtomValue(deleteDeploymentInstanceMutationAtom) + const submitDeleteInstance = useSetAtom(submitDeleteDeploymentInstanceAtom) + const displayName = useAtomValue(deploymentActionDisplayNameAtom) function handleDelete() { - deleteInstance.mutate( - { - params: { - appInstanceId, - }, + submitDeleteInstance({ + onSuccess: () => { + toast.success(t('settings.deleted')) + router.push('/deployments') }, - { - onSuccess: () => { - toast.success(t('settings.deleted')) - router.push('/deployments') - }, - onError: () => { - toast.error(t('settings.deleteFailed')) - }, - onSettled: () => { - onOpenChange(false) - }, + onError: () => { + toast.error(t('settings.deleteFailed')) }, - ) + onSettled: () => { + setOpen(false) + }, + }) } return ( @@ -59,7 +49,7 @@ export function DeleteDeploymentDialog({ onOpenChange={(nextOpen) => { if (!nextOpen && deleteInstance.isPending) return - onOpenChange(nextOpen) + setOpen(nextOpen) }} > diff --git a/web/features/deployments/components/deployment-actions/edit-dialog.tsx b/web/features/deployments/components/deployment-actions/edit-dialog.tsx index a1a367022a1..b433771b4e5 100644 --- a/web/features/deployments/components/deployment-actions/edit-dialog.tsx +++ b/web/features/deployments/components/deployment-actions/edit-dialog.tsx @@ -1,6 +1,5 @@ 'use client' -import type { AppInstance } from '@dify/contracts/enterprise/types.gen' import type { FormEvent } from 'react' import { Button } from '@langgenius/dify-ui/button' import { @@ -12,16 +11,22 @@ import { import { Input } from '@langgenius/dify-ui/input' import { Textarea } from '@langgenius/dify-ui/textarea' import { toast } from '@langgenius/dify-ui/toast' -import { useMutation, useQuery } from '@tanstack/react-query' -import { useState } from 'react' +import { useAtom, useAtomValue, useSetAtom } from 'jotai' +import { ScopeProvider } from 'jotai-scope' import { useTranslation } from 'react-i18next' import { SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton' -import { consoleQuery } from '@/service/client' - -type EditDeploymentFormValues = { - name: string - description: string -} +import { + deploymentActionAppInstanceQueryAtom, + editDeploymentDescriptionFieldAtom, + editDeploymentDialogOpenAtom, + editDeploymentFormAtom, + editDeploymentFormCanSaveAtom, + editDeploymentFormSavePendingAtom, + editDeploymentNameFieldAtom, + setEditDeploymentDialogOpenAtom, + submitEditDeploymentFormAtom, + updateDeploymentInstanceMutationAtom, +} from './state' function EditDeploymentFormSkeleton() { return ( @@ -42,35 +47,34 @@ function EditDeploymentFormSkeleton() { ) } -function EditDeploymentForm({ - app, - isSaving, - onClose, - onSubmit, -}: { - app: AppInstance - isSaving: boolean - onClose: () => void - onSubmit: (values: EditDeploymentFormValues) => void -}) { +function EditDeploymentForm() { const { t } = useTranslation('deployments') - const initialName = app.displayName - const initialDescription = app.description - const [name, setName] = useState(initialName) - const [description, setDescription] = useState(initialDescription) - const normalizedName = name.trim() - const normalizedDescription = description.trim() - const canSave = Boolean(normalizedName && (normalizedName !== initialName || normalizedDescription !== initialDescription) && !isSaving) + const [nameField, setNameField] = useAtom(editDeploymentNameFieldAtom) + const [descriptionField, setDescriptionField] = useAtom(editDeploymentDescriptionFieldAtom) + const canSave = useAtomValue(editDeploymentFormCanSaveAtom) + const savePending = useAtomValue(editDeploymentFormSavePendingAtom) + const submitEditDeploymentForm = useSetAtom(submitEditDeploymentFormAtom) + const requestOpenChange = useSetAtom(setEditDeploymentDialogOpenAtom) + const setOpen = useSetAtom(editDeploymentDialogOpenAtom) - function handleSubmit(event: FormEvent) { + async function handleSubmit(event: FormEvent) { event.preventDefault() + event.stopPropagation() + if (!canSave) return - onSubmit({ - name: normalizedName, - description: normalizedDescription, - }) + try { + const didSubmit = await submitEditDeploymentForm() + if (!didSubmit) + return + + toast.success(t('settings.updated')) + setOpen(false) + } + catch { + toast.error(t('settings.updateFailed')) + } } return ( @@ -81,9 +85,10 @@ function EditDeploymentForm({ setName(event.target.value)} + value={nameField.value} + onChange={event => setNameField(event.target.value)} className="h-8" /> @@ -93,8 +98,9 @@ function EditDeploymentForm({