From 1ad7b0e85288a1b6f30524be099e10ca9af0b10f Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 25 Aug 2025 18:26:05 +0800 Subject: [PATCH] add pipeline async run --- .../rag_pipeline/rag_pipeline_workflow.py | 36 +++++++++++ api/services/rag_pipeline/rag_pipeline.py | 63 +++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index cf4baadb1d..0b48cb594b 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -964,6 +964,38 @@ class RagPipelineTransformApi(Resource): return result +class RagPipelineDatasourceVariableApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_node_execution_fields) + def post(self, pipeline: Pipeline): + """ + Set datasource variables + """ + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("datasource_info", type=dict, required=True, location="json") + parser.add_argument("start_node_id", type=str, required=True, location="json") + parser.add_argument("start_node_title", type=str, required=True, location="json") + args = parser.parse_args() + + rag_pipeline_service = RagPipelineService() + workflow_node_execution = rag_pipeline_service.set_datasource_variables( + pipeline=pipeline, + args=args, + current_user=current_user, + ) + return workflow_node_execution + + api.add_resource( DraftRagPipelineApi, "/rag/pipelines//workflows/draft", @@ -1076,3 +1108,7 @@ api.add_resource( RagPipelineTransformApi, "/rag/pipelines/transform/datasets/", ) +api.add_resource( + RagPipelineDatasourceVariableApi, + "/rag/pipelines//workflows/draft/datasource/variables-inspect", +) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 9bd17778bd..deb645273f 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1148,3 +1148,66 @@ class RagPipelineService: .first() ) return node_exec + + def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account | EndUser): + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + if not draft_workflow: + raise ValueError("Workflow not initialized") + workflow_node_execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=draft_workflow.id, + index=1, + node_id=args.get("start_node_id", ""), + node_type=NodeType.DATASOURCE, + title=args.get("start_node_title", "Datasource"), + elapsed_time=0, + finished_at=datetime.now(UTC).replace(tzinfo=None), + created_at=datetime.now(UTC).replace(tzinfo=None), + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=None, + metadata=None, + ) + outputs = { + **args.get("datasource_info", {}), + "datasource_type": args.get("datasource_type", ""), + } + workflow_node_execution.outputs = outputs + node_config = draft_workflow.get_node_config_by_id(args.get("start_node_id", "")) + + eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + if eclosing_node_type_and_id: + _, enclosing_node_id = eclosing_node_type_and_id + else: + enclosing_node_id = None + + # Create repository and save the node execution + repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=db.engine, + user=current_user, + app_id=pipeline.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + repository.save(workflow_node_execution) + + # Convert node_execution to WorkflowNodeExecution after save + workflow_node_execution_db_model = repository.to_db_model(workflow_node_execution) + + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=pipeline.id, + node_id=workflow_node_execution_db_model.node_id, + node_type=NodeType(workflow_node_execution_db_model.node_type), + enclosing_node_id=enclosing_node_id, + node_execution_id=workflow_node_execution.id, + ) + draft_var_saver.save( + process_data=workflow_node_execution.process_data, + outputs=workflow_node_execution.outputs, + ) + session.commit() + return workflow_node_execution_db_model + + +