From 797d044714f5e8da4afe9c7bf99cc56e6495067a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 29 May 2025 09:53:42 +0800 Subject: [PATCH] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 13 -------- api/core/plugin/impl/datasource.py | 11 +++---- .../nodes/datasource/datasource_node.py | 18 +++++++---- api/services/dataset_service.py | 27 +++++++++-------- api/services/rag_pipeline/rag_pipeline.py | 30 +++++++++++-------- 5 files changed, 50 insertions(+), 49 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index fc6eab529a..09ff07646f 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -462,18 +462,6 @@ class PublishedRagPipelineApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.") - args = parser.parse_args() - - if not args.get("knowledge_base_setting"): - raise ValueError("Missing knowledge base setting.") - - knowledge_base_setting_data = args.get("knowledge_base_setting") - if not knowledge_base_setting_data: - raise ValueError("Missing knowledge base setting.") - - knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data) rag_pipeline_service = RagPipelineService() with Session(db.engine) as session: pipeline = session.merge(pipeline) @@ -481,7 +469,6 @@ class PublishedRagPipelineApi(Resource): session=session, pipeline=pipeline, account=current_user, - knowledge_base_setting=knowledge_base_setting, ) pipeline.is_published = True pipeline.workflow_id = workflow.id diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 7847218bb9..51d5489c4c 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -22,11 +22,12 @@ class PluginDatasourceManager(BasePluginClient): """ def transformer(json_response: dict[str, Any]) -> dict: - for provider in json_response.get("data", []): - declaration = provider.get("declaration", {}) or {} - provider_name = declaration.get("identity", {}).get("name") - for datasource in declaration.get("datasources", []): - datasource["identity"]["provider"] = provider_name + if json_response.get("data"): + for provider in json_response.get("data", []): + declaration = provider.get("declaration", {}) or {} + provider_name = declaration.get("identity", {}).get("name") + for datasource in declaration.get("datasources", []): + datasource["identity"]["provider"] = provider_name return json_response diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 8f841f9564..b44039298c 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -9,6 +9,7 @@ from core.datasource.entities.datasource_entities import ( ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.file import File +from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment, FileSegment from core.variables.variables import ArrayAnyVariable @@ -118,7 +119,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): }, ) case DatasourceProviderType.LOCAL_FILE: - upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first() + related_id = datasource_info.get("related_id") + if not related_id: + raise DatasourceNodeError( + "File is not exist" + ) + upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() if not upload_file: raise ValueError("Invalid upload file Info") @@ -128,14 +134,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): extension="." + upload_file.extension, mime_type=upload_file.mime_type, tenant_id=self.tenant_id, - type=datasource_info.get("type", ""), - transfer_method=datasource_info.get("transfer_method", ""), + type=FileType.CUSTOM, + transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, related_id=upload_file.id, size=upload_file.size, storage_key=upload_file.key, ) - variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)]) + variable_pool.add([self.node_id, "file"], [file_info]) for key, value in datasource_info.items(): # construct new key list new_key_list = ["file", key] @@ -147,7 +153,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): inputs=parameters_for_log, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "file_info": file_info, + "file_info": datasource_info, "datasource_type": datasource_type, }, ) @@ -220,7 +226,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - + def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 6d3891799c..7621784d37 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -53,6 +53,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeBaseUpdateConfiguration, + KnowledgeConfiguration, RagPipelineDatasetCreateEntity, ) from services.errors.account import InvalidActionError, NoPermissionError @@ -495,11 +496,11 @@ class DatasetService: @staticmethod def update_rag_pipeline_dataset_settings(session: Session, dataset: Dataset, - knowledge_base_setting: KnowledgeBaseUpdateConfiguration, + knowledge_configuration: KnowledgeConfiguration, has_published: bool = False): if not has_published: - dataset.chunk_structure = knowledge_base_setting.chunk_structure - index_method = knowledge_base_setting.index_method + dataset.chunk_structure = knowledge_configuration.chunk_structure + index_method = knowledge_configuration.index_method dataset.indexing_technique = index_method.indexing_technique if index_method == "high_quality": model_manager = ModelManager() @@ -519,26 +520,26 @@ class DatasetService: dataset.keyword_number = index_method.economy_setting.keyword_number else: raise ValueError("Invalid index method") - dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() + dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() session.add(dataset) else: - if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure: + if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: raise ValueError("Chunk structure is not allowed to be updated.") action = None - if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique: + if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique: # if update indexing_technique - if knowledge_base_setting.index_method.indexing_technique == "economy": + if knowledge_configuration.index_method.indexing_technique == "economy": raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_base_setting.index_method.indexing_technique == "high_quality": + elif knowledge_configuration.index_method.indexing_technique == "high_quality": action = "add" # get embedding model setting try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, + provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, model_type=ModelType.TEXT_EMBEDDING, - model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, + model=knowledge_configuration.index_method.embedding_setting.embedding_model_name, ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider @@ -607,9 +608,9 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) elif dataset.indexing_technique == "economy": - if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number: - dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number - dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() + if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number: + dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() session.add(dataset) session.commit() if action: diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 9e7a1d7fe2..79e793118a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -47,7 +47,7 @@ from models.workflow import ( WorkflowType, ) from services.dataset_service import DatasetService -from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity +from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory @@ -262,7 +262,6 @@ class RagPipelineService: session: Session, pipeline: Pipeline, account: Account, - knowledge_base_setting: KnowledgeBaseUpdateConfiguration, ) -> Workflow: draft_workflow_stmt = select(Workflow).where( Workflow.tenant_id == pipeline.tenant_id, @@ -291,16 +290,23 @@ class RagPipelineService: # commit db session changes session.add(workflow) - # update dataset - dataset = pipeline.dataset - if not dataset: - raise ValueError("Dataset not found") - DatasetService.update_rag_pipeline_dataset_settings( - session=session, - dataset=dataset, - knowledge_base_setting=knowledge_base_setting, - has_published=pipeline.is_published - ) + graph = workflow.graph_dict + nodes = graph.get("nodes", []) + for node in nodes: + if node.get("data", {}).get("type") == "knowledge_index": + knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) + knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + + # update dataset + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") + DatasetService.update_rag_pipeline_dataset_settings( + session=session, + dataset=dataset, + knowledge_configuration=knowledge_configuration, + has_published=pipeline.is_published + ) # return new workflow return workflow