diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 837245ecb1..d59aa44718 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -7,7 +7,7 @@ from flask import abort, request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services from controllers.console import console_ns @@ -46,13 +46,14 @@ from models import App from models.model import AppMode from models.workflow import Workflow from services.app_generate_service import AppGenerateService -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -284,7 +285,9 @@ class DraftWorkflowApi(Resource): workflow_service = WorkflowService() try: - environment_variables_list = args.get("environment_variables") or [] + environment_variables_list = Workflow.normalize_environment_variable_mappings( + args.get("environment_variables") or [], + ) environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] @@ -994,6 +997,43 @@ class PublishedAllWorkflowApi(Resource): } +@console_ns.route("/apps//workflows//restore") +class DraftWorkflowRestoreApi(Resource): + @console_ns.doc("restore_workflow_to_draft") + @console_ns.doc(description="Restore a published workflow version into the draft workflow") + @console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Published workflow ID"}) + @console_ns.response(200, "Workflow restored successfully") + @console_ns.response(400, "Source workflow must be published") + @console_ns.response(404, "Workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App, workflow_id: str): + current_user, _ = current_account_with_tenant() + workflow_service = WorkflowService() + + try: + workflow = workflow_service.restore_published_workflow_to_draft( + app_model=app_model, + workflow_id=workflow_id, + account=current_user, + ) + except IsDraftWorkflowError as exc: + raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + except ValueError as exc: + raise BadRequest(str(exc)) from exc + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/apps//workflows/") class WorkflowByIdApi(Resource): @console_ns.doc("update_workflow_by_id") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 51cdcc0c7a..3912cc73ca 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -6,7 +6,7 @@ from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services from controllers.common.schema import register_schema_models @@ -16,7 +16,11 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, DraftWorkflowNotSync, ) -from controllers.console.app.workflow import workflow_model, workflow_pagination_model +from controllers.console.app.workflow import ( + RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE, + workflow_model, + workflow_pagination_model, +) from controllers.console.app.workflow_run import ( workflow_run_detail_model, workflow_run_node_execution_list_model, @@ -42,7 +46,8 @@ from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline from models.model import EndUser -from services.errors.app import WorkflowHashNotEqualError +from models.workflow import Workflow +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -203,9 +208,12 @@ class DraftRagPipelineApi(Resource): abort(415) payload = DraftWorkflowSyncPayload.model_validate(payload_dict) + rag_pipeline_service = RagPipelineService() try: - environment_variables_list = payload.environment_variables or [] + environment_variables_list = Workflow.normalize_environment_variable_mappings( + payload.environment_variables or [], + ) environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] @@ -213,7 +221,6 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, graph=payload.graph, @@ -705,6 +712,36 @@ class PublishedAllRagPipelineApi(Resource): } +@console_ns.route("/rag/pipelines//workflows//restore") +class RagPipelineDraftWorkflowRestoreApi(Resource): + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, workflow_id: str): + current_user, _ = current_account_with_tenant() + rag_pipeline_service = RagPipelineService() + + try: + workflow = rag_pipeline_service.restore_published_workflow_to_draft( + pipeline=pipeline, + workflow_id=workflow_id, + account=current_user, + ) + except IsDraftWorkflowError as exc: + # Use a stable, predefined message to keep the 400 response consistent + raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/rag/pipelines//workflows/") class RagPipelineByIdApi(Resource): @setup_required diff --git a/api/models/workflow.py b/api/models/workflow.py index e7b20d0e65..6e8dda429d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,3 +1,4 @@ +import copy import json import logging from collections.abc import Generator, Mapping, Sequence @@ -302,26 +303,40 @@ class Workflow(Base): # bug def features(self) -> str: """ Convert old features structure to new features structure. + + This property avoids rewriting the underlying JSON when normalization + produces no effective change, to prevent marking the row dirty on read. """ if not self._features: return self._features - features = json.loads(self._features) - if features.get("file_upload", {}).get("image", {}).get("enabled", False): - image_enabled = True - image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS)) - image_transfer_methods = features["file_upload"]["image"].get( - "transfer_methods", ["remote_url", "local_file"] - ) - features["file_upload"]["enabled"] = image_enabled - features["file_upload"]["number_limits"] = image_number_limits - features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods - features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"]) - features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get( - "allowed_file_extensions", [] - ) - del features["file_upload"]["image"] - self._features = json.dumps(features) + # Parse once and deep-copy before normalization to detect in-place changes. + original_dict = self._decode_features_payload(self._features) + if original_dict is None: + return self._features + + # Fast-path: if the legacy file_upload.image.enabled shape is absent, skip + # deep-copy and normalization entirely and return the stored JSON. + file_upload_payload = original_dict.get("file_upload") + if not isinstance(file_upload_payload, dict): + return self._features + file_upload = cast(dict[str, Any], file_upload_payload) + + image_payload = file_upload.get("image") + if not isinstance(image_payload, dict): + return self._features + image = cast(dict[str, Any], image_payload) + if "enabled" not in image: + return self._features + + normalized_dict = self._normalize_features_payload(copy.deepcopy(original_dict)) + + if normalized_dict == original_dict: + # No effective change; return stored JSON unchanged. + return self._features + + # Normalization changed the payload: persist the normalized JSON. + self._features = json.dumps(normalized_dict) return self._features @features.setter @@ -332,6 +347,44 @@ class Workflow(Base): # bug def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} + @property + def serialized_features(self) -> str: + """Return the stored features JSON without triggering compatibility rewrites.""" + return self._features + + @property + def normalized_features_dict(self) -> dict[str, Any]: + """Decode features with legacy normalization without mutating the model state.""" + if not self._features: + return {} + + features = self._decode_features_payload(self._features) + return self._normalize_features_payload(features) if features is not None else {} + + @staticmethod + def _decode_features_payload(features: str) -> dict[str, Any] | None: + """Decode workflow features JSON when it contains an object payload.""" + payload = json.loads(features) + return cast(dict[str, Any], payload) if isinstance(payload, dict) else None + + @staticmethod + def _normalize_features_payload(features: dict[str, Any]) -> dict[str, Any]: + if features.get("file_upload", {}).get("image", {}).get("enabled", False): + image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS)) + image_transfer_methods = features["file_upload"]["image"].get( + "transfer_methods", ["remote_url", "local_file"] + ) + features["file_upload"]["enabled"] = True + features["file_upload"]["number_limits"] = image_number_limits + features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods + features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"]) + features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get( + "allowed_file_extensions", [] + ) + del features["file_upload"]["image"] + + return features + def walk_nodes( self, specific_node_type: NodeType | None = None ) -> Generator[tuple[str, Mapping[str, Any]], None, None]: @@ -517,6 +570,31 @@ class Workflow(Base): # bug ) self._environment_variables = environment_variables_json + @staticmethod + def normalize_environment_variable_mappings( + mappings: Sequence[Mapping[str, Any]], + ) -> list[dict[str, Any]]: + """Convert masked secret placeholders into the draft hidden sentinel. + + Regular draft sync requests should preserve existing secrets without shipping + plaintext values back from the client. The dedicated restore endpoint now + copies published secrets server-side, so draft sync only needs to normalize + the UI mask into `HIDDEN_VALUE`. + """ + masked_secret_value = encrypter.full_mask_token() + normalized_mappings: list[dict[str, Any]] = [] + + for mapping in mappings: + normalized_mapping = dict(mapping) + if ( + normalized_mapping.get("value_type") == SegmentType.SECRET.value + and normalized_mapping.get("value") == masked_secret_value + ): + normalized_mapping["value"] = HIDDEN_VALUE + normalized_mappings.append(normalized_mapping) + + return normalized_mappings + def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict: environment_variables = list(self.environment_variables) environment_variables = [ @@ -564,6 +642,12 @@ class Workflow(Base): # bug ensure_ascii=False, ) + def copy_serialized_variable_storage_from(self, source_workflow: "Workflow") -> None: + """Copy stored variable JSON directly for same-tenant restore flows.""" + self._environment_variables = source_workflow._environment_variables + self._conversation_variables = source_workflow._conversation_variables + self._rag_pipeline_variables = source_workflow._rag_pipeline_variables + @staticmethod def version_from_datetime(d: datetime) -> str: return str(d) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f3aedafac9..296b9f0890 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -79,10 +79,11 @@ from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, PipelineTemplateInfoEntity, ) -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader +from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) @@ -234,6 +235,21 @@ class RagPipelineService: return workflow + def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: + """Fetch a published workflow snapshot by ID for restore operations.""" + workflow = ( + db.session.query(Workflow) + .where( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.id == workflow_id, + ) + .first() + ) + if workflow and workflow.version == Workflow.VERSION_DRAFT: + raise IsDraftWorkflowError("source workflow must be published") + return workflow + def get_all_published_workflow( self, *, @@ -327,6 +343,42 @@ class RagPipelineService: # return draft workflow return workflow + def restore_published_workflow_to_draft( + self, + *, + pipeline: Pipeline, + workflow_id: str, + account: Account, + ) -> Workflow: + """Restore a published pipeline workflow snapshot into the draft workflow. + + Pipelines reuse the shared draft-restore field copy helper, but still own + the pipeline-specific flush/link step that wires a newly created draft + back onto ``pipeline.workflow_id``. + """ + source_workflow = self.get_published_workflow_by_id(pipeline=pipeline, workflow_id=workflow_id) + if not source_workflow: + raise WorkflowNotFoundError("Workflow not found.") + + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + source_workflow=source_workflow, + draft_workflow=draft_workflow, + account=account, + updated_at_factory=lambda: datetime.now(UTC).replace(tzinfo=None), + ) + + if is_new_draft: + db.session.add(draft_workflow) + db.session.flush() + pipeline.workflow_id = draft_workflow.id + + db.session.commit() + + return draft_workflow + def publish_workflow( self, *, diff --git a/api/services/workflow_restore.py b/api/services/workflow_restore.py new file mode 100644 index 0000000000..083235d228 --- /dev/null +++ b/api/services/workflow_restore.py @@ -0,0 +1,58 @@ +"""Shared helpers for restoring published workflow snapshots into drafts. + +Both app workflows and RAG pipeline workflows restore the same workflow fields +from a published snapshot into a draft. Keeping that field-copy logic in one +place prevents the two restore paths from drifting when we add or adjust draft +state in the future. Restore stays within a tenant, so we can safely reuse the +serialized workflow storage blobs without decrypting and re-encrypting secrets. +""" + +from collections.abc import Callable +from datetime import datetime + +from models import Account +from models.workflow import Workflow, WorkflowType + +UpdatedAtFactory = Callable[[], datetime] + + +def apply_published_workflow_snapshot_to_draft( + *, + tenant_id: str, + app_id: str, + source_workflow: Workflow, + draft_workflow: Workflow | None, + account: Account, + updated_at_factory: UpdatedAtFactory, +) -> tuple[Workflow, bool]: + """Copy a published workflow snapshot into a draft workflow record. + + The caller remains responsible for source lookup, validation, flushing, and + post-commit side effects. This helper only centralizes the shared draft + creation/update semantics used by both restore entry points. Features are + copied from the stored JSON payload so restore does not normalize and dirty + the published source row before the caller commits. + """ + if not draft_workflow: + workflow_type = ( + source_workflow.type.value if isinstance(source_workflow.type, WorkflowType) else source_workflow.type + ) + draft_workflow = Workflow( + tenant_id=tenant_id, + app_id=app_id, + type=workflow_type, + version=Workflow.VERSION_DRAFT, + graph=source_workflow.graph, + features=source_workflow.serialized_features, + created_by=account.id, + ) + draft_workflow.copy_serialized_variable_storage_from(source_workflow) + return draft_workflow, True + + draft_workflow.graph = source_workflow.graph + draft_workflow.features = source_workflow.serialized_features + draft_workflow.updated_by = account.id + draft_workflow.updated_at = updated_at_factory() + draft_workflow.copy_serialized_variable_storage_from(source_workflow) + + return draft_workflow, False diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index e13cdd5f27..66976058c0 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -63,7 +63,12 @@ from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeEx from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService from services.enterprise.plugin_manager_service import PluginCredentialType -from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError +from services.errors.app import ( + IsDraftWorkflowError, + TriggerNodeLimitExceededError, + WorkflowHashNotEqualError, + WorkflowNotFoundError, +) from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError @@ -75,6 +80,7 @@ from .human_input_delivery_test_service import ( HumanInputDeliveryTestService, ) from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService +from .workflow_restore import apply_published_workflow_snapshot_to_draft class WorkflowService: @@ -279,6 +285,43 @@ class WorkflowService: # return draft workflow return workflow + def restore_published_workflow_to_draft( + self, + *, + app_model: App, + workflow_id: str, + account: Account, + ) -> Workflow: + """Restore a published workflow snapshot into the draft workflow. + + Secret environment variables are copied server-side from the selected + published workflow so the normal draft sync flow stays stateless. + """ + source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id) + if not source_workflow: + raise WorkflowNotFoundError("Workflow not found.") + + self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict) + self.validate_graph_structure(graph=source_workflow.graph_dict) + + draft_workflow = self.get_draft_workflow(app_model=app_model) + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + source_workflow=source_workflow, + draft_workflow=draft_workflow, + account=account, + updated_at_factory=naive_utc_now, + ) + + if is_new_draft: + db.session.add(draft_workflow) + + db.session.commit() + app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow) + + return draft_workflow + def publish_workflow( self, *, diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 056db41750..a5fe052206 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -802,6 +802,81 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) + def test_restore_published_workflow_to_draft_does_not_persist_normalized_source_features( + self, db_session_with_containers: Session + ): + """Restore copies legacy feature JSON into draft without rewriting the source row.""" + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + app.mode = AppMode.ADVANCED_CHAT + + legacy_features = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + published_workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version="2026.03.19.001", + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps(legacy_features), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + draft_workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + db_session_with_containers.add(published_workflow) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + restored_workflow = workflow_service.restore_published_workflow_to_draft( + app_model=app, + workflow_id=published_workflow.id, + account=account, + ) + + db_session_with_containers.expire_all() + refreshed_published_workflow = ( + db_session_with_containers.query(Workflow).filter_by(id=published_workflow.id).first() + ) + refreshed_draft_workflow = db_session_with_containers.query(Workflow).filter_by(id=draft_workflow.id).first() + + assert restored_workflow.id == draft_workflow.id + assert refreshed_published_workflow is not None + assert refreshed_draft_workflow is not None + assert refreshed_published_workflow.serialized_features == json.dumps(legacy_features) + assert refreshed_draft_workflow.serialized_features == json.dumps(legacy_features) + def test_get_default_block_configs(self, db_session_with_containers: Session): """ Test retrieval of default block configurations for all node types. diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index f100080eaa..0e22db9f9b 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -129,6 +129,136 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) handler(api, app_model=SimpleNamespace(id="app")) +def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow = SimpleNamespace( + unique_hash="restored-hash", + updated_at=None, + created_at=datetime(2024, 1, 1), + ) + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(restore_published_workflow_to_draft=lambda **_kwargs: workflow), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + response = handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + assert response["result"] == "success" + assert response["hash"] == "restored-hash" + + +def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + workflow_module.WorkflowNotFoundError("Workflow not found") + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(NotFound): + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + +def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + workflow_module.IsDraftWorkflowError( + "Cannot use draft workflow version. Workflow ID: draft-workflow. " + "Please use a published workflow version or leave workflow_id empty." + ) + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/draft-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="draft-workflow", + ) + + assert exc.value.code == 400 + assert exc.value.description == workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE + + +def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + ValueError("invalid workflow graph") + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + assert exc.value.code == 400 + assert exc.value.description == "invalid workflow graph" + + def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 7775cbdd81..472d133349 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -2,7 +2,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest -from werkzeug.exceptions import Forbidden, NotFound +from werkzeug.exceptions import Forbidden, HTTPException, NotFound import services from controllers.console import console_ns @@ -19,13 +19,14 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( RagPipelineDraftNodeRunApi, RagPipelineDraftRunIterationNodeApi, RagPipelineDraftRunLoopNodeApi, + RagPipelineDraftWorkflowRestoreApi, RagPipelineRecommendedPluginApi, RagPipelineTaskStopApi, RagPipelineTransformApi, RagPipelineWorkflowLastRunApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -116,6 +117,86 @@ class TestDraftWorkflowApi: response, status = method(api, pipeline) assert status == 400 + def test_restore_published_workflow_to_draft_success(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1)) + + service = MagicMock() + service.restore_published_workflow_to_draft.return_value = workflow + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "published-workflow") + + assert result["result"] == "success" + assert result["hash"] == "restored-hash" + + def test_restore_published_workflow_to_draft_not_found(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + + service = MagicMock() + service.restore_published_workflow_to_draft.side_effect = WorkflowNotFoundError("Workflow not found") + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "published-workflow") + + def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + + service = MagicMock() + service.restore_published_workflow_to_draft.side_effect = IsDraftWorkflowError( + "source workflow must be published" + ) + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(HTTPException) as exc: + method(api, pipeline, "draft-workflow") + + assert exc.value.code == 400 + assert exc.value.description == "source workflow must be published" + class TestDraftRunNodes: def test_iteration_node_success(self, app): diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index f3b72aa128..ef29b26a7a 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -4,12 +4,18 @@ from unittest import mock from uuid import uuid4 from constants import HIDDEN_VALUE +from core.helper import encrypter from dify_graph.file.enums import FileTransferMethod, FileType from dify_graph.file.models import File from dify_graph.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from dify_graph.variables.segments import IntegerSegment, Segment from factories.variable_factory import build_segment -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable +from models.workflow import ( + Workflow, + WorkflowDraftVariable, + WorkflowNodeExecutionModel, + is_system_variable_editable, +) def test_environment_variables(): @@ -144,6 +150,36 @@ def test_to_dict(): assert workflow_dict["environment_variables"][1]["value"] == "text" +def test_normalize_environment_variable_mappings_converts_full_mask_to_hidden_value(): + normalized = Workflow.normalize_environment_variable_mappings( + [ + { + "id": str(uuid4()), + "name": "secret", + "value": encrypter.full_mask_token(), + "value_type": "secret", + } + ] + ) + + assert normalized[0]["value"] == HIDDEN_VALUE + + +def test_normalize_environment_variable_mappings_keeps_hidden_value(): + normalized = Workflow.normalize_environment_variable_mappings( + [ + { + "id": str(uuid4()), + "name": "secret", + "value": HIDDEN_VALUE, + "value_type": "secret", + } + ] + ) + + assert normalized[0]["value"] == HIDDEN_VALUE + + class TestWorkflowNodeExecution: def test_execution_metadata_dict(self): node_exec = WorkflowNodeExecutionModel() diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 57c0464dc6..753cff8697 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -544,6 +544,89 @@ class TestWorkflowService: conversation_variables=[], ) + def test_restore_published_workflow_to_draft_keeps_source_features_unmodified( + self, workflow_service, mock_db_session + ): + app = TestWorkflowAssociatedDataFactory.create_app_mock() + account = TestWorkflowAssociatedDataFactory.create_account_mock() + legacy_features = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + normalized_features = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_file_extensions": [], + "allowed_file_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + source_workflow = Workflow( + id="published-workflow-id", + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version="2026-03-19T00:00:00", + graph=json.dumps(TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()), + features=json.dumps(legacy_features), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + draft_workflow = Workflow( + id="draft-workflow-id", + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + with ( + patch.object(workflow_service, "get_published_workflow_by_id", return_value=source_workflow), + patch.object(workflow_service, "get_draft_workflow", return_value=draft_workflow), + patch.object(workflow_service, "validate_graph_structure"), + patch.object(workflow_service, "validate_features_structure") as mock_validate_features, + patch("services.workflow_service.app_draft_workflow_was_synced"), + ): + result = workflow_service.restore_published_workflow_to_draft( + app_model=app, + workflow_id=source_workflow.id, + account=account, + ) + + mock_validate_features.assert_called_once_with(app_model=app, features=normalized_features) + assert result is draft_workflow + assert source_workflow.serialized_features == json.dumps(legacy_features) + assert draft_workflow.serialized_features == json.dumps(legacy_features) + mock_db_session.session.commit.assert_called_once() + # ==================== Workflow Validation Tests ==================== # These tests verify graph structure and feature configuration validation diff --git a/api/tests/unit_tests/services/workflow/test_workflow_restore.py b/api/tests/unit_tests/services/workflow/test_workflow_restore.py new file mode 100644 index 0000000000..179361de45 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_restore.py @@ -0,0 +1,77 @@ +import json +from types import SimpleNamespace + +from models.workflow import Workflow +from services.workflow_restore import apply_published_workflow_snapshot_to_draft + +LEGACY_FEATURES = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + +NORMALIZED_FEATURES = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_file_extensions": [], + "allowed_file_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + + +def _create_workflow(*, workflow_id: str, version: str, features: dict[str, object]) -> Workflow: + return Workflow( + id=workflow_id, + tenant_id="tenant-id", + app_id="app-id", + type="workflow", + version=version, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps(features), + created_by="account-id", + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + +def test_apply_published_workflow_snapshot_to_draft_copies_serialized_features_without_mutating_source() -> None: + source_workflow = _create_workflow( + workflow_id="published-workflow-id", + version="2026-03-19T00:00:00", + features=LEGACY_FEATURES, + ) + + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id="tenant-id", + app_id="app-id", + source_workflow=source_workflow, + draft_workflow=None, + account=SimpleNamespace(id="account-id"), + updated_at_factory=lambda: source_workflow.updated_at, + ) + + assert is_new_draft is True + assert source_workflow.serialized_features == json.dumps(LEGACY_FEATURES) + assert source_workflow.normalized_features_dict == NORMALIZED_FEATURES + assert draft_workflow.serialized_features == json.dumps(LEGACY_FEATURES) diff --git a/web/app/components/plugins/plugin-page/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-page/__tests__/index.spec.tsx index dafcbe57c2..e02a2bcb57 100644 --- a/web/app/components/plugins/plugin-page/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-page/__tests__/index.spec.tsx @@ -8,6 +8,8 @@ import { usePluginInstallation } from '@/hooks/use-query-params' import { fetchBundleInfoFromMarketPlace, fetchManifestFromMarketPlace } from '@/service/plugins' import PluginPageWithContext from '../index' +let mockEnableMarketplace = true + // Mock external dependencies vi.mock('@/service/plugins', () => ({ fetchManifestFromMarketPlace: vi.fn(), @@ -31,7 +33,7 @@ vi.mock('@/context/global-public-context', () => ({ useGlobalPublicStore: vi.fn((selector) => { const state = { systemFeatures: { - enable_marketplace: true, + enable_marketplace: mockEnableMarketplace, }, } return selector(state) @@ -138,6 +140,7 @@ const createDefaultProps = (): PluginPageProps => ({ describe('PluginPage Component', () => { beforeEach(() => { vi.clearAllMocks() + mockEnableMarketplace = true // Reset to default mock values vi.mocked(usePluginInstallation).mockReturnValue([ { packageId: null, bundleInfo: null }, @@ -630,18 +633,7 @@ describe('PluginPage Component', () => { }) it('should handle marketplace disabled', () => { - // Mock marketplace disabled - vi.mock('@/context/global-public-context', async () => ({ - useGlobalPublicStore: vi.fn((selector) => { - const state = { - systemFeatures: { - enable_marketplace: false, - }, - } - return selector(state) - }), - })) - + mockEnableMarketplace = false vi.mocked(useQueryState).mockReturnValue(['discover', vi.fn()]) render() diff --git a/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx b/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx index 36454d33e4..c3341ecd83 100644 --- a/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx +++ b/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx @@ -1,5 +1,6 @@ import type { EnvironmentVariable } from '@/app/components/workflow/types' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { useState } from 'react' import { createMockProviderContextValue } from '@/__mocks__/provider-context' import Conversion from '../conversion' @@ -347,11 +348,67 @@ vi.mock('@/app/components/workflow/dsl-export-confirm-modal', () => ({ ), })) +vi.mock('@/app/components/base/app-icon-picker', () => ({ + default: function MockAppIconPicker({ onSelect, onClose }: { + onSelect?: (payload: + | { type: 'emoji', icon: string, background: string } + | { type: 'image', fileId: string, url: string }, + ) => void + onClose?: () => void + }) { + const [activeTab, setActiveTab] = useState<'emoji' | 'image'>('emoji') + const [selectedEmoji, setSelectedEmoji] = useState({ icon: '😀', background: '#FFFFFF' }) + + return ( +
+ + + {activeTab === 'emoji' && ( + + )} + {activeTab === 'image' &&
picker-image-panel
} + + +
+ ) + }, +})) + // Silence expected console.error from Dialog/Modal rendering beforeEach(() => { vi.spyOn(console, 'error').mockImplementation(() => {}) }) +afterEach(() => { + vi.restoreAllMocks() +}) + // Helper to find the name input in PublishAsKnowledgePipelineModal function getNameInput() { return screen.getByPlaceholderText('pipeline.common.publishAsPipeline.namePlaceholder') @@ -708,10 +765,7 @@ describe('PublishAsKnowledgePipelineModal', () => { const appIcon = getAppIcon() fireEvent.click(appIcon) - // Click the first emoji in the grid (search full document since Dialog uses portal) - const gridEmojis = document.querySelectorAll('.grid em-emoji') - expect(gridEmojis.length).toBeGreaterThan(0) - fireEvent.click(gridEmojis[0].parentElement!.parentElement!) + fireEvent.click(screen.getByTestId('picker-emoji-option')) // Click OK to confirm selection fireEvent.click(screen.getByRole('button', { name: /iconPicker\.ok/ })) @@ -1031,11 +1085,8 @@ describe('Integration Tests', () => { // Open picker and select an emoji const appIcon = getAppIcon() fireEvent.click(appIcon) - const gridEmojis = document.querySelectorAll('.grid em-emoji') - if (gridEmojis.length > 0) { - fireEvent.click(gridEmojis[0].parentElement!.parentElement!) - fireEvent.click(screen.getByRole('button', { name: /iconPicker\.ok/ })) - } + fireEvent.click(screen.getByTestId('picker-emoji-option')) + fireEvent.click(screen.getByRole('button', { name: /iconPicker\.ok/ })) fireEvent.click(screen.getByRole('button', { name: /workflow\.common\.publish/i })) diff --git a/web/app/components/rag-pipeline/components/panel/index.tsx b/web/app/components/rag-pipeline/components/panel/index.tsx index 74cdd7034d..8f913956d1 100644 --- a/web/app/components/rag-pipeline/components/panel/index.tsx +++ b/web/app/components/rag-pipeline/components/panel/index.tsx @@ -62,6 +62,7 @@ const RagPipelinePanel = () => { return { getVersionListUrl: `/rag/pipelines/${pipelineId}/workflows`, deleteVersionUrl: (versionId: string) => `/rag/pipelines/${pipelineId}/workflows/${versionId}`, + restoreVersionUrl: (versionId: string) => `/rag/pipelines/${pipelineId}/workflows/${versionId}/restore`, updateVersionUrl: (versionId: string) => `/rag/pipelines/${pipelineId}/workflows/${versionId}`, latestVersionId: '', } diff --git a/web/app/components/rag-pipeline/hooks/__tests__/use-nodes-sync-draft.spec.ts b/web/app/components/rag-pipeline/hooks/__tests__/use-nodes-sync-draft.spec.ts index 82635a75b3..6d807565d9 100644 --- a/web/app/components/rag-pipeline/hooks/__tests__/use-nodes-sync-draft.spec.ts +++ b/web/app/components/rag-pipeline/hooks/__tests__/use-nodes-sync-draft.spec.ts @@ -231,6 +231,25 @@ describe('useNodesSyncDraft', () => { expect(mockSyncWorkflowDraft).toHaveBeenCalled() }) + it('should not include source_workflow_id in sync payloads', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft() + }) + + expect(mockSyncWorkflowDraft).toHaveBeenCalledWith(expect.objectContaining({ + params: expect.not.objectContaining({ + source_workflow_id: expect.anything(), + }), + })) + }) + it('should call onSuccess callback when sync succeeds', async () => { mockGetNodesReadOnly.mockReturnValue(false) mockGetNodes.mockReturnValue([ @@ -421,6 +440,21 @@ describe('useNodesSyncDraft', () => { expect(sentParams.rag_pipeline_variables).toEqual([{ variable: 'input', type: 'text-input' }]) }) + it('should not include source_workflow_id when syncing on page close', () => { + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + const sentParams = mockPostWithKeepalive.mock.calls[0][1] + expect(sentParams.source_workflow_id).toBeUndefined() + }) + it('should remove underscore-prefixed keys from edges', () => { mockStoreGetState.mockReturnValue({ getNodes: mockGetNodes, diff --git a/web/app/components/rag-pipeline/hooks/__tests__/use-pipeline-refresh-draft.spec.ts b/web/app/components/rag-pipeline/hooks/__tests__/use-pipeline-refresh-draft.spec.ts index 4ad8bc4582..b9cff292e6 100644 --- a/web/app/components/rag-pipeline/hooks/__tests__/use-pipeline-refresh-draft.spec.ts +++ b/web/app/components/rag-pipeline/hooks/__tests__/use-pipeline-refresh-draft.spec.ts @@ -35,6 +35,7 @@ describe('usePipelineRefreshDraft', () => { const mockSetIsSyncingWorkflowDraft = vi.fn() const mockSetEnvironmentVariables = vi.fn() const mockSetEnvSecrets = vi.fn() + const mockSetRagPipelineVariables = vi.fn() beforeEach(() => { vi.clearAllMocks() @@ -45,6 +46,7 @@ describe('usePipelineRefreshDraft', () => { setIsSyncingWorkflowDraft: mockSetIsSyncingWorkflowDraft, setEnvironmentVariables: mockSetEnvironmentVariables, setEnvSecrets: mockSetEnvSecrets, + setRagPipelineVariables: mockSetRagPipelineVariables, }) mockFetchWorkflowDraft.mockResolvedValue({ @@ -55,6 +57,7 @@ describe('usePipelineRefreshDraft', () => { }, hash: 'new-hash', environment_variables: [], + rag_pipeline_variables: [], }) }) @@ -116,6 +119,29 @@ describe('usePipelineRefreshDraft', () => { }) }) + it('should update rag pipeline variables after fetch', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: [], + rag_pipeline_variables: [{ variable: 'query', type: 'text-input' }], + }) + + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([{ variable: 'query', type: 'text-input' }]) + }) + }) + it('should set syncing state to false after completion', async () => { const { result } = renderHook(() => usePipelineRefreshDraft()) diff --git a/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts index 640da5e8f8..184adb582f 100644 --- a/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts @@ -1,3 +1,4 @@ +import type { SyncDraftCallback } from '@/app/components/workflow/hooks-store' import { produce } from 'immer' import { useCallback } from 'react' import { useStoreApi } from 'reactflow' @@ -83,11 +84,7 @@ export const useNodesSyncDraft = () => { const performSync = useCallback(async ( notRefreshWhenSyncError?: boolean, - callback?: { - onSuccess?: () => void - onError?: () => void - onSettled?: () => void - }, + callback?: SyncDraftCallback, ) => { if (getNodesReadOnly()) return diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.ts index 8909af4c4c..c9966a90c5 100644 --- a/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.ts +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.ts @@ -16,6 +16,7 @@ export const usePipelineRefreshDraft = () => { setIsSyncingWorkflowDraft, setEnvironmentVariables, setEnvSecrets, + setRagPipelineVariables, } = workflowStore.getState() setIsSyncingWorkflowDraft(true) fetchWorkflowDraft(`/rag/pipelines/${pipelineId}/workflows/draft`).then((response) => { @@ -34,6 +35,7 @@ export const usePipelineRefreshDraft = () => { return acc }, {} as Record)) setEnvironmentVariables(response.environment_variables?.map(env => env.value_type === 'secret' ? { ...env, value: '[__HIDDEN__]' } : env) || []) + setRagPipelineVariables?.(response.rag_pipeline_variables || []) }).finally(() => setIsSyncingWorkflowDraft(false)) }, [handleUpdateWorkflowCanvas, workflowStore]) diff --git a/web/app/components/workflow-app/components/workflow-panel.tsx b/web/app/components/workflow-app/components/workflow-panel.tsx index 7f70c53e2e..4b145339d7 100644 --- a/web/app/components/workflow-app/components/workflow-panel.tsx +++ b/web/app/components/workflow-app/components/workflow-panel.tsx @@ -110,6 +110,7 @@ const WorkflowPanel = () => { return { getVersionListUrl: `/apps/${appId}/workflows`, deleteVersionUrl: (versionId: string) => `/apps/${appId}/workflows/${versionId}`, + restoreVersionUrl: (versionId: string) => `/apps/${appId}/workflows/${versionId}/restore`, updateVersionUrl: (versionId: string) => `/apps/${appId}/workflows/${versionId}`, latestVersionId: appDetail?.workflow?.id, } diff --git a/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts b/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts index d35e6e3612..fd808affc3 100644 --- a/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts +++ b/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts @@ -108,4 +108,18 @@ describe('useNodesSyncDraft — handleRefreshWorkflowDraft(true) on 409', () => expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled() }) + + it('should not include source_workflow_id in draft sync payloads', async () => { + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false) + }) + + expect(mockSyncWorkflowDraft).toHaveBeenCalledWith(expect.objectContaining({ + params: expect.not.objectContaining({ + source_workflow_id: expect.anything(), + }), + })) + }) }) diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts index 4f9e529d92..5f61997d9f 100644 --- a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -1,3 +1,4 @@ +import type { SyncDraftCallback } from '@/app/components/workflow/hooks-store' import { produce } from 'immer' import { useCallback } from 'react' import { useStoreApi } from 'reactflow' @@ -91,11 +92,7 @@ export const useNodesSyncDraft = () => { const performSync = useCallback(async ( notRefreshWhenSyncError?: boolean, - callback?: { - onSuccess?: () => void - onError?: () => void - onSettled?: () => void - }, + callback?: SyncDraftCallback, ) => { if (getNodesReadOnly()) return diff --git a/web/app/components/workflow/header/__tests__/header-in-restoring.spec.tsx b/web/app/components/workflow/header/__tests__/header-in-restoring.spec.tsx new file mode 100644 index 0000000000..6fa934b57d --- /dev/null +++ b/web/app/components/workflow/header/__tests__/header-in-restoring.spec.tsx @@ -0,0 +1,126 @@ +import type { VersionHistory } from '@/types/workflow' +import { screen } from '@testing-library/react' +import { FlowType } from '@/types/common' +import { renderWorkflowComponent } from '../../__tests__/workflow-test-env' +import { WorkflowVersion } from '../../types' +import HeaderInRestoring from '../header-in-restoring' + +const mockRestoreWorkflow = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockHandleLoadBackupDraft = vi.fn() +const mockHandleRefreshWorkflowDraft = vi.fn() + +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ + theme: 'light', + }), +})) + +vi.mock('@/hooks/use-timestamp', () => ({ + default: () => ({ + formatTime: vi.fn(() => '09:30:00'), + }), +})) + +vi.mock('@/hooks/use-format-time-from-now', () => ({ + useFormatTimeFromNow: () => ({ + formatTimeFromNow: vi.fn(() => '3 hours ago'), + }), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, + useRestoreWorkflow: () => ({ + mutateAsync: mockRestoreWorkflow, + }), +})) + +vi.mock('../../hooks', () => ({ + useWorkflowRun: () => ({ + handleLoadBackupDraft: mockHandleLoadBackupDraft, + }), + useWorkflowRefreshDraft: () => ({ + handleRefreshWorkflowDraft: mockHandleRefreshWorkflowDraft, + }), +})) + +const createVersion = (overrides: Partial = {}): VersionHistory => ({ + id: 'version-1', + graph: { + nodes: [], + edges: [], + }, + created_at: 1_700_000_000, + created_by: { + id: 'user-1', + name: 'Alice', + email: 'alice@example.com', + }, + hash: 'hash-1', + updated_at: 1_700_000_100, + updated_by: { + id: 'user-2', + name: 'Bob', + email: 'bob@example.com', + }, + tool_published: false, + version: 'v1', + marked_name: 'Release 1', + marked_comment: '', + ...overrides, +}) + +describe('HeaderInRestoring', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should disable restore when the flow id is not ready yet', () => { + renderWorkflowComponent(, { + initialStoreState: { + currentVersion: createVersion(), + }, + hooksStoreProps: { + configsMap: undefined, + }, + }) + + expect(screen.getByRole('button', { name: 'workflow.common.restore' })).toBeDisabled() + }) + + it('should enable restore when version and flow config are both ready', () => { + renderWorkflowComponent(, { + initialStoreState: { + currentVersion: createVersion(), + }, + hooksStoreProps: { + configsMap: { + flowId: 'app-1', + flowType: FlowType.appFlow, + fileSettings: {} as never, + }, + }, + }) + + expect(screen.getByRole('button', { name: 'workflow.common.restore' })).toBeEnabled() + }) + + it('should keep restore disabled for draft versions even when flow config is ready', () => { + renderWorkflowComponent(, { + initialStoreState: { + currentVersion: createVersion({ + version: WorkflowVersion.Draft, + }), + }, + hooksStoreProps: { + configsMap: { + flowId: 'app-1', + flowType: FlowType.appFlow, + fileSettings: {} as never, + }, + }, + }) + + expect(screen.getByRole('button', { name: 'workflow.common.restore' })).toBeDisabled() + }) +}) diff --git a/web/app/components/workflow/header/header-in-restoring.tsx b/web/app/components/workflow/header/header-in-restoring.tsx index e005ce64e8..2c5b4b9f08 100644 --- a/web/app/components/workflow/header/header-in-restoring.tsx +++ b/web/app/components/workflow/header/header-in-restoring.tsx @@ -5,11 +5,12 @@ import { import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import useTheme from '@/hooks/use-theme' -import { useInvalidAllLastRun } from '@/service/use-workflow' +import { useInvalidAllLastRun, useRestoreWorkflow } from '@/service/use-workflow' +import { getFlowPrefix } from '@/service/utils' import { cn } from '@/utils/classnames' import Toast from '../../base/toast' import { - useNodesSyncDraft, + useWorkflowRefreshDraft, useWorkflowRun, } from '../hooks' import { useHooksStore } from '../hooks-store' @@ -42,7 +43,9 @@ const HeaderInRestoring = ({ const { handleLoadBackupDraft, } = useWorkflowRun() - const { handleSyncWorkflowDraft } = useNodesSyncDraft() + const { handleRefreshWorkflowDraft } = useWorkflowRefreshDraft() + const { mutateAsync: restoreWorkflow } = useRestoreWorkflow() + const canRestore = !!currentVersion?.id && !!configsMap?.flowId && currentVersion.version !== WorkflowVersion.Draft const handleCancelRestore = useCallback(() => { handleLoadBackupDraft() @@ -50,30 +53,35 @@ const HeaderInRestoring = ({ setShowWorkflowVersionHistoryPanel(false) }, [workflowStore, handleLoadBackupDraft, setShowWorkflowVersionHistoryPanel]) - const handleRestore = useCallback(() => { + const handleRestore = useCallback(async () => { + if (!canRestore) + return + setShowWorkflowVersionHistoryPanel(false) - workflowStore.setState({ isRestoring: false }) - workflowStore.setState({ backupDraft: undefined }) - handleSyncWorkflowDraft(true, false, { - onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }), - }) - }, - onError: () => { - Toast.notify({ - type: 'error', - message: t('versionHistory.action.restoreFailure', { ns: 'workflow' }), - }) - }, - onSettled: () => { - onRestoreSettled?.() - }, - }) - deleteAllInspectVars() - invalidAllLastRun() - }, [setShowWorkflowVersionHistoryPanel, workflowStore, handleSyncWorkflowDraft, deleteAllInspectVars, invalidAllLastRun, t, onRestoreSettled]) + const restoreUrl = `/${getFlowPrefix(configsMap.flowType)}/${configsMap.flowId}/workflows/${currentVersion.id}/restore` + + try { + await restoreWorkflow(restoreUrl) + workflowStore.setState({ isRestoring: false }) + workflowStore.setState({ backupDraft: undefined }) + handleRefreshWorkflowDraft() + Toast.notify({ + type: 'success', + message: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }), + }) + deleteAllInspectVars() + invalidAllLastRun() + } + catch { + Toast.notify({ + type: 'error', + message: t('versionHistory.action.restoreFailure', { ns: 'workflow' }), + }) + } + finally { + onRestoreSettled?.() + } + }, [canRestore, currentVersion?.id, configsMap, setShowWorkflowVersionHistoryPanel, workflowStore, restoreWorkflow, handleRefreshWorkflowDraft, deleteAllInspectVars, invalidAllLastRun, t, onRestoreSettled]) return ( <> @@ -83,7 +91,7 @@ const HeaderInRestoring = ({
+ } + + return + }, })) vi.mock('@/app/components/app/app-publisher/version-info-modal', () => ({ default: () => null, })) +vi.mock('../version-history-item', () => ({ + default: (props: MockVersionHistoryItemProps) => { + const MockVersionHistoryItem = () => { + const { item, onClick, handleClickMenuItem } = props + + useEffect(() => { + if (item.version === WorkflowVersion.Draft) + onClick(item) + }, [item, onClick]) + + return ( +
+ + {item.version !== WorkflowVersion.Draft && ( + + )} +
+ ) + } + + return + }, +})) + describe('VersionHistoryPanel', () => { beforeEach(() => { vi.clearAllMocks() + mockCurrentVersion = null }) describe('Version Click Behavior', () => { @@ -134,10 +184,10 @@ describe('VersionHistoryPanel', () => { render( `/apps/app-1/workflows/${versionId}/restore`} />, ) - // Draft version auto-clicks on mount via useEffect in VersionHistoryItem expect(mockHandleLoadBackupDraft).toHaveBeenCalled() expect(mockHandleRestoreFromPublishedWorkflow).not.toHaveBeenCalled() }) @@ -148,17 +198,72 @@ describe('VersionHistoryPanel', () => { render( `/apps/app-1/workflows/${versionId}/restore`} />, ) - // Clear mocks after initial render (draft version auto-clicks on mount) vi.clearAllMocks() - const publishedItem = screen.getByText('v1.0') - fireEvent.click(publishedItem) + fireEvent.click(screen.getByText('v1.0')) expect(mockHandleRestoreFromPublishedWorkflow).toHaveBeenCalled() expect(mockHandleLoadBackupDraft).not.toHaveBeenCalled() }) }) + + it('should set current version before confirming restore from context menu', async () => { + const { VersionHistoryPanel } = await import('../index') + + render( + `/apps/app-1/workflows/${versionId}/restore`} + />, + ) + + vi.clearAllMocks() + + fireEvent.click(screen.getByText('restore-published-version-id')) + fireEvent.click(screen.getByText('confirm restore')) + + await waitFor(() => { + expect(mockSetCurrentVersion).toHaveBeenCalledWith(expect.objectContaining({ + id: 'published-version-id', + })) + expect(mockRestoreWorkflow).toHaveBeenCalledWith('/apps/app-1/workflows/published-version-id/restore') + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ isRestoring: false }) + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ backupDraft: undefined }) + expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalled() + }) + }) + + it('should keep restore mode backup state when restore request fails', async () => { + const { VersionHistoryPanel } = await import('../index') + mockRestoreWorkflow.mockRejectedValueOnce(new Error('restore failed')) + mockCurrentVersion = createVersionHistory({ + id: 'draft-version-id', + version: WorkflowVersion.Draft, + }) + + render( + `/apps/app-1/workflows/${versionId}/restore`} + />, + ) + + vi.clearAllMocks() + + fireEvent.click(screen.getByText('restore-published-version-id')) + fireEvent.click(screen.getByText('confirm restore')) + + await waitFor(() => { + expect(mockRestoreWorkflow).toHaveBeenCalledWith('/apps/app-1/workflows/published-version-id/restore') + }) + + expect(mockWorkflowStoreSetState).not.toHaveBeenCalledWith({ isRestoring: false }) + expect(mockWorkflowStoreSetState).not.toHaveBeenCalledWith({ backupDraft: undefined }) + expect(mockSetCurrentVersion).not.toHaveBeenCalled() + expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled() + }) }) diff --git a/web/app/components/workflow/panel/version-history-panel/index.tsx b/web/app/components/workflow/panel/version-history-panel/index.tsx index 9439efc918..2815cbf28d 100644 --- a/web/app/components/workflow/panel/version-history-panel/index.tsx +++ b/web/app/components/workflow/panel/version-history-panel/index.tsx @@ -9,8 +9,8 @@ import VersionInfoModal from '@/app/components/app/app-publisher/version-info-mo import Divider from '@/app/components/base/divider' import { toast } from '@/app/components/base/ui/toast' import { useSelector as useAppContextSelector } from '@/context/app-context' -import { useDeleteWorkflow, useInvalidAllLastRun, useResetWorkflowVersionHistory, useUpdateWorkflow, useWorkflowVersionHistory } from '@/service/use-workflow' -import { useDSL, useNodesSyncDraft, useWorkflowRun } from '../../hooks' +import { useDeleteWorkflow, useInvalidAllLastRun, useResetWorkflowVersionHistory, useRestoreWorkflow, useUpdateWorkflow, useWorkflowVersionHistory } from '@/service/use-workflow' +import { useDSL, useWorkflowRefreshDraft, useWorkflowRun } from '../../hooks' import { useHooksStore } from '../../hooks-store' import { useStore, useWorkflowStore } from '../../store' import { VersionHistoryContextMenuOptions, WorkflowVersion, WorkflowVersionFilterOptions } from '../../types' @@ -27,12 +27,14 @@ const INITIAL_PAGE = 1 export type VersionHistoryPanelProps = { getVersionListUrl?: string deleteVersionUrl?: (versionId: string) => string + restoreVersionUrl: (versionId: string) => string updateVersionUrl?: (versionId: string) => string latestVersionId?: string } export const VersionHistoryPanel = ({ getVersionListUrl, deleteVersionUrl, + restoreVersionUrl, updateVersionUrl, latestVersionId, }: VersionHistoryPanelProps) => { @@ -43,8 +45,8 @@ export const VersionHistoryPanel = ({ const [deleteConfirmOpen, setDeleteConfirmOpen] = useState(false) const [editModalOpen, setEditModalOpen] = useState(false) const workflowStore = useWorkflowStore() - const { handleSyncWorkflowDraft } = useNodesSyncDraft() const { handleRestoreFromPublishedWorkflow, handleLoadBackupDraft } = useWorkflowRun() + const { handleRefreshWorkflowDraft } = useWorkflowRefreshDraft() const { handleExportDSL } = useDSL() const setShowWorkflowVersionHistoryPanel = useStore(s => s.setShowWorkflowVersionHistoryPanel) const currentVersion = useStore(s => s.currentVersion) @@ -144,32 +146,33 @@ export const VersionHistoryPanel = ({ }, []) const resetWorkflowVersionHistory = useResetWorkflowVersionHistory() + const { mutateAsync: restoreWorkflow } = useRestoreWorkflow() - const handleRestore = useCallback((item: VersionHistory) => { + const handleRestore = useCallback(async (item: VersionHistory) => { setShowWorkflowVersionHistoryPanel(false) - handleRestoreFromPublishedWorkflow(item) - workflowStore.setState({ isRestoring: false }) - workflowStore.setState({ backupDraft: undefined }) - handleSyncWorkflowDraft(true, false, { - onSuccess: () => { - toast.add({ - type: 'success', - title: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }), - }) - deleteAllInspectVars() - invalidAllLastRun() - }, - onError: () => { - toast.add({ - type: 'error', - title: t('versionHistory.action.restoreFailure', { ns: 'workflow' }), - }) - }, - onSettled: () => { - resetWorkflowVersionHistory() - }, - }) - }, [setShowWorkflowVersionHistoryPanel, handleRestoreFromPublishedWorkflow, workflowStore, handleSyncWorkflowDraft, deleteAllInspectVars, invalidAllLastRun, t, resetWorkflowVersionHistory]) + try { + await restoreWorkflow(restoreVersionUrl(item.id)) + setCurrentVersion(item) + workflowStore.setState({ isRestoring: false }) + workflowStore.setState({ backupDraft: undefined }) + handleRefreshWorkflowDraft() + toast.add({ + type: 'success', + title: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }), + }) + deleteAllInspectVars() + invalidAllLastRun() + } + catch { + toast.add({ + type: 'error', + title: t('versionHistory.action.restoreFailure', { ns: 'workflow' }), + }) + } + finally { + resetWorkflowVersionHistory() + } + }, [setShowWorkflowVersionHistoryPanel, setCurrentVersion, workflowStore, restoreWorkflow, restoreVersionUrl, handleRefreshWorkflowDraft, deleteAllInspectVars, invalidAllLastRun, t, resetWorkflowVersionHistory]) const { mutateAsync: deleteWorkflow } = useDeleteWorkflow() diff --git a/web/service/use-workflow.ts b/web/service/use-workflow.ts index fe20b906fc..949658d8ed 100644 --- a/web/service/use-workflow.ts +++ b/web/service/use-workflow.ts @@ -113,6 +113,13 @@ export const useDeleteWorkflow = () => { }) } +export const useRestoreWorkflow = () => { + return useMutation({ + mutationKey: [NAME_SPACE, 'restore'], + mutationFn: (url: string) => post(url, {}, { silent: true }), + }) +} + export const usePublishWorkflow = () => { return useMutation({ mutationKey: [NAME_SPACE, 'publish'],