diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index d7ed5d475d..b66a747121 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -677,6 +677,53 @@ class PublishedRagPipelineSecondStepApi(Resource): "variables": variables, } +class PublishedRagPipelineFirstStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get first step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_published_first_step_parameters(pipeline=pipeline, node_id=node_id) + return { + "variables": variables, + } + +class DraftRagPipelineFirstStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get first step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_draft_first_step_parameters(pipeline=pipeline, node_id=node_id) + return { + "variables": variables, + } class DraftRagPipelineSecondStepApi(Resource): @setup_required @@ -862,7 +909,15 @@ api.add_resource( PublishedRagPipelineSecondStepApi, "/rag/pipelines//workflows/published/processing/parameters", ) +api.add_resource( + PublishedRagPipelineFirstStepApi, + "/rag/pipelines//workflows/published/pre-processing/parameters", +) api.add_resource( DraftRagPipelineSecondStepApi, "/rag/pipelines//workflows/draft/processing/parameters", ) +api.add_resource( + DraftRagPipelineFirstStepApi, + "/rag/pipelines//workflows/draft/pre-processing/parameters", +) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 92b2daea54..8c76fc161d 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -59,7 +59,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): raise DatasourceNodeError("Datasource type is not set") datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=node_data.provider_id, + provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, datasource_type=DatasourceProviderType.value_of(datasource_type), diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index dee3c1d2fb..b182928baa 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -7,7 +7,7 @@ from core.workflow.nodes.base.entities import BaseNodeData class DatasourceEntity(BaseModel): - provider_id: str + plugin_id: str provider_name: str # redundancy provider_type: str datasource_name: Optional[str] = "local_file" diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 9c4a054184..80b961851a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1,4 +1,5 @@ import json +import re import threading import time from collections.abc import Callable, Generator, Sequence @@ -434,14 +435,19 @@ class RagPipelineService: datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) if not datasource_node_data: raise ValueError("Datasource node data not found") + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] from core.datasource.datasource_manager import DatasourceManager datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=datasource_node_data.get("provider_id"), + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", datasource_name=datasource_node_data.get("datasource_name"), tenant_id=pipeline.tenant_id, datasource_type=DatasourceProviderType(datasource_type), ) + if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( @@ -648,6 +654,60 @@ class RagPipelineService: if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" ] return datasource_provider_variables + + def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + """ + Get first step parameters of rag pipeline + """ + + published_workflow = self.get_published_workflow(pipeline=pipeline) + if not published_workflow: + raise ValueError("Workflow not initialized") + + # get second step node + datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) + if not datasource_node_data: + raise ValueError("Datasource node data not found") + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + if datasource_parameters: + datasource_parameters_map = { + item["variable"]: item for item in datasource_parameters + } + else: + datasource_parameters_map = {} + variables = datasource_node_data.get("variables", {}) + user_input_variables = [] + for key, value in variables.items(): + if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): + user_input_variables.append(datasource_parameters_map.get(key, {})) + return user_input_variables + + def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + """ + Get first step parameters of rag pipeline + """ + + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + # get second step node + datasource_node_data = draft_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) + if not datasource_node_data: + raise ValueError("Datasource node data not found") + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + if datasource_parameters: + datasource_parameters_map = { + item["variable"]: item for item in datasource_parameters + } + else: + datasource_parameters_map = {} + variables = datasource_node_data.get("variables", {}) + user_input_variables = [] + for key, value in variables.items(): + if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): + user_input_variables.append(datasource_parameters_map.get(key, {})) + return user_input_variables def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """