diff --git a/api/commands.py b/api/commands.py index aa7b731a27..3d68de4cb4 100644 --- a/api/commands.py +++ b/api/commands.py @@ -950,6 +950,346 @@ def clean_workflow_runs( ) +@click.command( + "archive-workflow-runs", + help="Archive workflow runs for paid plan tenants to S3-compatible storage.", +) +@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.") +@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created at or after this timestamp (UTC if no timezone).", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created before this timestamp (UTC if no timezone).", +) +@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.") +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.") +@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.") +@click.option("--dry-run", is_flag=True, help="Preview without archiving.") +@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.") +def archive_workflow_runs( + tenant_ids: str | None, + before_days: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + batch_size: int, + workers: int, + limit: int | None, + dry_run: bool, + delete_after_archive: bool, +): + """ + Archive workflow runs for paid plan tenants older than the specified days. + + This command archives the following tables to storage: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + + The workflow_runs and workflow_app_logs tables are preserved for UI listing. + """ + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + run_started_at = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting workflow run archiving at {run_started_at.isoformat()}.", + fg="white", + ) + ) + + if (start_from is None) ^ (end_before is None): + click.echo(click.style("start-from and end-before must be provided together.", fg="red")) + return + + if (from_days_ago is None) ^ (to_days_ago is None): + click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red")) + return + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red")) + return + if from_days_ago <= to_days_ago: + click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red")) + return + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + if start_from and end_before and start_from >= end_before: + click.echo(click.style("start-from must be earlier than end-before.", fg="red")) + return + if workers < 1: + click.echo(click.style("workers must be at least 1.", fg="red")) + return + + archiver = WorkflowRunArchiver( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + workers=workers, + tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None, + limit=limit, + dry_run=dry_run, + delete_after_archive=delete_after_archive, + ) + summary = archiver.run() + click.echo( + click.style( + f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="cyan", + ) + ) + + run_finished_at = datetime.datetime.now(datetime.UTC) + elapsed = run_finished_at - run_started_at + click.echo( + click.style( + f"Workflow run archiving completed. start={run_started_at.isoformat()} " + f"end={run_finished_at.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + +@click.command( + "restore-workflow-runs", + help="Restore archived workflow runs from S3-compatible storage.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to restore.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.") +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.") +@click.option("--dry-run", is_flag=True, help="Preview without restoring.") +def restore_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + workers: int, + limit: int, + dry_run: bool, +): + """ + Restore an archived workflow run from storage to the database. + + This restores the following tables: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + """ + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch restore.") + if workers < 1: + raise click.BadParameter("workers must be at least 1") + + start_time = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.", + fg="white", + ) + ) + + restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) + if run_id: + results = [restorer.restore_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = restorer.restore_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Restore completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + +@click.command( + "delete-archived-workflow-runs", + help="Delete archived workflow runs from the database.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to delete.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.") +@click.option("--dry-run", is_flag=True, help="Preview without deleting.") +def delete_archived_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + limit: int, + dry_run: bool, +): + """ + Delete archived workflow runs from the database. + """ + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch delete.") + + start_time = datetime.datetime.now(datetime.UTC) + target_desc = f"workflow run {run_id}" if run_id else "workflow runs" + click.echo( + click.style( + f"Starting delete of {target_desc} at {start_time.isoformat()}.", + fg="white", + ) + ) + + deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run) + if run_id: + results = [deleter.delete_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = deleter.delete_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + for result in results: + if result.success: + click.echo( + click.style( + f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " + f"workflow run {result.run_id} (tenant={result.tenant_id})", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Failed to delete workflow run {result.run_id}: {result.error}", + fg="red", + ) + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Delete completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + @click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") @click.command("clear-orphaned-file-records", help="Clear orphaned file records.") def clear_orphaned_file_records(force: bool): diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index fa67fb8154..6736f24a2e 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.workflow.enums import WorkflowExecutionStatus from extensions.ext_database import db -from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from fields.workflow_app_log_fields import ( + build_workflow_app_log_pagination_model, + build_workflow_archived_log_pagination_model, +) from libs.login import login_required from models import App from models.model import AppMode @@ -61,6 +64,7 @@ console_ns.schema_model( # Register model for flask_restx to avoid dict type issues in Swagger workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) +workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns) @console_ns.route("/apps//workflow-app-logs") @@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource): ) return workflow_app_log_pagination + + +@console_ns.route("/apps//workflow-archived-logs") +class WorkflowArchivedLogApi(Resource): + @console_ns.doc("get_workflow_archived_logs") + @console_ns.doc(description="Get workflow archived execution logs") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) + @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + @marshal_with(workflow_archived_log_pagination_model) + def get(self, app_model: App): + """ + Get workflow archived logs + """ + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + workflow_app_service = WorkflowAppService() + with Session(db.engine) as session: + workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs( + session=session, + app_model=app_model, + page=args.page, + limit=args.limit, + ) + + return workflow_app_log_pagination diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 8f1871f1e9..fa74f8aea1 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,12 +1,15 @@ +from datetime import UTC, datetime, timedelta from typing import Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required +from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields from fields.workflow_run_fields import ( @@ -19,14 +22,17 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required -from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom +from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME from services.workflow_run_service import WorkflowRunService # Workflow run status choices for filtering WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] +EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -93,6 +99,15 @@ workflow_run_node_execution_list_model = console_ns.model( "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy ) +workflow_run_export_fields = console_ns.model( + "WorkflowRunExport", + { + "status": fields.String(description="Export status: success/failed"), + "presigned_url": fields.String(description="Pre-signed URL for download", required=False), + "presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False), + }, +) + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -181,6 +196,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource): return result +@console_ns.route("/apps//workflow-runs//export") +class WorkflowRunExportApi(Resource): + @console_ns.doc("get_workflow_run_export_url") + @console_ns.doc(description="Generate a download URL for an archived workflow run.") + @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @console_ns.response(200, "Export URL generated", workflow_run_export_fields) + @setup_required + @login_required + @account_initialization_required + @get_app_model() + def get(self, app_model: App, run_id: str): + tenant_id = str(app_model.tenant_id) + app_id = str(app_model.id) + run_id_str = str(run_id) + + run_created_at = db.session.scalar( + select(WorkflowArchiveLog.run_created_at) + .where( + WorkflowArchiveLog.tenant_id == tenant_id, + WorkflowArchiveLog.app_id == app_id, + WorkflowArchiveLog.workflow_run_id == run_id_str, + ) + .limit(1) + ) + if not run_created_at: + return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404 + + prefix = ( + f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/" + f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}" + ) + archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + + try: + archive_storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + return {"code": "archive_storage_not_configured", "message": str(e)}, 500 + + presigned_url = archive_storage.generate_presigned_url( + archive_key, + expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS, + ) + expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS) + return { + "status": "success", + "presigned_url": presigned_url, + "presigned_url_expires_at": expires_at.isoformat(), + }, 200 + + @console_ns.route("/apps//advanced-chat/workflow-runs/count") class AdvancedChatAppWorkflowRunCountApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs_count") diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 51e2c6cdd5..46885761a1 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + archive_workflow_runs, clean_expired_messages, clean_workflow_runs, cleanup_orphaned_draft_variables, @@ -11,6 +12,7 @@ def init_app(app: DifyApp): clear_orphaned_file_records, convert_to_agent_apps, create_tenant, + delete_archived_workflow_runs, extract_plugins, extract_unique_plugins, file_usage, @@ -24,6 +26,7 @@ def init_app(app: DifyApp): reset_email, reset_encrypt_key_pair, reset_password, + restore_workflow_runs, setup_datasource_oauth_client, setup_system_tool_oauth_client, setup_system_trigger_oauth_client, @@ -58,6 +61,9 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + archive_workflow_runs, + delete_archived_workflow_runs, + restore_workflow_runs, clean_workflow_runs, clean_expired_messages, ] diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 0ebc03a98c..ae70356322 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -2,7 +2,12 @@ from flask_restx import Namespace, fields from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields from fields.member_fields import build_simple_account_model, simple_account_fields -from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields +from fields.workflow_run_fields import ( + build_workflow_run_for_archived_log_model, + build_workflow_run_for_log_model, + workflow_run_for_archived_log_fields, + workflow_run_for_log_fields, +) from libs.helper import TimestampField workflow_app_log_partial_fields = { @@ -34,6 +39,33 @@ def build_workflow_app_log_partial_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowAppLogPartial", copied_fields) +workflow_archived_log_partial_fields = { + "id": fields.String, + "workflow_run": fields.Nested(workflow_run_for_archived_log_fields, allow_null=True), + "trigger_metadata": fields.Raw, + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, +} + + +def build_workflow_archived_log_partial_model(api_or_ns: Namespace): + """Build the workflow archived log partial model for the API or Namespace.""" + workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) + simple_account_model = build_simple_account_model(api_or_ns) + simple_end_user_model = build_simple_end_user_model(api_or_ns) + + copied_fields = workflow_archived_log_partial_fields.copy() + copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) + copied_fields["created_by_account"] = fields.Nested( + simple_account_model, attribute="created_by_account", allow_null=True + ) + copied_fields["created_by_end_user"] = fields.Nested( + simple_end_user_model, attribute="created_by_end_user", allow_null=True + ) + return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) + + workflow_app_log_pagination_fields = { "page": fields.Integer, "limit": fields.Integer, @@ -51,3 +83,21 @@ def build_workflow_app_log_pagination_model(api_or_ns: Namespace): copied_fields = workflow_app_log_pagination_fields.copy() copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model)) return api_or_ns.model("WorkflowAppLogPagination", copied_fields) + + +workflow_archived_log_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(workflow_archived_log_partial_fields)), +} + + +def build_workflow_archived_log_pagination_model(api_or_ns: Namespace): + """Build the workflow archived log pagination model for the API or Namespace.""" + workflow_archived_log_partial_model = build_workflow_archived_log_partial_model(api_or_ns) + + copied_fields = workflow_archived_log_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model)) + return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 476025064f..35bb442c59 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -23,6 +23,19 @@ def build_workflow_run_for_log_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) +workflow_run_for_archived_log_fields = { + "id": fields.String, + "status": fields.String, + "triggered_from": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, +} + + +def build_workflow_run_for_archived_log_model(api_or_ns: Namespace): + return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields) + + workflow_run_for_list_fields = { "id": fields.String, "version": fields.String, diff --git a/api/libs/archive_storage.py b/api/libs/archive_storage.py index f84d226447..66b57ac661 100644 --- a/api/libs/archive_storage.py +++ b/api/libs/archive_storage.py @@ -7,7 +7,6 @@ to S3-compatible object storage. import base64 import datetime -import gzip import hashlib import logging from collections.abc import Generator @@ -39,7 +38,7 @@ class ArchiveStorage: """ S3-compatible storage client for archiving or exporting. - This client provides methods for storing and retrieving archived data in JSONL+gzip format. + This client provides methods for storing and retrieving archived data in JSONL format. """ def __init__(self, bucket: str): @@ -69,7 +68,10 @@ class ArchiveStorage: aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY, aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY, region_name=dify_config.ARCHIVE_STORAGE_REGION, - config=Config(s3={"addressing_style": "path"}), + config=Config( + s3={"addressing_style": "path"}, + max_pool_connections=64, + ), ) # Verify bucket accessibility @@ -100,12 +102,18 @@ class ArchiveStorage: """ checksum = hashlib.md5(data).hexdigest() try: - self.client.put_object( + response = self.client.put_object( Bucket=self.bucket, Key=key, Body=data, ContentMD5=self._content_md5(data), ) + etag = response.get("ETag") + if not etag: + raise ArchiveStorageError(f"Missing ETag for '{key}'") + normalized_etag = etag.strip('"') + if normalized_etag != checksum: + raise ArchiveStorageError(f"ETag mismatch for '{key}': expected={checksum}, actual={normalized_etag}") logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum) return checksum except ClientError as e: @@ -240,19 +248,18 @@ class ArchiveStorage: return base64.b64encode(hashlib.md5(data).digest()).decode() @staticmethod - def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes: + def serialize_to_jsonl(records: list[dict[str, Any]]) -> bytes: """ - Serialize records to gzipped JSONL format. + Serialize records to JSONL format. Args: records: List of dictionaries to serialize Returns: - Gzipped JSONL bytes + JSONL bytes """ lines = [] for record in records: - # Convert datetime objects to ISO format strings serialized = ArchiveStorage._serialize_record(record) lines.append(orjson.dumps(serialized)) @@ -260,23 +267,22 @@ class ArchiveStorage: if jsonl_content: jsonl_content += b"\n" - return gzip.compress(jsonl_content) + return jsonl_content @staticmethod - def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]: + def deserialize_from_jsonl(data: bytes) -> list[dict[str, Any]]: """ - Deserialize gzipped JSONL data to records. + Deserialize JSONL data to records. Args: - data: Gzipped JSONL bytes + data: JSONL bytes Returns: List of dictionaries """ - jsonl_content = gzip.decompress(data) records = [] - for line in jsonl_content.splitlines(): + for line in data.splitlines(): if line: records.append(orjson.loads(line)) diff --git a/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py new file mode 100644 index 0000000000..5e7298af54 --- /dev/null +++ b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py @@ -0,0 +1,95 @@ +"""create workflow_archive_logs + +Revision ID: 9d77545f524e +Revises: f9f6d18a37f9 +Create Date: 2026-01-06 17:18:56.292479 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + +# revision identifiers, used by Alembic. +revision = '9d77545f524e' +down_revision = 'f9f6d18a37f9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if _is_pg(conn): + op.create_table('workflow_archive_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), + sa.Column('log_id', models.types.StringUUID(), nullable=True), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('log_created_at', sa.DateTime(), nullable=True), + sa.Column('log_created_from', sa.String(length=255), nullable=True), + sa.Column('run_version', sa.String(length=255), nullable=False), + sa.Column('run_status', sa.String(length=255), nullable=False), + sa.Column('run_triggered_from', sa.String(length=255), nullable=False), + sa.Column('run_error', models.types.LongText(), nullable=True), + sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('run_created_at', sa.DateTime(), nullable=False), + sa.Column('run_finished_at', sa.DateTime(), nullable=True), + sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('trigger_metadata', models.types.LongText(), nullable=True), + sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey') + ) + else: + op.create_table('workflow_archive_logs', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('log_id', models.types.StringUUID(), nullable=True), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('log_created_at', sa.DateTime(), nullable=True), + sa.Column('log_created_from', sa.String(length=255), nullable=True), + sa.Column('run_version', sa.String(length=255), nullable=False), + sa.Column('run_status', sa.String(length=255), nullable=False), + sa.Column('run_triggered_from', sa.String(length=255), nullable=False), + sa.Column('run_error', models.types.LongText(), nullable=True), + sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('run_created_at', sa.DateTime(), nullable=False), + sa.Column('run_finished_at', sa.DateTime(), nullable=True), + sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('trigger_metadata', models.types.LongText(), nullable=True), + sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey') + ) + with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op: + batch_op.create_index('workflow_archive_log_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_archive_log_run_created_at_idx', ['run_created_at'], unique=False) + batch_op.create_index('workflow_archive_log_workflow_run_id_idx', ['workflow_run_id'], unique=False) + + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_archive_log_workflow_run_id_idx') + batch_op.drop_index('workflow_archive_log_run_created_at_idx') + batch_op.drop_index('workflow_archive_log_app_idx') + + op.drop_table('workflow_archive_logs') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 91171a4bef..74b33130ef 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -103,6 +103,7 @@ from .workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, + WorkflowArchiveLog, WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom, @@ -203,6 +204,7 @@ __all__ = [ "Workflow", "WorkflowAppLog", "WorkflowAppLogCreatedFrom", + "WorkflowArchiveLog", "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", diff --git a/api/models/workflow.py b/api/models/workflow.py index 2ff47e87b9..0efb3a4e44 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1163,6 +1163,69 @@ class WorkflowAppLog(TypeBase): } +class WorkflowArchiveLog(TypeBase): + """ + Workflow archive log. + + Stores essential workflow run snapshot data for archived app logs. + + Field sources: + - Shared fields (tenant/app/workflow/run ids, created_by*): from WorkflowRun for consistency. + - log_* fields: from WorkflowAppLog when present; null if the run has no app log. + - run_* fields: workflow run snapshot fields from WorkflowRun. + - trigger_metadata: snapshot from WorkflowTriggerLog when present. + """ + + __tablename__ = "workflow_archive_logs" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_archive_log_pkey"), + sa.Index("workflow_archive_log_app_idx", "tenant_id", "app_id"), + sa.Index("workflow_archive_log_workflow_run_id_idx", "workflow_run_id"), + sa.Index("workflow_archive_log_run_created_at_idx", "run_created_at"), + ) + + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False + ) + + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + + log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + + run_version: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False) + run_error: Mapped[str | None] = mapped_column(LongText, nullable=True) + run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) + run_total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + run_created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + run_finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + run_exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + + trigger_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True) + archived_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + + @property + def workflow_run_summary(self) -> dict[str, Any]: + return { + "id": self.workflow_run_id, + "status": self.run_status, + "triggered_from": self.run_triggered_from, + "elapsed_time": self.run_elapsed_time, + "total_tokens": self.run_total_tokens, + } + + class ConversationVariable(TypeBase): __tablename__ = "workflow_conversation_variables" diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 479eb1ff54..5b3f635301 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from models.workflow import WorkflowNodeExecutionModel +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): @@ -209,3 +209,23 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr The number of executions deleted """ ... + + def get_offloads_by_execution_ids( + self, + session: Session, + node_execution_ids: Sequence[str], + ) -> Sequence[WorkflowNodeExecutionOffload]: + """ + Get offload records by node execution IDs. + + This method retrieves workflow node execution offload records + that belong to the given node execution IDs. + + Args: + session: The database session to use + node_execution_ids: List of node execution IDs to filter by + + Returns: + A sequence of WorkflowNodeExecutionOffload instances + """ + ... diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 1a2b84fdf9..1d3954571f 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -45,7 +45,7 @@ from core.workflow.enums import WorkflowType from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( AverageInteractionStats, @@ -270,6 +270,58 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_archived_run_ids( + self, + session: Session, + run_ids: Sequence[str], + ) -> set[str]: + """ + Fetch workflow run IDs that already have archive log records. + """ + ... + + def get_archived_logs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowArchiveLog]: + """ + Fetch archived workflow logs by time range for restore. + """ + ... + + def get_archived_log_by_run_id( + self, + run_id: str, + ) -> WorkflowArchiveLog | None: + """ + Fetch a workflow archive log by workflow run ID. + """ + ... + + def delete_archive_log_by_run_id( + self, + session: Session, + run_id: str, + ) -> int: + """ + Delete archive log by workflow run ID. + + Used after restoring a workflow run to remove the archive log record, + allowing the run to be archived again if needed. + + Args: + session: Database session + run_id: Workflow run ID + + Returns: + Number of records deleted (0 or 1) + """ + ... + def delete_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -282,6 +334,61 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_pause_records_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowPause]: + """ + Fetch workflow pause records by workflow run ID. + """ + ... + + def get_pause_reason_records_by_run_id( + self, + session: Session, + pause_ids: Sequence[str], + ) -> Sequence[WorkflowPauseReason]: + """ + Fetch workflow pause reason records by pause IDs. + """ + ... + + def get_app_logs_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowAppLog]: + """ + Fetch workflow app logs by workflow run ID. + """ + ... + + def create_archive_logs( + self, + session: Session, + run: WorkflowRun, + app_logs: Sequence[WorkflowAppLog], + trigger_metadata: str | None, + ) -> int: + """ + Create archive log records for a workflow run. + """ + ... + + def get_archived_runs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowRun]: + """ + Return workflow runs that already have archive logs, for cleanup of `workflow_runs`. + """ + ... + def count_runs_with_related( self, runs: Sequence[WorkflowRun], diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 4a7c975d2c..b19cc73bd1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -351,3 +351,27 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut ) return int(node_executions_count), int(offloads_count) + + @staticmethod + def get_by_run( + session: Session, + run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Fetch node executions for a run using workflow_run_id. + """ + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def get_offloads_by_execution_ids( + self, + session: Session, + node_execution_ids: Sequence[str], + ) -> Sequence[WorkflowNodeExecutionOffload]: + if not node_execution_ids: + return [] + + stmt = select(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) + ) + return list(session.scalars(stmt)) diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 9d2d06e99f..d5214be042 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -40,14 +40,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.workflow import ( - WorkflowAppLog, - WorkflowPauseReason, - WorkflowRun, -) -from models.workflow import ( - WorkflowPause as WorkflowPauseModel, -) +from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( @@ -369,6 +362,53 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): return session.scalars(stmt).all() + def get_archived_run_ids( + self, + session: Session, + run_ids: Sequence[str], + ) -> set[str]: + if not run_ids: + return set() + + stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids)) + return set(session.scalars(stmt).all()) + + def get_archived_log_by_run_id( + self, + run_id: str, + ) -> WorkflowArchiveLog | None: + with self._session_maker() as session: + stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1) + return session.scalar(stmt) + + def delete_archive_log_by_run_id( + self, + session: Session, + run_id: str, + ) -> int: + stmt = delete(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id) + result = session.execute(stmt) + return cast(CursorResult, result).rowcount or 0 + + def get_pause_records_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowPause]: + stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def get_pause_reason_records_by_run_id( + self, + session: Session, + pause_ids: Sequence[str], + ) -> Sequence[WorkflowPauseReason]: + if not pause_ids: + return [] + + stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) + return list(session.scalars(stmt)) + def delete_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -396,9 +436,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))) app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0 - pause_ids = session.scalars( - select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) - ).all() + pause_stmt = select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids)) + pause_ids = session.scalars(pause_stmt).all() pause_reasons_deleted = 0 pauses_deleted = 0 @@ -407,7 +446,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) ) pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0 - pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids))) + pauses_result = session.execute(delete(WorkflowPause).where(WorkflowPause.id.in_(pause_ids))) pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0 trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0 @@ -427,6 +466,124 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): "pause_reasons": pause_reasons_deleted, } + def get_app_logs_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowAppLog]: + stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def create_archive_logs( + self, + session: Session, + run: WorkflowRun, + app_logs: Sequence[WorkflowAppLog], + trigger_metadata: str | None, + ) -> int: + if not app_logs: + archive_log = WorkflowArchiveLog( + log_id=None, + log_created_at=None, + log_created_from=None, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + workflow_run_id=run.id, + created_by_role=run.created_by_role, + created_by=run.created_by, + run_version=run.version, + run_status=run.status, + run_triggered_from=run.triggered_from, + run_error=run.error, + run_elapsed_time=run.elapsed_time, + run_total_tokens=run.total_tokens, + run_total_steps=run.total_steps, + run_created_at=run.created_at, + run_finished_at=run.finished_at, + run_exceptions_count=run.exceptions_count, + trigger_metadata=trigger_metadata, + ) + session.add(archive_log) + return 1 + + archive_logs = [ + WorkflowArchiveLog( + log_id=app_log.id, + log_created_at=app_log.created_at, + log_created_from=app_log.created_from, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + workflow_run_id=run.id, + created_by_role=run.created_by_role, + created_by=run.created_by, + run_version=run.version, + run_status=run.status, + run_triggered_from=run.triggered_from, + run_error=run.error, + run_elapsed_time=run.elapsed_time, + run_total_tokens=run.total_tokens, + run_total_steps=run.total_steps, + run_created_at=run.created_at, + run_finished_at=run.finished_at, + run_exceptions_count=run.exceptions_count, + trigger_metadata=trigger_metadata, + ) + for app_log in app_logs + ] + session.add_all(archive_logs) + return len(archive_logs) + + def get_archived_runs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowRun]: + """ + Retrieves WorkflowRun records by joining workflow_archive_logs. + + Used to identify runs that are already archived and ready for deletion. + """ + stmt = ( + select(WorkflowRun) + .join(WorkflowArchiveLog, WorkflowArchiveLog.workflow_run_id == WorkflowRun.id) + .where( + WorkflowArchiveLog.run_created_at >= start_date, + WorkflowArchiveLog.run_created_at < end_date, + ) + .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc()) + .limit(limit) + ) + if tenant_ids: + stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) + return list(session.scalars(stmt)) + + def get_archived_logs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowArchiveLog]: + # Returns WorkflowArchiveLog rows directly; use this when workflow_runs may be deleted. + stmt = ( + select(WorkflowArchiveLog) + .where( + WorkflowArchiveLog.run_created_at >= start_date, + WorkflowArchiveLog.run_created_at < end_date, + ) + .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc()) + .limit(limit) + ) + if tenant_ids: + stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) + return list(session.scalars(stmt)) + def count_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -459,7 +616,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ) pause_ids = session.scalars( - select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) + select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids)) ).all() pauses_count = len(pause_ids) pause_reasons_count = 0 @@ -511,9 +668,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ValueError: If workflow_run_id is invalid or workflow run doesn't exist RuntimeError: If workflow is already paused or in invalid state """ - previous_pause_model_query = select(WorkflowPauseModel).where( - WorkflowPauseModel.workflow_run_id == workflow_run_id - ) + previous_pause_model_query = select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id) with self._session_maker() as session, session.begin(): # Get the workflow run workflow_run = session.get(WorkflowRun, workflow_run_id) @@ -538,7 +693,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Upload the state file # Create the pause record - pause_model = WorkflowPauseModel() + pause_model = WorkflowPause() pause_model.id = str(uuidv7()) pause_model.workflow_id = workflow_run.workflow_id pause_model.workflow_run_id = workflow_run.id @@ -710,13 +865,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ with self._session_maker() as session, session.begin(): # Get the pause model by ID - pause_model = session.get(WorkflowPauseModel, pause_entity.id) + pause_model = session.get(WorkflowPause, pause_entity.id) if pause_model is None: raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}") self._delete_pause_model(session, pause_model) @staticmethod - def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel): + def _delete_pause_model(session: Session, pause_model: WorkflowPause): storage.delete(pause_model.state_object_key) # Delete the pause record @@ -751,15 +906,15 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): _limit: int = limit or 1000 pruned_record_ids: list[str] = [] cond = or_( - WorkflowPauseModel.created_at < expiration, + WorkflowPause.created_at < expiration, and_( - WorkflowPauseModel.resumed_at.is_not(null()), - WorkflowPauseModel.resumed_at < resumption_expiration, + WorkflowPause.resumed_at.is_not(null()), + WorkflowPause.resumed_at < resumption_expiration, ), ) # First, collect pause records to delete with their state files # Expired pauses (created before expiration time) - stmt = select(WorkflowPauseModel).where(cond).limit(_limit) + stmt = select(WorkflowPause).where(cond).limit(_limit) with self._session_maker(expire_on_commit=False) as session: # Old resumed pauses (resumed more than resumption_duration ago) @@ -770,7 +925,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Delete state files from storage for pause in pauses_to_delete: with self._session_maker(expire_on_commit=False) as session, session.begin(): - # todo: this issues a separate query for each WorkflowPauseModel record. + # todo: this issues a separate query for each WorkflowPause record. # consider batching this lookup. try: storage.delete(pause.state_object_key) @@ -1022,7 +1177,7 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): def __init__( self, *, - pause_model: WorkflowPauseModel, + pause_model: WorkflowPause, reason_models: Sequence[WorkflowPauseReason], human_input_form: Sequence = (), ) -> None: diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index ebd3745d18..f3dc4cd60b 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -46,6 +46,11 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): return self.session.scalar(query) + def list_by_run_id(self, run_id: str) -> Sequence[WorkflowTriggerLog]: + """List trigger logs for a workflow run.""" + query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id) + return list(self.session.scalars(query).all()) + def get_failed_for_retry( self, tenant_id: str, max_retry_count: int = 3, limit: int = 100 ) -> Sequence[WorkflowTriggerLog]: diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py index e69de29bb2..18dd42c91e 100644 --- a/api/services/retention/workflow_run/__init__.py +++ b/api/services/retention/workflow_run/__init__.py @@ -0,0 +1 @@ +"""Workflow run retention services.""" diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py new file mode 100644 index 0000000000..ea5cbb7740 --- /dev/null +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -0,0 +1,531 @@ +""" +Archive Paid Plan Workflow Run Logs Service. + +This service archives workflow run logs for paid plan users older than the configured +retention period (default: 90 days) to S3-compatible storage. + +Archived tables: +- workflow_runs +- workflow_app_logs +- workflow_node_executions +- workflow_node_execution_offload +- workflow_pauses +- workflow_pause_reasons +- workflow_trigger_logs + +""" + +import datetime +import io +import json +import logging +import time +import zipfile +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any + +import click +from sqlalchemy import inspect +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.workflow.enums import WorkflowType +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageNotConfiguredError, + get_archive_storage, +) +from models.workflow import WorkflowAppLog, WorkflowRun +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.billing_service import BillingService +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION + +logger = logging.getLogger(__name__) + + +@dataclass +class TableStats: + """Statistics for a single archived table.""" + + table_name: str + row_count: int + checksum: str + size_bytes: int + + +@dataclass +class ArchiveResult: + """Result of archiving a single workflow run.""" + + run_id: str + tenant_id: str + success: bool + tables: list[TableStats] = field(default_factory=list) + error: str | None = None + elapsed_time: float = 0.0 + + +@dataclass +class ArchiveSummary: + """Summary of the entire archive operation.""" + + total_runs_processed: int = 0 + runs_archived: int = 0 + runs_skipped: int = 0 + runs_failed: int = 0 + total_elapsed_time: float = 0.0 + + +class WorkflowRunArchiver: + """ + Archive workflow run logs for paid plan users. + + Storage Layout: + {tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/ + └── archive.v1.0.zip + ├── manifest.json + ├── workflow_runs.jsonl + ├── workflow_app_logs.jsonl + ├── workflow_node_executions.jsonl + ├── workflow_node_execution_offload.jsonl + ├── workflow_pauses.jsonl + ├── workflow_pause_reasons.jsonl + └── workflow_trigger_logs.jsonl + """ + + ARCHIVED_TYPE = [ + WorkflowType.WORKFLOW, + WorkflowType.RAG_PIPELINE, + ] + ARCHIVED_TABLES = [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", + ] + + start_from: datetime.datetime | None + end_before: datetime.datetime + + def __init__( + self, + days: int = 90, + batch_size: int = 100, + start_from: datetime.datetime | None = None, + end_before: datetime.datetime | None = None, + workers: int = 1, + tenant_ids: Sequence[str] | None = None, + limit: int | None = None, + dry_run: bool = False, + delete_after_archive: bool = False, + workflow_run_repo: APIWorkflowRunRepository | None = None, + ): + """ + Initialize the archiver. + + Args: + days: Archive runs older than this many days + batch_size: Number of runs to process per batch + start_from: Optional start time (inclusive) for archiving + end_before: Optional end time (exclusive) for archiving + workers: Number of concurrent workflow runs to archive + tenant_ids: Optional tenant IDs for grayscale rollout + limit: Maximum number of runs to archive (None for unlimited) + dry_run: If True, only preview without making changes + delete_after_archive: If True, delete runs and related data after archiving + """ + self.days = days + self.batch_size = batch_size + if start_from or end_before: + if start_from is None or end_before is None: + raise ValueError("start_from and end_before must be provided together") + if start_from >= end_before: + raise ValueError("start_from must be earlier than end_before") + self.start_from = start_from.replace(tzinfo=datetime.UTC) + self.end_before = end_before.replace(tzinfo=datetime.UTC) + else: + self.start_from = None + self.end_before = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days) + if workers < 1: + raise ValueError("workers must be at least 1") + self.workers = workers + self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else [] + self.limit = limit + self.dry_run = dry_run + self.delete_after_archive = delete_after_archive + self.workflow_run_repo = workflow_run_repo + + def run(self) -> ArchiveSummary: + """ + Main archiving loop. + + Returns: + ArchiveSummary with statistics about the operation + """ + summary = ArchiveSummary() + start_time = time.time() + + click.echo( + click.style( + self._build_start_message(), + fg="white", + ) + ) + + # Initialize archive storage (will raise if not configured) + try: + if not self.dry_run: + storage = get_archive_storage() + else: + storage = None + except ArchiveStorageNotConfiguredError as e: + click.echo(click.style(f"Archive storage not configured: {e}", fg="red")) + return summary + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = self._get_workflow_run_repo() + + def _archive_with_session(run: WorkflowRun) -> ArchiveResult: + with session_maker() as session: + return self._archive_run(session, storage, run) + + last_seen: tuple[datetime.datetime, str] | None = None + archived_count = 0 + + with ThreadPoolExecutor(max_workers=self.workers) as executor: + while True: + # Check limit + if self.limit and archived_count >= self.limit: + click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow")) + break + + # Fetch batch of runs + runs = self._get_runs_batch(last_seen) + + if not runs: + break + + run_ids = [run.id for run in runs] + with session_maker() as session: + archived_run_ids = repo.get_archived_run_ids(session, run_ids) + + last_seen = (runs[-1].created_at, runs[-1].id) + + # Filter to paid tenants only + tenant_ids = {run.tenant_id for run in runs} + paid_tenants = self._filter_paid_tenants(tenant_ids) + + runs_to_process: list[WorkflowRun] = [] + for run in runs: + summary.total_runs_processed += 1 + + # Skip non-paid tenants + if run.tenant_id not in paid_tenants: + summary.runs_skipped += 1 + continue + + # Skip already archived runs + if run.id in archived_run_ids: + summary.runs_skipped += 1 + continue + + # Check limit + if self.limit and archived_count + len(runs_to_process) >= self.limit: + break + + runs_to_process.append(run) + + if not runs_to_process: + continue + + results = list(executor.map(_archive_with_session, runs_to_process)) + + for run, result in zip(runs_to_process, results): + if result.success: + summary.runs_archived += 1 + archived_count += 1 + click.echo( + click.style( + f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} " + f"run {run.id} (tenant={run.tenant_id}, " + f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)", + fg="green", + ) + ) + else: + summary.runs_failed += 1 + click.echo( + click.style( + f"Failed to archive run {run.id}: {result.error}", + fg="red", + ) + ) + + summary.total_elapsed_time = time.time() - start_time + click.echo( + click.style( + f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: " + f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="white", + ) + ) + + return summary + + def _get_runs_batch( + self, + last_seen: tuple[datetime.datetime, str] | None, + ) -> Sequence[WorkflowRun]: + """Fetch a batch of workflow runs to archive.""" + repo = self._get_workflow_run_repo() + return repo.get_runs_batch_by_time_range( + start_from=self.start_from, + end_before=self.end_before, + last_seen=last_seen, + batch_size=self.batch_size, + run_types=self.ARCHIVED_TYPE, + tenant_ids=self.tenant_ids or None, + ) + + def _build_start_message(self) -> str: + range_desc = f"before {self.end_before.isoformat()}" + if self.start_from: + range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}" + return ( + f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving " + f"for runs {range_desc} " + f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})" + ) + + def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]: + """Filter tenant IDs to only include paid tenants.""" + if not dify_config.BILLING_ENABLED: + # If billing is not enabled, treat all tenants as paid + return tenant_ids + + if not tenant_ids: + return set() + + try: + bulk_info = BillingService.get_plan_bulk_with_cache(list(tenant_ids)) + except Exception: + logger.exception("Failed to fetch billing plans for tenants") + # On error, skip all tenants in this batch + return set() + + # Filter to paid tenants (any plan except SANDBOX) + paid = set() + for tid, info in bulk_info.items(): + if info and info.get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM): + paid.add(tid) + + return paid + + def _archive_run( + self, + session: Session, + storage: ArchiveStorage | None, + run: WorkflowRun, + ) -> ArchiveResult: + """Archive a single workflow run.""" + start_time = time.time() + result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + + try: + # Extract data from all tables + table_data, app_logs, trigger_metadata = self._extract_data(session, run) + + if self.dry_run: + # In dry run, just report what would be archived + for table_name in self.ARCHIVED_TABLES: + records = table_data.get(table_name, []) + result.tables.append( + TableStats( + table_name=table_name, + row_count=len(records), + checksum="", + size_bytes=0, + ) + ) + result.success = True + else: + if storage is None: + raise ArchiveStorageNotConfiguredError("Archive storage not configured") + archive_key = self._get_archive_key(run) + + # Serialize tables for the archive bundle + table_stats: list[TableStats] = [] + table_payloads: dict[str, bytes] = {} + for table_name in self.ARCHIVED_TABLES: + records = table_data.get(table_name, []) + data = ArchiveStorage.serialize_to_jsonl(records) + table_payloads[table_name] = data + checksum = ArchiveStorage.compute_checksum(data) + + table_stats.append( + TableStats( + table_name=table_name, + row_count=len(records), + checksum=checksum, + size_bytes=len(data), + ) + ) + + # Generate and upload archive bundle + manifest = self._generate_manifest(run, table_stats) + manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8") + archive_data = self._build_archive_bundle(manifest_data, table_payloads) + storage.put_object(archive_key, archive_data) + + repo = self._get_workflow_run_repo() + archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata) + session.commit() + + deleted_counts = None + if self.delete_after_archive: + deleted_counts = repo.delete_runs_with_related( + [run], + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + + logger.info( + "Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s", + run.id, + {s.table_name: s.row_count for s in table_stats}, + archived_log_count, + deleted_counts, + ) + + result.tables = table_stats + result.success = True + + except Exception as e: + logger.exception("Failed to archive workflow run %s", run.id) + result.error = str(e) + session.rollback() + + result.elapsed_time = time.time() - start_time + return result + + def _extract_data( + self, + session: Session, + run: WorkflowRun, + ) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]: + table_data: dict[str, list[dict[str, Any]]] = {} + table_data["workflow_runs"] = [self._row_to_dict(run)] + repo = self._get_workflow_run_repo() + app_logs = repo.get_app_logs_by_run_id(session, run.id) + table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs] + node_exec_repo = self._get_workflow_node_execution_repo(session) + node_exec_records = node_exec_repo.get_executions_by_workflow_run( + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_run_id=run.id, + ) + node_exec_ids = [record.id for record in node_exec_records] + offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids) + table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records] + table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records] + repo = self._get_workflow_run_repo() + pause_records = repo.get_pause_records_by_run_id(session, run.id) + pause_ids = [pause.id for pause in pause_records] + pause_reason_records = repo.get_pause_reason_records_by_run_id( + session, + pause_ids, + ) + table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records] + table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records] + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + trigger_records = trigger_repo.list_by_run_id(run.id) + table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records] + trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None + return table_data, app_logs, trigger_metadata + + @staticmethod + def _row_to_dict(row: Any) -> dict[str, Any]: + mapper = inspect(row).mapper + return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns} + + def _get_archive_key(self, run: WorkflowRun) -> str: + """Get the storage key for the archive bundle.""" + created_at = run.created_at + prefix = ( + f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/" + f"month={created_at.strftime('%m')}/workflow_run_id={run.id}" + ) + return f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + + def _generate_manifest( + self, + run: WorkflowRun, + table_stats: list[TableStats], + ) -> dict[str, Any]: + """Generate a manifest for the archived workflow run.""" + return { + "schema_version": ARCHIVE_SCHEMA_VERSION, + "workflow_run_id": run.id, + "tenant_id": run.tenant_id, + "app_id": run.app_id, + "workflow_id": run.workflow_id, + "created_at": run.created_at.isoformat(), + "archived_at": datetime.datetime.now(datetime.UTC).isoformat(), + "tables": { + stat.table_name: { + "row_count": stat.row_count, + "checksum": stat.checksum, + "size_bytes": stat.size_bytes, + } + for stat in table_stats + }, + } + + def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes: + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: + archive.writestr("manifest.json", manifest_data) + for table_name in self.ARCHIVED_TABLES: + data = table_payloads.get(table_name) + if data is None: + raise ValueError(f"Missing archive payload for {table_name}") + archive.writestr(f"{table_name}.jsonl", data) + return buffer.getvalue() + + def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_ids = [run.id for run in runs] + return self._get_workflow_node_execution_repo(session).delete_by_runs(session, run_ids) + + def _get_workflow_node_execution_repo( + self, + session: Session, + ) -> DifyAPIWorkflowNodeExecutionRepository: + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=session.get_bind(), expire_on_commit=False) + return DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + return self.workflow_run_repo diff --git a/api/services/retention/workflow_run/constants.py b/api/services/retention/workflow_run/constants.py new file mode 100644 index 0000000000..162bb4947d --- /dev/null +++ b/api/services/retention/workflow_run/constants.py @@ -0,0 +1,2 @@ +ARCHIVE_SCHEMA_VERSION = "1.0" +ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip" diff --git a/api/services/retention/workflow_run/delete_archived_workflow_run.py b/api/services/retention/workflow_run/delete_archived_workflow_run.py new file mode 100644 index 0000000000..11873bf1b9 --- /dev/null +++ b/api/services/retention/workflow_run/delete_archived_workflow_run.py @@ -0,0 +1,134 @@ +""" +Delete Archived Workflow Run Service. + +This service deletes archived workflow run data from the database while keeping +archive logs intact. +""" + +import time +from collections.abc import Sequence +from dataclasses import dataclass, field +from datetime import datetime + +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_database import db +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository + + +@dataclass +class DeleteResult: + run_id: str + tenant_id: str + success: bool + deleted_counts: dict[str, int] = field(default_factory=dict) + error: str | None = None + elapsed_time: float = 0.0 + + +class ArchivedWorkflowRunDeletion: + def __init__(self, dry_run: bool = False): + self.dry_run = dry_run + self.workflow_run_repo: APIWorkflowRunRepository | None = None + + def delete_by_run_id(self, run_id: str) -> DeleteResult: + start_time = time.time() + result = DeleteResult(run_id=run_id, tenant_id="", success=False) + + repo = self._get_workflow_run_repo() + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + with session_maker() as session: + run = session.get(WorkflowRun, run_id) + if not run: + result.error = f"Workflow run {run_id} not found" + result.elapsed_time = time.time() - start_time + return result + + result.tenant_id = run.tenant_id + if not repo.get_archived_run_ids(session, [run.id]): + result.error = f"Workflow run {run_id} is not archived" + result.elapsed_time = time.time() - start_time + return result + + result = self._delete_run(run) + result.elapsed_time = time.time() - start_time + return result + + def delete_batch( + self, + tenant_ids: list[str] | None, + start_date: datetime, + end_date: datetime, + limit: int = 100, + ) -> list[DeleteResult]: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + results: list[DeleteResult] = [] + + repo = self._get_workflow_run_repo() + with session_maker() as session: + runs = list( + repo.get_archived_runs_by_time_range( + session=session, + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + ) + for run in runs: + results.append(self._delete_run(run)) + + return results + + def _delete_run(self, run: WorkflowRun) -> DeleteResult: + start_time = time.time() + result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + if self.dry_run: + result.success = True + result.elapsed_time = time.time() - start_time + return result + + repo = self._get_workflow_run_repo() + try: + deleted_counts = repo.delete_runs_with_related( + [run], + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + result.deleted_counts = deleted_counts + result.success = True + except Exception as e: + result.error = str(e) + result.elapsed_time = time.time() - start_time + return result + + @staticmethod + def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + @staticmethod + def _delete_node_executions( + session: Session, + runs: Sequence[WorkflowRun], + ) -> tuple[int, int]: + from repositories.factory import DifyAPIRepositoryFactory + + run_ids = [run.id for run in runs] + repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False) + ) + return repo.delete_by_runs(session, run_ids) + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + from repositories.factory import DifyAPIRepositoryFactory + + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) + return self.workflow_run_repo diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py new file mode 100644 index 0000000000..d4a6e87585 --- /dev/null +++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py @@ -0,0 +1,481 @@ +""" +Restore Archived Workflow Run Service. + +This service restores archived workflow run data from S3-compatible storage +back to the database. +""" + +import io +import json +import logging +import time +import zipfile +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from datetime import datetime +from typing import Any, cast + +import click +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker + +from extensions.ext_database import db +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageNotConfiguredError, + get_archive_storage, +) +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowArchiveLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPause, + WorkflowPauseReason, + WorkflowRun, +) +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.factory import DifyAPIRepositoryFactory +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME + +logger = logging.getLogger(__name__) + + +# Mapping of table names to SQLAlchemy models +TABLE_MODELS = { + "workflow_runs": WorkflowRun, + "workflow_app_logs": WorkflowAppLog, + "workflow_node_executions": WorkflowNodeExecutionModel, + "workflow_node_execution_offload": WorkflowNodeExecutionOffload, + "workflow_pauses": WorkflowPause, + "workflow_pause_reasons": WorkflowPauseReason, + "workflow_trigger_logs": WorkflowTriggerLog, +} + +SchemaMapper = Callable[[dict[str, Any]], dict[str, Any]] + +SCHEMA_MAPPERS: dict[str, dict[str, SchemaMapper]] = { + "1.0": {}, +} + + +@dataclass +class RestoreResult: + """Result of restoring a single workflow run.""" + + run_id: str + tenant_id: str + success: bool + restored_counts: dict[str, int] + error: str | None = None + elapsed_time: float = 0.0 + + +class WorkflowRunRestore: + """ + Restore archived workflow run data from storage to database. + + This service reads archived data from storage and restores it to the + database tables. It handles idempotency by skipping records that already + exist in the database. + """ + + def __init__(self, dry_run: bool = False, workers: int = 1): + """ + Initialize the restore service. + + Args: + dry_run: If True, only preview without making changes + workers: Number of concurrent workflow runs to restore + """ + self.dry_run = dry_run + if workers < 1: + raise ValueError("workers must be at least 1") + self.workers = workers + self.workflow_run_repo: APIWorkflowRunRepository | None = None + + def _restore_from_run( + self, + run: WorkflowRun | WorkflowArchiveLog, + *, + session_maker: sessionmaker, + ) -> RestoreResult: + start_time = time.time() + run_id = run.workflow_run_id if isinstance(run, WorkflowArchiveLog) else run.id + created_at = run.run_created_at if isinstance(run, WorkflowArchiveLog) else run.created_at + result = RestoreResult( + run_id=run_id, + tenant_id=run.tenant_id, + success=False, + restored_counts={}, + ) + + if not self.dry_run: + click.echo( + click.style( + f"Starting restore for workflow run {run_id} (tenant={run.tenant_id})", + fg="white", + ) + ) + + try: + storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + result.error = str(e) + click.echo(click.style(f"Archive storage not configured: {e}", fg="red")) + result.elapsed_time = time.time() - start_time + return result + + prefix = ( + f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/" + f"month={created_at.strftime('%m')}/workflow_run_id={run_id}" + ) + archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + try: + archive_data = storage.get_object(archive_key) + except FileNotFoundError: + result.error = f"Archive bundle not found: {archive_key}" + click.echo(click.style(result.error, fg="red")) + result.elapsed_time = time.time() - start_time + return result + + with session_maker() as session: + try: + with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive: + try: + manifest = self._load_manifest_from_zip(archive) + except ValueError as e: + result.error = f"Archive bundle invalid: {e}" + click.echo(click.style(result.error, fg="red")) + return result + + tables = manifest.get("tables", {}) + schema_version = self._get_schema_version(manifest) + for table_name, info in tables.items(): + row_count = info.get("row_count", 0) + if row_count == 0: + result.restored_counts[table_name] = 0 + continue + + if self.dry_run: + result.restored_counts[table_name] = row_count + continue + + member_path = f"{table_name}.jsonl" + try: + data = archive.read(member_path) + except KeyError: + click.echo( + click.style( + f" Warning: Table data not found in archive: {member_path}", + fg="yellow", + ) + ) + result.restored_counts[table_name] = 0 + continue + + records = ArchiveStorage.deserialize_from_jsonl(data) + restored = self._restore_table_records( + session, + table_name, + records, + schema_version=schema_version, + ) + result.restored_counts[table_name] = restored + if not self.dry_run: + click.echo( + click.style( + f" Restored {restored}/{len(records)} records to {table_name}", + fg="white", + ) + ) + + # Verify row counts match manifest + manifest_total = sum(info.get("row_count", 0) for info in tables.values()) + restored_total = sum(result.restored_counts.values()) + + if not self.dry_run: + # Note: restored count might be less than manifest count if records already exist + logger.info( + "Restore verification: manifest_total=%d, restored_total=%d", + manifest_total, + restored_total, + ) + + # Delete the archive log record after successful restore + repo = self._get_workflow_run_repo() + repo.delete_archive_log_by_run_id(session, run_id) + + session.commit() + + result.success = True + if not self.dry_run: + click.echo( + click.style( + f"Completed restore for workflow run {run_id}: restored={result.restored_counts}", + fg="green", + ) + ) + + except Exception as e: + logger.exception("Failed to restore workflow run %s", run_id) + result.error = str(e) + session.rollback() + click.echo(click.style(f"Restore failed: {e}", fg="red")) + + result.elapsed_time = time.time() - start_time + return result + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) + return self.workflow_run_repo + + @staticmethod + def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]: + try: + data = archive.read("manifest.json") + except KeyError as e: + raise ValueError("manifest.json missing from archive bundle") from e + return json.loads(data.decode("utf-8")) + + def _restore_table_records( + self, + session: Session, + table_name: str, + records: list[dict[str, Any]], + *, + schema_version: str, + ) -> int: + """ + Restore records to a table. + + Uses INSERT ... ON CONFLICT DO NOTHING for idempotency. + + Args: + session: Database session + table_name: Name of the table + records: List of record dictionaries + schema_version: Archived schema version from manifest + + Returns: + Number of records actually inserted + """ + if not records: + return 0 + + model = TABLE_MODELS.get(table_name) + if not model: + logger.warning("Unknown table: %s", table_name) + return 0 + + column_names, required_columns, non_nullable_with_default = self._get_model_column_info(model) + unknown_fields: set[str] = set() + + # Apply schema mapping, filter to current columns, then convert datetimes + converted_records = [] + for record in records: + mapped = self._apply_schema_mapping(table_name, schema_version, record) + unknown_fields.update(set(mapped.keys()) - column_names) + filtered = {key: value for key, value in mapped.items() if key in column_names} + for key in non_nullable_with_default: + if key in filtered and filtered[key] is None: + filtered.pop(key) + missing_required = [key for key in required_columns if key not in filtered or filtered.get(key) is None] + if missing_required: + missing_cols = ", ".join(sorted(missing_required)) + raise ValueError( + f"Missing required columns for {table_name} (schema_version={schema_version}): {missing_cols}" + ) + converted = self._convert_datetime_fields(filtered, model) + converted_records.append(converted) + if unknown_fields: + logger.warning( + "Dropped unknown columns for %s (schema_version=%s): %s", + table_name, + schema_version, + ", ".join(sorted(unknown_fields)), + ) + + # Use INSERT ... ON CONFLICT DO NOTHING for idempotency + stmt = pg_insert(model).values(converted_records) + stmt = stmt.on_conflict_do_nothing(index_elements=["id"]) + + result = session.execute(stmt) + return cast(CursorResult, result).rowcount or 0 + + def _convert_datetime_fields( + self, + record: dict[str, Any], + model: type[DeclarativeBase] | Any, + ) -> dict[str, Any]: + """Convert ISO datetime strings to datetime objects.""" + from sqlalchemy import DateTime + + result = dict(record) + + for column in model.__table__.columns: + if isinstance(column.type, DateTime): + value = result.get(column.key) + if isinstance(value, str): + try: + result[column.key] = datetime.fromisoformat(value) + except ValueError: + pass + + return result + + def _get_schema_version(self, manifest: dict[str, Any]) -> str: + schema_version = manifest.get("schema_version") + if not schema_version: + logger.warning("Manifest missing schema_version; defaulting to 1.0") + schema_version = "1.0" + schema_version = str(schema_version) + if schema_version not in SCHEMA_MAPPERS: + raise ValueError(f"Unsupported schema_version {schema_version}. Add a mapping before restoring.") + return schema_version + + def _apply_schema_mapping( + self, + table_name: str, + schema_version: str, + record: dict[str, Any], + ) -> dict[str, Any]: + # Keep hook for forward/backward compatibility when schema evolves. + mapper = SCHEMA_MAPPERS.get(schema_version, {}).get(table_name) + if mapper is None: + return dict(record) + return mapper(record) + + def _get_model_column_info( + self, + model: type[DeclarativeBase] | Any, + ) -> tuple[set[str], set[str], set[str]]: + columns = list(model.__table__.columns) + column_names = {column.key for column in columns} + required_columns = { + column.key + for column in columns + if not column.nullable + and column.default is None + and column.server_default is None + and not column.autoincrement + } + non_nullable_with_default = { + column.key + for column in columns + if not column.nullable + and (column.default is not None or column.server_default is not None or column.autoincrement) + } + return column_names, required_columns, non_nullable_with_default + + def restore_batch( + self, + tenant_ids: list[str] | None, + start_date: datetime, + end_date: datetime, + limit: int = 100, + ) -> list[RestoreResult]: + """ + Restore multiple workflow runs by time range. + + Args: + tenant_ids: Optional tenant IDs + start_date: Start date filter + end_date: End date filter + limit: Maximum number of runs to restore (default: 100) + + Returns: + List of RestoreResult objects + """ + results: list[RestoreResult] = [] + if tenant_ids is not None and not tenant_ids: + return results + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = self._get_workflow_run_repo() + + with session_maker() as session: + archive_logs = repo.get_archived_logs_by_time_range( + session=session, + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + + click.echo( + click.style( + f"Found {len(archive_logs)} archived workflow runs to restore", + fg="white", + ) + ) + + def _restore_with_session(archive_log: WorkflowArchiveLog) -> RestoreResult: + return self._restore_from_run( + archive_log, + session_maker=session_maker, + ) + + with ThreadPoolExecutor(max_workers=self.workers) as executor: + results = list(executor.map(_restore_with_session, archive_logs)) + + total_counts: dict[str, int] = {} + for result in results: + for table_name, count in result.restored_counts.items(): + total_counts[table_name] = total_counts.get(table_name, 0) + count + success_count = sum(1 for result in results if result.success) + + if self.dry_run: + click.echo( + click.style( + f"[DRY RUN] Would restore {len(results)} workflow runs: totals={total_counts}", + fg="yellow", + ) + ) + else: + click.echo( + click.style( + f"Restored {success_count}/{len(results)} workflow runs: totals={total_counts}", + fg="green", + ) + ) + + return results + + def restore_by_run_id( + self, + run_id: str, + ) -> RestoreResult: + """ + Restore a single workflow run by run ID. + """ + repo = self._get_workflow_run_repo() + archive_log = repo.get_archived_log_by_run_id(run_id) + + if not archive_log: + click.echo(click.style(f"Workflow run archive {run_id} not found", fg="red")) + return RestoreResult( + run_id=run_id, + tenant_id="", + success=False, + restored_counts={}, + error=f"Workflow run archive {run_id} not found", + ) + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + result = self._restore_from_run(archive_log, session_maker=session_maker) + if self.dry_run and result.success: + click.echo( + click.style( + f"[DRY RUN] Would restore workflow run {run_id}: totals={result.restored_counts}", + fg="yellow", + ) + ) + return result diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 8574d30255..efc76c33bc 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from core.workflow.enums import WorkflowExecutionStatus -from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun +from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog from services.plugin.plugin_service import PluginService @@ -173,7 +173,80 @@ class WorkflowAppService: "data": items, } - def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]: + def get_paginate_workflow_archive_logs( + self, + *, + session: Session, + app_model: App, + page: int = 1, + limit: int = 20, + ): + """ + Get paginate workflow archive logs using SQLAlchemy 2.0 style. + """ + stmt = select(WorkflowArchiveLog).where( + WorkflowArchiveLog.tenant_id == app_model.tenant_id, + WorkflowArchiveLog.app_id == app_model.id, + WorkflowArchiveLog.log_id.isnot(None), + ) + + stmt = stmt.order_by(WorkflowArchiveLog.run_created_at.desc()) + + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = session.scalar(count_stmt) or 0 + + offset_stmt = stmt.offset((page - 1) * limit).limit(limit) + + logs = list(session.scalars(offset_stmt).all()) + account_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.ACCOUNT} + end_user_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.END_USER} + + accounts_by_id = {} + if account_ids: + accounts_by_id = { + account.id: account + for account in session.scalars(select(Account).where(Account.id.in_(account_ids))).all() + } + + end_users_by_id = {} + if end_user_ids: + end_users_by_id = { + end_user.id: end_user + for end_user in session.scalars(select(EndUser).where(EndUser.id.in_(end_user_ids))).all() + } + + items = [] + for log in logs: + if log.created_by_role == CreatorUserRole.ACCOUNT: + created_by_account = accounts_by_id.get(log.created_by) + created_by_end_user = None + elif log.created_by_role == CreatorUserRole.END_USER: + created_by_account = None + created_by_end_user = end_users_by_id.get(log.created_by) + else: + created_by_account = None + created_by_end_user = None + + items.append( + { + "id": log.id, + "workflow_run": log.workflow_run_summary, + "trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, log.trigger_metadata), + "created_by_account": created_by_account, + "created_by_end_user": created_by_end_user, + "created_at": log.log_created_at, + } + ) + + return { + "page": page, + "limit": limit, + "total": total, + "has_more": total > page * limit, + "data": items, + } + + def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]: metadata: dict[str, Any] | None = self._safe_json_loads(meta_val) if not metadata: return {} diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 4e5fb08870..817249845a 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -11,8 +11,10 @@ from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker +from configs import dify_config from core.db.session_factory import session_factory from extensions.ext_database import db +from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from models import ( ApiToken, AppAnnotationHitHistory, @@ -43,6 +45,7 @@ from models.workflow import ( ConversationVariable, Workflow, WorkflowAppLog, + WorkflowArchiveLog, ) from repositories.factory import DifyAPIRepositoryFactory @@ -67,6 +70,9 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_app_workflow_runs(tenant_id, app_id) _delete_app_workflow_node_executions(tenant_id, app_id) _delete_app_workflow_app_logs(tenant_id, app_id) + if dify_config.BILLING_ENABLED and dify_config.ARCHIVE_STORAGE_ENABLED: + _delete_app_workflow_archive_logs(tenant_id, app_id) + _delete_archived_workflow_run_files(tenant_id, app_id) _delete_app_conversations(tenant_id, app_id) _delete_app_messages(tenant_id, app_id) _delete_workflow_tool_providers(tenant_id, app_id) @@ -252,6 +258,45 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): ) +def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): + def del_workflow_archive_log(workflow_archive_log_id: str): + db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( + synchronize_session=False + ) + + _delete_records( + """select id from workflow_archive_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_workflow_archive_log, + "workflow archive log", + ) + + +def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): + prefix = f"{tenant_id}/app_id={app_id}/" + try: + archive_storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + logger.info("Archive storage not configured, skipping archive file cleanup: %s", e) + return + + try: + keys = archive_storage.list_objects(prefix) + except Exception: + logger.exception("Failed to list archive files for app %s", app_id) + return + + deleted = 0 + for key in keys: + try: + archive_storage.delete_object(key) + deleted += 1 + except Exception: + logger.exception("Failed to delete archive object %s", key) + + logger.info("Deleted %s archive objects for app %s", deleted, app_id) + + def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(session, conversation_id: str): session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py index 7cdb2328f2..078ec5f6ab 100644 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ b/api/tests/unit_tests/core/workflow/test_enums.py @@ -30,3 +30,12 @@ class TestWorkflowExecutionStatus: for status in non_ended_statuses: assert not status.is_ended(), f"{status} should not be considered ended" + + def test_ended_values(self): + """Test ended_values returns the expected status values.""" + assert set(WorkflowExecutionStatus.ended_values()) == { + WorkflowExecutionStatus.SUCCEEDED.value, + WorkflowExecutionStatus.FAILED.value, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, + WorkflowExecutionStatus.STOPPED.value, + } diff --git a/api/tests/unit_tests/libs/test_archive_storage.py b/api/tests/unit_tests/libs/test_archive_storage.py index 697760e33a..de3c9c4737 100644 --- a/api/tests/unit_tests/libs/test_archive_storage.py +++ b/api/tests/unit_tests/libs/test_archive_storage.py @@ -37,6 +37,20 @@ def _client_error(code: str) -> ClientError: def _mock_client(monkeypatch): client = MagicMock() client.head_bucket.return_value = None + # Configure put_object to return a proper ETag that matches the MD5 hash + # The ETag format is typically the MD5 hash wrapped in quotes + + def mock_put_object(**kwargs): + md5_hash = kwargs.get("Body", b"") + if isinstance(md5_hash, bytes): + md5_hash = hashlib.md5(md5_hash).hexdigest() + else: + md5_hash = hashlib.md5(md5_hash.encode()).hexdigest() + response = MagicMock() + response.get.return_value = f'"{md5_hash}"' + return response + + client.put_object.side_effect = mock_put_object boto_client = MagicMock(return_value=client) monkeypatch.setattr(storage_module.boto3, "client", boto_client) return client, boto_client @@ -254,8 +268,8 @@ def test_serialization_roundtrip(): {"id": "2", "value": 123}, ] - data = ArchiveStorage.serialize_to_jsonl_gz(records) - decoded = ArchiveStorage.deserialize_from_jsonl_gz(data) + data = ArchiveStorage.serialize_to_jsonl(records) + decoded = ArchiveStorage.deserialize_from_jsonl(data) assert decoded[0]["id"] == "1" assert decoded[0]["payload"]["nested"] == "value" diff --git a/api/tests/unit_tests/services/test_archive_workflow_run_logs.py b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py new file mode 100644 index 0000000000..ef62dacd6b --- /dev/null +++ b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py @@ -0,0 +1,54 @@ +""" +Unit tests for workflow run archiving functionality. + +This module contains tests for: +- Archive service +- Rollback service +""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME + + +class TestWorkflowRunArchiver: + """Tests for the WorkflowRunArchiver class.""" + + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage") + def test_archiver_initialization(self, mock_get_storage, mock_config): + """Test archiver can be initialized with various options.""" + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + mock_config.BILLING_ENABLED = False + + archiver = WorkflowRunArchiver( + days=90, + batch_size=100, + tenant_ids=["test-tenant"], + limit=50, + dry_run=True, + ) + + assert archiver.days == 90 + assert archiver.batch_size == 100 + assert archiver.tenant_ids == ["test-tenant"] + assert archiver.limit == 50 + assert archiver.dry_run is True + + def test_get_archive_key(self): + """Test archive key generation.""" + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + archiver = WorkflowRunArchiver.__new__(WorkflowRunArchiver) + + mock_run = MagicMock() + mock_run.tenant_id = "tenant-123" + mock_run.app_id = "app-999" + mock_run.id = "run-456" + mock_run.created_at = datetime(2024, 1, 15, 12, 0, 0) + + key = archiver._get_archive_key(mock_run) + + assert key == f"tenant-123/app_id=app-999/year=2024/month=01/workflow_run_id=run-456/{ARCHIVE_BUNDLE_NAME}" diff --git a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py new file mode 100644 index 0000000000..2c9d946ea6 --- /dev/null +++ b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py @@ -0,0 +1,180 @@ +""" +Unit tests for archived workflow run deletion service. +""" + +from unittest.mock import MagicMock, patch + + +class TestArchivedWorkflowRunDeletion: + def test_delete_by_run_id_returns_error_when_run_missing(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + session = MagicMock() + session.get.return_value = None + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is False + assert result.error == "Workflow run run-1 not found" + repo.get_archived_run_ids.assert_not_called() + + def test_delete_by_run_id_returns_error_when_not_archived(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + repo.get_archived_run_ids.return_value = set() + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + session = MagicMock() + session.get.return_value = run + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object(deleter, "_delete_run") as mock_delete_run, + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is False + assert result.error == "Workflow run run-1 is not archived" + mock_delete_run.assert_not_called() + + def test_delete_by_run_id_calls_delete_run(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + repo.get_archived_run_ids.return_value = {"run-1"} + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + session = MagicMock() + session.get.return_value = run + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object(deleter, "_delete_run", return_value=MagicMock(success=True)) as mock_delete_run, + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is True + mock_delete_run.assert_called_once_with(run) + + def test_delete_batch_uses_repo(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + run1 = MagicMock() + run1.id = "run-1" + run1.tenant_id = "tenant-1" + run2 = MagicMock() + run2.id = "run-2" + run2.tenant_id = "tenant-1" + repo.get_archived_runs_by_time_range.return_value = [run1, run2] + + session = MagicMock() + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + start_date = MagicMock() + end_date = MagicMock() + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object( + deleter, "_delete_run", side_effect=[MagicMock(success=True), MagicMock(success=True)] + ) as mock_delete_run, + ): + results = deleter.delete_batch( + tenant_ids=["tenant-1"], + start_date=start_date, + end_date=end_date, + limit=2, + ) + + assert len(results) == 2 + repo.get_archived_runs_by_time_range.assert_called_once_with( + session=session, + tenant_ids=["tenant-1"], + start_date=start_date, + end_date=end_date, + limit=2, + ) + assert mock_delete_run.call_count == 2 + + def test_delete_run_dry_run(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion(dry_run=True) + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + result = deleter._delete_run(run) + + assert result.success is True + mock_get_repo.assert_not_called() + + def test_delete_run_calls_repo(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + repo = MagicMock() + repo.delete_runs_with_related.return_value = {"runs": 1} + + with patch.object(deleter, "_get_workflow_run_repo", return_value=repo): + result = deleter._delete_run(run) + + assert result.success is True + assert result.deleted_counts == {"runs": 1} + repo.delete_runs_with_related.assert_called_once() diff --git a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py new file mode 100644 index 0000000000..68aa8c0fe1 --- /dev/null +++ b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py @@ -0,0 +1,65 @@ +""" +Unit tests for workflow run restore functionality. +""" + +from datetime import datetime +from unittest.mock import MagicMock + + +class TestWorkflowRunRestore: + """Tests for the WorkflowRunRestore class.""" + + def test_restore_initialization(self): + """Restore service should respect dry_run flag.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + restore = WorkflowRunRestore(dry_run=True) + + assert restore.dry_run is True + + def test_convert_datetime_fields(self): + """ISO datetime strings should be converted to datetime objects.""" + from models.workflow import WorkflowRun + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + record = { + "id": "test-id", + "created_at": "2024-01-01T12:00:00", + "finished_at": "2024-01-01T12:05:00", + "name": "test", + } + + restore = WorkflowRunRestore() + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["created_at"].month == 1 + assert result["name"] == "test" + + def test_restore_table_records_returns_rowcount(self): + """Restore should return inserted rowcount.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + session = MagicMock() + session.execute.return_value = MagicMock(rowcount=2) + + restore = WorkflowRunRestore() + records = [{"id": "p1", "workflow_run_id": "r1", "created_at": "2024-01-01T00:00:00"}] + + restored = restore._restore_table_records(session, "workflow_pauses", records, schema_version="1.0") + + assert restored == 2 + session.execute.assert_called_once() + + def test_restore_table_records_unknown_table(self): + """Unknown table names should be ignored gracefully.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + session = MagicMock() + + restore = WorkflowRunRestore() + restored = restore._restore_table_records(session, "unknown_table", [{"id": "x1"}], schema_version="1.0") + + assert restored == 0 + session.execute.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index ccf43591f0..a14bbb01d0 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,7 +2,11 @@ from unittest.mock import ANY, MagicMock, call, patch import pytest +from libs.archive_storage import ArchiveStorageNotConfiguredError +from models.workflow import WorkflowArchiveLog from tasks.remove_app_and_related_data_task import ( + _delete_app_workflow_archive_logs, + _delete_archived_workflow_run_files, _delete_draft_variable_offload_data, _delete_draft_variables, delete_draft_variables_batch, @@ -324,3 +328,68 @@ class TestDeleteDraftVariableOffloadData: # Verify error was logged mock_logging.exception.assert_called_once_with("Error deleting draft variable offload data:") + + +class TestDeleteWorkflowArchiveLogs: + @patch("tasks.remove_app_and_related_data_task._delete_records") + @patch("tasks.remove_app_and_related_data_task.db") + def test_delete_app_workflow_archive_logs_calls_delete_records(self, mock_db, mock_delete_records): + tenant_id = "tenant-1" + app_id = "app-1" + + _delete_app_workflow_archive_logs(tenant_id, app_id) + + mock_delete_records.assert_called_once() + query_sql, params, delete_func, name = mock_delete_records.call_args[0] + assert "workflow_archive_logs" in query_sql + assert params == {"tenant_id": tenant_id, "app_id": app_id} + assert name == "workflow archive log" + + mock_query = MagicMock() + mock_delete_query = MagicMock() + mock_query.where.return_value = mock_delete_query + mock_db.session.query.return_value = mock_query + + delete_func("log-1") + + mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) + mock_query.where.assert_called_once() + mock_delete_query.delete.assert_called_once_with(synchronize_session=False) + + +class TestDeleteArchivedWorkflowRunFiles: + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_not_configured(self, mock_logger, mock_get_storage): + mock_get_storage.side_effect = ArchiveStorageNotConfiguredError("missing config") + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + assert mock_logger.info.call_count == 1 + assert "Archive storage not configured" in mock_logger.info.call_args[0][0] + + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_list_failure(self, mock_logger, mock_get_storage): + storage = MagicMock() + storage.list_objects.side_effect = Exception("list failed") + mock_get_storage.return_value = storage + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + storage.list_objects.assert_called_once_with("tenant-1/app_id=app-1/") + storage.delete_object.assert_not_called() + mock_logger.exception.assert_called_once_with("Failed to list archive files for app %s", "app-1") + + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_success(self, mock_logger, mock_get_storage): + storage = MagicMock() + storage.list_objects.return_value = ["key-1", "key-2"] + mock_get_storage.return_value = storage + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + storage.list_objects.assert_called_once_with("tenant-1/app_id=app-1/") + storage.delete_object.assert_has_calls([call("key-1"), call("key-2")], any_order=False) + mock_logger.info.assert_called_with("Deleted %s archive objects for app %s", 2, "app-1")