diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 2b95938cb6..a497ff14ac 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -90,6 +90,7 @@ from .datasets.rag_pipeline import ( datasource_content_preview, rag_pipeline, rag_pipeline_datasets, + rag_pipeline_draft_variable, rag_pipeline_import, rag_pipeline_workflow, ) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index e5c211be93..485a73e517 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -53,6 +53,7 @@ class RagPipelineImportApi(Resource): yaml_content=args.get("yaml_content"), yaml_url=args.get("yaml_url"), pipeline_id=args.get("pipeline_id"), + dataset_name=args.get("name"), ) session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index fb311482d8..e2908d83aa 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -123,6 +123,7 @@ class RagPipelineDslService: yaml_url: Optional[str] = None, pipeline_id: Optional[str] = None, dataset: Optional[Dataset] = None, + dataset_name: Optional[str] = None, ) -> RagPipelineImportInfo: """Import an app from YAML content or URL.""" import_id = str(uuid.uuid4()) @@ -265,7 +266,7 @@ class RagPipelineDslService: dependencies=check_dependencies_pending_data, ) # create dataset - name = pipeline.name + name = dataset_name or pipeline.name description = pipeline.description icon_type = data.get("rag_pipeline", {}).get("icon_type") icon = data.get("rag_pipeline", {}).get("icon") @@ -883,6 +884,7 @@ class RagPipelineDslService: import_mode=ImportMode.YAML_CONTENT.value, yaml_content=rag_pipeline_dataset_create_entity.yaml_content, dataset=None, + dataset_name=rag_pipeline_dataset_create_entity.name, ) return { "id": rag_pipeline_import_info.id,