diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 7710a6770b..fef10b79a7 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -301,7 +301,7 @@ class PublishedRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -class RagPipelineDatasourceNodeRunApi(Resource): +class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @login_required @account_initialization_required @@ -336,10 +336,50 @@ class RagPipelineDatasourceNodeRunApi(Resource): user_inputs=inputs, account=current_user, datasource_type=datasource_type, + is_published=True ) return result +class RagPipelineDrafDatasourceNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs == None: + raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type == None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False + ) + + return result class RagPipelinePublishedNodeRunApi(Resource): @setup_required @@ -851,8 +891,13 @@ api.add_resource( "/rag/pipelines//workflows/draft/nodes//run", ) api.add_resource( - RagPipelineDatasourceNodeRunApi, - "/rag/pipelines//workflows/datasource/nodes//run", + RagPipelinePublishedDatasourceNodeRunApi, + "/rag/pipelines//workflows/published/datasource/nodes//run", +) + +api.add_resource( + RagPipelineDrafDatasourceNodeRunApi, + "/rag/pipelines//workflows/draft/datasource/nodes//run", ) api.add_resource( diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 79925a1c1b..d899e89b02 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -414,27 +414,30 @@ class RagPipelineService: return workflow_node_execution def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool ) -> dict: """ Run published workflow datasource """ + if is_published: # fetch published workflow by app_model - published_workflow = self.get_published_workflow(pipeline=pipeline) - if not published_workflow: + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: raise ValueError("Workflow not initialized") # run draft workflow node datasource_node_data = None start_at = time.perf_counter() - datasource_nodes = published_workflow.graph_dict.get("nodes", []) + datasource_nodes = workflow.graph_dict.get("nodes", []) for datasource_node in datasource_nodes: if datasource_node.get("id") == node_id: datasource_node_data = datasource_node.get("data", {}) break if not datasource_node_data: raise ValueError("Datasource node data not found") - + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) for key, value in datasource_parameters.items(): if not user_inputs.get(key): @@ -651,7 +654,7 @@ class RagPipelineService: if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" ] return datasource_provider_variables - + def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get first step parameters of rag pipeline @@ -683,7 +686,7 @@ class RagPipelineService: if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): user_input_variables.append(variables_map.get(key, {})) return user_input_variables - + def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get first step parameters of rag pipeline