diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index d88c63feef..e7ec1555cc 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -37,6 +37,7 @@ from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import ( + IconInfo, KnowledgeConfiguration, RagPipelineDatasetCreateEntity, ) @@ -126,6 +127,7 @@ class RagPipelineDslService: pipeline_id: Optional[str] = None, dataset: Optional[Dataset] = None, dataset_name: Optional[str] = None, + icon_info: Optional[IconInfo] = None, ) -> RagPipelineImportInfo: """Import an app from YAML content or URL.""" import_id = str(uuid.uuid4()) @@ -274,10 +276,16 @@ class RagPipelineDslService: # create dataset name = dataset_name or pipeline.name description = pipeline.description - icon_type = data.get("rag_pipeline", {}).get("icon_type") - icon = data.get("rag_pipeline", {}).get("icon") - icon_background = data.get("rag_pipeline", {}).get("icon_background") - icon_url = data.get("rag_pipeline", {}).get("icon_url") + if icon_info: + icon_type = icon_info.icon_type + icon = icon_info.icon + icon_background = icon_info.icon_background + icon_url = icon_info.icon_url + else: + icon_type = data.get("rag_pipeline", {}).get("icon_type") + icon = data.get("rag_pipeline", {}).get("icon") + icon_background = data.get("rag_pipeline", {}).get("icon_background") + icon_url = data.get("rag_pipeline", {}).get("icon_url") workflow = data.get("workflow", {}) graph = workflow.get("graph", {}) nodes = graph.get("nodes", []) @@ -925,6 +933,7 @@ class RagPipelineDslService: yaml_content=rag_pipeline_dataset_create_entity.yaml_content, dataset=None, dataset_name=rag_pipeline_dataset_create_entity.name, + icon_info=rag_pipeline_dataset_create_entity.icon_info, ) return { "id": rag_pipeline_import_info.id,