From 4cebaa331e1ddf48f0de44b3814977eec32b1849 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 25 Sep 2025 17:18:23 +0800 Subject: [PATCH] add pipeline template endpoint --- api/configs/enterprise/__init__.py | 15 ++ .../datasets/rag_pipeline/rag_pipeline.py | 191 +++++++++++++++++- ...0bcbf45396_remove_builtin_template_user.py | 37 ++++ api/models/dataset.py | 9 - .../rag_pipeline_entities.py | 7 + .../database/database_retrieval.py | 1 - api/services/rag_pipeline/rag_pipeline.py | 140 +++++++++++++ 7 files changed, 383 insertions(+), 17 deletions(-) create mode 100644 api/migrations/versions/2025_09_25_1650-bf0bcbf45396_remove_builtin_template_user.py diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index eda6345e14..af5640934e 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -18,3 +18,18 @@ class EnterpriseFeatureConfig(BaseSettings): description="Allow customization of the enterprise logo.", default=False, ) + + UPLOAD_KNOWLEDGE_PIPELINE_TEMPLATE_TOKEN: str = Field( + description="Token for uploading knowledge pipeline template.", + default="", + ) + + KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT: str = Field( + description="Knowledge pipeline template copyright.", + default="Copyright 2023 Dify", + ) + + KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY: str = Field( + description="Knowledge pipeline template privacy policy.", + default="https://dify.ai", + ) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index f04b0e04c3..c8bb976c44 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -14,7 +14,10 @@ from controllers.console.wraps import ( from extensions.ext_database import db from libs.login import login_required from models.dataset import PipelineCustomizedTemplate -from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity +from services.entities.knowledge_entities.rag_pipeline_entities import ( + PipelineBuiltInTemplateEntity, + PipelineTemplateInfoEntity, +) from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) @@ -26,12 +29,6 @@ def _validate_name(name): return name -def _validate_description_length(description): - if len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - class PipelineTemplateListApi(Resource): @setup_required @login_required @@ -146,6 +143,186 @@ class PublishCustomizedPipelineTemplateApi(Resource): return {"result": "success"} +class PipelineTemplateInstallApi(Resource): + """API endpoint for installing built-in pipeline templates""" + + def post(self): + """ + Install a built-in pipeline template + + Args: + template_id: The template ID from URL parameter + + Returns: + Success response or error with appropriate HTTP status + """ + try: + # Extract and validate Bearer token + auth_token = self._extract_bearer_token() + + # Parse and validate request parameters + template_args = self._parse_template_args() + + # Process uploaded template file + file_content = self._process_template_file() + + # Create template entity + pipeline_built_in_template_entity = PipelineBuiltInTemplateEntity(**template_args) + + # Install the template + rag_pipeline_service = RagPipelineService() + rag_pipeline_service.install_built_in_pipeline_template( + pipeline_built_in_template_entity, file_content, auth_token + ) + + return {"result": "success", "message": "Template installed successfully"}, 200 + + except ValueError as e: + logger.exception("Validation error in template installation") + return {"error": str(e)}, 400 + except Exception as e: + logger.exception("Unexpected error in template installation") + return {"error": "An unexpected error occurred during template installation"}, 500 + + def _extract_bearer_token(self) -> str: + """ + Extract and validate Bearer token from Authorization header + + Returns: + The extracted token string + + Raises: + ValueError: If token is missing or invalid + """ + auth_header = request.headers.get("Authorization", "").strip() + + if not auth_header: + raise ValueError("Authorization header is required") + + if not auth_header.startswith("Bearer "): + raise ValueError("Authorization header must start with 'Bearer '") + + token_parts = auth_header.split(" ", 1) + if len(token_parts) != 2: + raise ValueError("Invalid Authorization header format") + + auth_token = token_parts[1].strip() + if not auth_token: + raise ValueError("Bearer token cannot be empty") + + return auth_token + + def _parse_template_args(self) -> dict: + """ + Parse and validate template arguments from form data + + Args: + template_id: The template ID from URL + + Returns: + Dictionary of validated template arguments + """ + # Use reqparse for consistent parameter parsing + parser = reqparse.RequestParser() + + parser.add_argument( + "template_id", + type=str, + location="form", + required=False, + help="Template ID for updating existing template" + ) + parser.add_argument( + "language", + type=str, + location="form", + required=True, + default="en-US", + choices=["en-US", "zh-CN", "ja-JP"], + help="Template language code" + ) + parser.add_argument( + "name", + type=str, + location="form", + required=True, + default="New Pipeline Template", + help="Template name (1-200 characters)" + ) + parser.add_argument( + "description", + type=str, + location="form", + required=False, + default="", + help="Template description (max 1000 characters)" + ) + + args = parser.parse_args() + + # Additional validation + if args.get("name"): + args["name"] = self._validate_name(args["name"]) + + if args.get("description") and len(args["description"]) > 1000: + raise ValueError("Description must not exceed 1000 characters") + + # Filter out None values + return {k: v for k, v in args.items() if v is not None} + + def _validate_name(self, name: str) -> str: + """ + Validate template name + + Args: + name: Template name to validate + + Returns: + Validated and trimmed name + + Raises: + ValueError: If name is invalid + """ + name = name.strip() + if not name or len(name) < 1 or len(name) > 200: + raise ValueError("Template name must be between 1 and 200 characters") + return name + + def _process_template_file(self) -> str: + """ + Process and validate uploaded template file + + Returns: + File content as string + + Raises: + ValueError: If file is missing or invalid + """ + if "file" not in request.files: + raise ValueError("Template file is required") + + file = request.files["file"] + + # Validate file + if not file or not file.filename: + raise ValueError("No file selected") + + filename = file.filename.strip() + if not filename: + raise ValueError("File name cannot be empty") + + # Check file extension + if not filename.lower().endswith(".pipeline"): + raise ValueError("Template file must be a pipeline file (.pipeline)") + + try: + file_content = file.read().decode("utf-8") + except UnicodeDecodeError: + raise ValueError("Template file must be valid UTF-8 text") + + return file_content + + api.add_resource( PipelineTemplateListApi, "/rag/pipeline/templates", diff --git a/api/migrations/versions/2025_09_25_1650-bf0bcbf45396_remove_builtin_template_user.py b/api/migrations/versions/2025_09_25_1650-bf0bcbf45396_remove_builtin_template_user.py new file mode 100644 index 0000000000..67731111f9 --- /dev/null +++ b/api/migrations/versions/2025_09_25_1650-bf0bcbf45396_remove_builtin_template_user.py @@ -0,0 +1,37 @@ +"""remove-builtin-template-user + +Revision ID: bf0bcbf45396 +Revises: 68519ad5cd18 +Create Date: 2025-09-25 16:50:32.245503 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'bf0bcbf45396' +down_revision = '68519ad5cd18' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.drop_column('updated_by') + batch_op.drop_column('created_by') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False)) + batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True)) + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 25ebe14738..d7d204e2bc 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1239,15 +1239,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] language = db.Column(db.String(255), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by = db.Column(StringUUID, nullable=False) - updated_by = db.Column(StringUUID, nullable=True) - - @property - def created_user_name(self): - account = db.session.query(Account).where(Account.id == self.created_by).first() - if account: - return account.name - return "" class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 860bfde401..022fa47525 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -128,3 +128,10 @@ class KnowledgeConfiguration(BaseModel): if v is None: return "" return v + + +class PipelineBuiltInTemplateEntity(BaseModel): + template_id: str | None = None + name: str + description: str + language: str diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index ec91f79606..908f9a2684 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -74,5 +74,4 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): "chunk_structure": pipeline_template.chunk_structure, "export_data": pipeline_template.yaml_content, "graph": graph_data, - "created_by": pipeline_template.created_user_name, } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index fdaaa73bcc..14e3364690 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -8,6 +8,7 @@ from datetime import UTC, datetime from typing import Any, Union, cast from uuid import uuid4 +import yaml from flask_login import current_user from sqlalchemy import func, or_, select from sqlalchemy.orm import Session, sessionmaker @@ -60,6 +61,7 @@ from models.dataset import ( # type: ignore Document, DocumentPipelineExecutionLog, Pipeline, + PipelineBuiltInTemplate, PipelineCustomizedTemplate, PipelineRecommendedPlugin, ) @@ -76,6 +78,7 @@ from repositories.factory import DifyAPIRepositoryFactory from services.datasource_provider_service import DatasourceProviderService from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, + PipelineBuiltInTemplateEntity, PipelineTemplateInfoEntity, ) from services.errors.app import WorkflowHashNotEqualError @@ -1454,3 +1457,140 @@ class RagPipelineService: if not pipeline: raise ValueError("Pipeline not found") return pipeline + + def install_built_in_pipeline_template( + self, args: PipelineBuiltInTemplateEntity, file_content: str, auth_token: str + ) -> None: + """ + Install built-in pipeline template + + Args: + args: Pipeline built-in template entity with template metadata + file_content: YAML content of the pipeline template + auth_token: Authentication token for authorization + + Raises: + ValueError: If validation fails or template processing errors occur + """ + # Validate authentication + self._validate_auth_token(auth_token) + + # Parse and validate template content + pipeline_template_dsl = self._parse_template_content(file_content) + + # Extract template metadata + icon = self._extract_icon_metadata(pipeline_template_dsl) + chunk_structure = self._extract_chunk_structure(pipeline_template_dsl) + + # Prepare template data + template_data = { + "name": args.name, + "description": args.description, + "chunk_structure": chunk_structure, + "icon": icon, + "language": args.language, + "yaml_content": file_content, + } + + # Use transaction for database operations + try: + if args.template_id: + self._update_existing_template(args.template_id, template_data) + else: + self._create_new_template(template_data) + db.session.commit() + except Exception as e: + db.session.rollback() + raise ValueError(f"Failed to install pipeline template: {str(e)}") + + def _validate_auth_token(self, auth_token: str) -> None: + """Validate the authentication token""" + config_auth_token = dify_config.UPLOAD_KNOWLEDGE_PIPELINE_TEMPLATE_TOKEN + if not config_auth_token: + raise ValueError("Auth token configuration is required") + if config_auth_token != auth_token: + raise ValueError("Auth token is incorrect") + + def _parse_template_content(self, file_content: str) -> dict: + """Parse and validate YAML template content""" + try: + pipeline_template_dsl = yaml.safe_load(file_content) + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML content: {str(e)}") + + if not pipeline_template_dsl: + raise ValueError("Pipeline template DSL is required") + + return pipeline_template_dsl + + def _extract_icon_metadata(self, pipeline_template_dsl: dict) -> dict: + """Extract icon metadata from template DSL""" + rag_pipeline_info = pipeline_template_dsl.get("rag_pipeline", {}) + + return { + "icon": rag_pipeline_info.get("icon", "📙"), + "icon_type": rag_pipeline_info.get("icon_type", "emoji"), + "icon_background": rag_pipeline_info.get("icon_background", "#FFEAD5"), + "icon_url": rag_pipeline_info.get("icon_url"), + } + + def _extract_chunk_structure(self, pipeline_template_dsl: dict) -> str: + """Extract chunk structure from template DSL""" + nodes = pipeline_template_dsl.get("workflow", {}).get("graph", {}).get("nodes", []) + + # Use generator expression for efficiency + chunk_structure = next( + ( + node.get("data", {}).get("chunk_structure") + for node in nodes + if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_INDEX.value + ), + None + ) + + if not chunk_structure: + raise ValueError("Chunk structure is required in template") + + return chunk_structure + + def _update_existing_template(self, template_id: str, template_data: dict) -> None: + """Update an existing pipeline template""" + pipeline_built_in_template = ( + db.session.query(PipelineBuiltInTemplate) + .filter(PipelineBuiltInTemplate.id == template_id) + .first() + ) + + if not pipeline_built_in_template: + raise ValueError(f"Pipeline built-in template not found: {template_id}") + + # Update template fields + for key, value in template_data.items(): + setattr(pipeline_built_in_template, key, value) + + db.session.add(pipeline_built_in_template) + + def _create_new_template(self, template_data: dict) -> None: + """Create a new pipeline template""" + # Get the next available position + position = self._get_next_position(template_data["language"]) + + # Add additional fields for new template + template_data.update({ + "position": position, + "install_count": 0, + "copyright": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT, + "privacy_policy": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY, + }) + + new_template = PipelineBuiltInTemplate(**template_data) + db.session.add(new_template) + + def _get_next_position(self, language: str) -> int: + """Get the next available position for a template in the specified language""" + max_position = ( + db.session.query(func.max(PipelineBuiltInTemplate.position)) + .filter(PipelineBuiltInTemplate.language == language) + .scalar() + ) + return (max_position or 0) + 1