diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 22b5b64009..df2cc3f892 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -3,7 +3,7 @@ import logging from typing import Any, Literal, cast from flask import abort, request -from flask_restx import Resource, marshal_with, reqparse # type: ignore +from flask_restx import Resource, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -120,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel): start_node_title: str +class RagPipelineRecommendedPluginQuery(BaseModel): + type: str = "all" + + register_schema_models( console_ns, DraftWorkflowSyncPayload, @@ -134,6 +138,7 @@ register_schema_models( NodeIdQuery, WorkflowRunQuery, DatasourceVariablesPayload, + RagPipelineRecommendedPluginQuery, ) @@ -974,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, location="args", required=False, default="all") - args = parser.parse_args() - type = args["type"] + query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict()) rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) + recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type) return recommended_plugins