diff --git a/api/controllers/console/snippets/snippet_workflow.py b/api/controllers/console/snippets/snippet_workflow.py index 5af885ab91b..2c6cbe9dd1f 100644 --- a/api/controllers/console/snippets/snippet_workflow.py +++ b/api/controllers/console/snippets/snippet_workflow.py @@ -80,6 +80,10 @@ class SnippetDraftConfigResponse(BaseModel): parallel_depth_limit: int +class SnippetWorkflowPaginationResponse(WorkflowPaginationResponse): + items: list[SnippetWorkflowResponse] + + register_schema_models( console_ns, SnippetDraftSyncPayload, @@ -98,6 +102,7 @@ register_response_schema_models( SimpleResultResponse, SnippetDraftConfigResponse, SnippetWorkflowResponse, + SnippetWorkflowPaginationResponse, WorkflowPublishResponse, WorkflowPaginationResponse, WorkflowRestoreResponse, @@ -329,7 +334,7 @@ class SnippetPublishedAllWorkflowApi(Resource): @console_ns.response( 200, "Published workflows retrieved successfully", - console_ns.models[WorkflowPaginationResponse.__name__], + console_ns.models[SnippetWorkflowPaginationResponse.__name__], ) @setup_required @login_required @@ -350,7 +355,7 @@ class SnippetPublishedAllWorkflowApi(Resource): limit=args.limit, ) - return WorkflowPaginationResponse.model_validate( + response = SnippetWorkflowPaginationResponse.model_validate( { "items": workflows, "page": args.page, @@ -359,6 +364,9 @@ class SnippetPublishedAllWorkflowApi(Resource): }, from_attributes=True, ).model_dump(mode="json") + for item in response["items"]: + item["input_fields"] = snippet.input_fields_list + return response @console_ns.route("/snippets//workflows//restore") diff --git a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py index 11916c87b68..b20dd3e30a7 100644 --- a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py +++ b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from datetime import datetime from inspect import unwrap from types import SimpleNamespace @@ -199,6 +200,54 @@ def test_default_block_configs_delegates_to_service(app: Flask, monkeypatch: pyt get_default_block_configs.assert_called_once() +def test_list_published_snippet_workflows_includes_input_fields(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + workflow = SimpleNamespace( + id="workflow-1", + graph_dict={"nodes": [], "edges": []}, + features_dict={}, + unique_hash="hash-1", + version="2024-01-01 00:00:00", + marked_name="", + marked_comment="", + created_by_account=None, + created_at=datetime(2024, 1, 1), + updated_by_account=None, + updated_at=datetime(2024, 1, 1), + tool_published=False, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + input_fields = [{"variable": "query", "type": "text"}] + snippet = _snippet(input_fields=json.dumps(input_fields)) + + class SessionContext: + def __init__(self, engine): + self.engine = engine + + def __enter__(self): + return Mock() + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(snippet_workflow_module, "Session", SessionContext) + monkeypatch.setattr(snippet_workflow_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + snippet_workflow_module, + "SnippetService", + lambda: SimpleNamespace(get_all_published_workflows=Mock(return_value=([workflow], False))), + ) + + api = snippet_workflow_module.SnippetPublishedAllWorkflowApi() + handler = unwrap(api.get) + + with app.test_request_context("/snippets/snippet-1/workflows?page=1&limit=20"): + response = handler(api, snippet=snippet) + + assert response["items"][0]["input_fields"] == input_fields + + def test_restore_published_snippet_workflow_to_draft_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow = SimpleNamespace( unique_hash="restored-hash",