diff --git a/api/controllers/console/snippets/snippet_workflow.py b/api/controllers/console/snippets/snippet_workflow.py index d86c60ead5..7b8083cd80 100644 --- a/api/controllers/console/snippets/snippet_workflow.py +++ b/api/controllers/console/snippets/snippet_workflow.py @@ -6,12 +6,16 @@ from typing import ParamSpec, TypeVar from flask import request from flask_restx import Resource, fields, marshal, marshal_with from sqlalchemy.orm import Session -from werkzeug.exceptions import InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync -from controllers.console.app.workflow import workflow_model, workflow_pagination_model +from controllers.console.app.workflow import ( + RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE, + workflow_model, + workflow_pagination_model, +) from controllers.console.app.workflow_run import ( workflow_run_detail_model, workflow_run_node_execution_list_model, @@ -42,7 +46,7 @@ from libs import helper from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.snippet import CustomizedSnippet -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.snippet_generate_service import SnippetGenerateService from services.snippet_service import SnippetService @@ -286,6 +290,44 @@ class SnippetPublishedAllWorkflowApi(Resource): } +@console_ns.route("/snippets//workflows//restore") +class SnippetDraftWorkflowRestoreApi(Resource): + @console_ns.doc("restore_snippet_workflow_to_draft") + @console_ns.doc(description="Restore a published snippet workflow version into the draft workflow") + @console_ns.doc(params={"snippet_id": "Snippet ID", "workflow_id": "Published workflow ID"}) + @console_ns.response(200, "Workflow restored successfully") + @console_ns.response(400, "Source workflow must be published") + @console_ns.response(404, "Workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def post(self, snippet: CustomizedSnippet, workflow_id: str): + """Restore a published snippet workflow version into the draft workflow.""" + current_user, _ = current_account_with_tenant() + snippet_service = SnippetService() + + try: + workflow = snippet_service.restore_published_workflow_to_draft( + snippet=snippet, + workflow_id=workflow_id, + account=current_user, + ) + except IsDraftWorkflowError as exc: + raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + except ValueError as exc: + raise BadRequest(str(exc)) from exc + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/snippets//workflow-runs") class SnippetWorkflowRunsApi(Resource): @console_ns.doc("list_snippet_workflow_runs") diff --git a/api/services/snippet_service.py b/api/services/snippet_service.py index 0d525a6248..7cd3dba514 100644 --- a/api/services/snippet_service.py +++ b/api/services/snippet_service.py @@ -22,7 +22,8 @@ from models.workflow import ( WorkflowType, ) from repositories.factory import DifyAPIRepositoryFactory -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError +from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) @@ -306,6 +307,31 @@ class SnippetService: ) return workflow + def get_published_workflow_by_id(self, snippet: CustomizedSnippet, workflow_id: str) -> Workflow | None: + """ + Get a published workflow snapshot by ID for snippet history restore. + + :param snippet: CustomizedSnippet instance + :param workflow_id: Workflow ID + :return: Published Workflow or None + :raises IsDraftWorkflowError: If the workflow ID points to a draft workflow + """ + workflow = ( + db.session.query(Workflow) + .where( + Workflow.tenant_id == snippet.tenant_id, + Workflow.app_id == snippet.id, + self._snippet_kind_filter(), + Workflow.id == workflow_id, + ) + .first() + ) + if not workflow: + return None + if workflow.version == Workflow.VERSION_DRAFT: + raise IsDraftWorkflowError("source workflow must be published") + return workflow + def sync_draft_workflow( self, *, @@ -371,6 +397,46 @@ class SnippetService: db.session.commit() return workflow + def restore_published_workflow_to_draft( + self, + *, + snippet: CustomizedSnippet, + workflow_id: str, + account: Account, + ) -> Workflow: + """ + Restore a published snippet workflow snapshot into the draft workflow. + + :param snippet: CustomizedSnippet instance + :param workflow_id: Published workflow ID + :param account: Account making the change + :return: Restored draft Workflow + :raises WorkflowNotFoundError: If the source workflow does not exist + :raises IsDraftWorkflowError: If the source workflow is a draft + :raises ValueError: If the restored graph is invalid for snippets + """ + source_workflow = self.get_published_workflow_by_id(snippet=snippet, workflow_id=workflow_id) + if not source_workflow: + raise WorkflowNotFoundError("Workflow not found.") + + SnippetService.validate_snippet_graph_forbidden_nodes(source_workflow.graph_dict) + + draft_workflow = self.get_draft_workflow(snippet=snippet) + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + source_workflow=source_workflow, + draft_workflow=draft_workflow, + account=account, + updated_at_factory=lambda: datetime.now(UTC).replace(tzinfo=None), + ) + + if is_new_draft: + db.session.add(draft_workflow) + + db.session.commit() + return draft_workflow + def publish_workflow( self, *, diff --git a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py new file mode 100644 index 0000000000..574cbcba59 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +import pytest +from werkzeug.exceptions import HTTPException, NotFound + +from controllers.console.snippets import snippet_workflow as snippet_workflow_module + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def test_restore_published_snippet_workflow_to_draft_success( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + workflow = SimpleNamespace( + unique_hash="restored-hash", + updated_at=None, + created_at=datetime(2024, 1, 1), + ) + user = SimpleNamespace(id="account-1") + snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + + monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) + monkeypatch.setattr( + snippet_workflow_module, + "SnippetService", + lambda: SimpleNamespace(restore_published_workflow_to_draft=lambda **_kwargs: workflow), + ) + + api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/snippets/snippet-1/workflows/published-workflow/restore", + method="POST", + ): + response = handler(api, snippet=snippet, workflow_id="published-workflow") + + assert response["result"] == "success" + assert response["hash"] == "restored-hash" + + +def test_restore_published_snippet_workflow_to_draft_not_found( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + user = SimpleNamespace(id="account-1") + snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + + monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) + monkeypatch.setattr( + snippet_workflow_module, + "SnippetService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + snippet_workflow_module.WorkflowNotFoundError("Workflow not found") + ) + ), + ) + + api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/snippets/snippet-1/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(NotFound): + handler(api, snippet=snippet, workflow_id="published-workflow") + + +def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_source( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + user = SimpleNamespace(id="account-1") + snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + + monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) + monkeypatch.setattr( + snippet_workflow_module, + "SnippetService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + snippet_workflow_module.IsDraftWorkflowError("source workflow must be published") + ) + ), + ) + + api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/snippets/snippet-1/workflows/draft-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler(api, snippet=snippet, workflow_id="draft-workflow") + + assert exc.value.code == 400 + assert exc.value.description == snippet_workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE + + +def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_graph( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + user = SimpleNamespace(id="account-1") + snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + + monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) + monkeypatch.setattr( + snippet_workflow_module, + "SnippetService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + ValueError("invalid snippet workflow graph") + ) + ), + ) + + api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/snippets/snippet-1/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler(api, snippet=snippet, workflow_id="published-workflow") + + assert exc.value.code == 400 + assert exc.value.description == "invalid snippet workflow graph" diff --git a/api/tests/unit_tests/services/test_snippet_service.py b/api/tests/unit_tests/services/test_snippet_service.py new file mode 100644 index 0000000000..96d78c16e4 --- /dev/null +++ b/api/tests/unit_tests/services/test_snippet_service.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from models.workflow import Workflow, WorkflowKind, WorkflowType +from services.errors.app import WorkflowNotFoundError +from services.snippet_service import SnippetService + + +def _create_workflow(*, workflow_id: str, version: str, graph: dict, features: dict) -> Workflow: + return Workflow( + id=workflow_id, + tenant_id="tenant-1", + app_id="snippet-1", + type=WorkflowType.WORKFLOW.value, + kind=WorkflowKind.SNIPPET.value, + version=version, + graph=json.dumps(graph), + features=json.dumps(features), + created_by="account-1", + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + +def test_restore_published_snippet_workflow_to_draft_copies_source_snapshot( + monkeypatch: pytest.MonkeyPatch, +) -> None: + snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + account = SimpleNamespace(id="account-2") + source_graph = {"nodes": [{"id": "llm-1", "data": {"type": "llm"}}], "edges": []} + source_features = {"opening_statement": "hello"} + source_workflow = _create_workflow( + workflow_id="published-workflow", + version="2026-04-28 00:00:00", + graph=source_graph, + features=source_features, + ) + draft_workflow = _create_workflow( + workflow_id="draft-workflow", + version=Workflow.VERSION_DRAFT, + graph={"nodes": [], "edges": []}, + features={}, + ) + service = SnippetService.__new__(SnippetService) + session = SimpleNamespace(add=Mock(), commit=Mock()) + + monkeypatch.setattr(service, "get_published_workflow_by_id", Mock(return_value=source_workflow)) + monkeypatch.setattr(service, "get_draft_workflow", Mock(return_value=draft_workflow)) + monkeypatch.setattr("services.snippet_service.db.session", session) + + result = service.restore_published_workflow_to_draft( + snippet=snippet, + workflow_id=source_workflow.id, + account=account, + ) + + assert result is draft_workflow + assert draft_workflow.graph_dict == source_graph + assert draft_workflow.features_dict == source_features + assert draft_workflow.updated_by == account.id + session.add.assert_not_called() + session.commit.assert_called_once() + + +def test_restore_published_snippet_workflow_to_draft_raises_when_source_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + account = SimpleNamespace(id="account-2") + service = SnippetService.__new__(SnippetService) + + monkeypatch.setattr(service, "get_published_workflow_by_id", Mock(return_value=None)) + + with pytest.raises(WorkflowNotFoundError): + service.restore_published_workflow_to_draft( + snippet=snippet, + workflow_id="missing-workflow", + account=account, + )