diff --git a/api/app_factory.py b/api/app_factory.py index 17c376de77..933cf294d1 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -18,6 +18,7 @@ def create_flask_app_with_configs() -> DifyApp: """ dify_app = DifyApp(__name__) dify_app.config.from_mapping(dify_config.model_dump()) + dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True # add before request hook @dify_app.before_request diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index d413def27f..5e3b3428eb 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -1,7 +1,7 @@ from flask_restx import ( # type: ignore Resource, # type: ignore - reqparse, ) +from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.console import api, console_ns @@ -12,17 +12,21 @@ from models import Account from models.dataset import Pipeline from services.rag_pipeline.rag_pipeline import RagPipelineService -parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("credential_id", type=str, required=False, location="json") -) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class Parser(BaseModel): + inputs: dict + datasource_type: str + credential_id: str | None = None + + +console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview") class DataSourceContentPreviewApi(Resource): - @api.expect(parser) + @api.expect(console_ns.models[Parser.__name__], validate=True) @setup_required @login_required @account_initialization_required @@ -34,15 +38,10 @@ class DataSourceContentPreviewApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - args = parser.parse_args() - - inputs = args.get("inputs") - if inputs is None: - raise ValueError("missing inputs") - datasource_type = args.get("datasource_type") - if datasource_type is None: - raise ValueError("missing datasource_type") + args = Parser.model_validate(api.payload) + inputs = args.inputs + datasource_type = args.datasource_type rag_pipeline_service = RagPipelineService() preview_content = rag_pipeline_service.run_datasource_node_preview( pipeline=pipeline, @@ -51,6 +50,6 @@ class DataSourceContentPreviewApi(Resource): account=current_user, datasource_type=datasource_type, is_published=True, - credential_id=args.get("credential_id"), + credential_id=args.credential_id, ) return preview_content, 200 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 7c525845a7..ed47e706b6 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,7 +1,10 @@ import json +from typing import Self +from uuid import UUID from flask import request from flask_restx import marshal, reqparse +from pydantic import BaseModel, model_validator from sqlalchemy import desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -31,7 +34,7 @@ from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DatasetService, DocumentService -from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService # Define parsers for document operations @@ -51,15 +54,26 @@ document_text_create_parser = ( .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") ) -document_text_update_parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, nullable=True, location="json") - .add_argument("text", type=str, required=False, nullable=True, location="json") - .add_argument("process_rule", type=dict, required=False, nullable=True, location="json") - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") -) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class DocumentTextUpdate(BaseModel): + name: str | None = None + text: str | None = None + process_rule: ProcessRule | None = None + doc_form: str = "text_model" + doc_language: str = "English" + retrieval_model: RetrievalModel | None = None + + @model_validator(mode="after") + def check_text_and_name(self) -> Self: + if self.text is not None and self.name is None: + raise ValueError("name is required when text is provided") + return self + + +for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]: + service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore @service_api_ns.route( @@ -160,7 +174,7 @@ class DocumentAddByTextApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" - @service_api_ns.expect(document_text_update_parser) + @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True) @service_api_ns.doc("update_document_by_text") @service_api_ns.doc(description="Update an existing document by providing text content") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @@ -173,12 +187,10 @@ class DocumentUpdateByTextApi(DatasetApiResource): ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id): + def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" - args = document_text_update_parser.parse_args() - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True) + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -198,11 +210,9 @@ class DocumentUpdateByTextApi(DatasetApiResource): # indexing_technique is already set in dataset since this is an update args["indexing_technique"] = dataset.indexing_technique - if args["text"]: + if args.get("text"): text = args.get("text") name = args.get("name") - if text is None or name is None: - raise ValueError("Both text and name must be strings.") if not current_user: raise ValueError("current_user is required") upload_file = FileService(db.engine).upload_text(