diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index e28aa02593..709782a123 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1311,3 +1311,54 @@ class RagPipelineService: "installed_recommended_plugins": installed_plugin_list, "uninstalled_recommended_plugins": uninstalled_plugin_list, } + + def get_datasource_plugins(self, pipeline_id: str, is_published: bool) -> list[dict]: + """ + Get datasource plugins + """ + pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + + workflow: Workflow | None = None + if is_published: + workflow: Workflow | None = self.get_published_workflow(pipeline=pipeline) + else: + workflow: Workflow | None = self.get_draft_workflow(pipeline=pipeline) + if not pipeline or not workflow: + raise ValueError("Pipeline or workflow not found") + + datasource_nodes = workflow.graph_dict.get("nodes", []) + datasource_plugins = [] + for datasource_node in datasource_nodes: + if datasource_node.get("type") == "datasource": + datasource_node_data = datasource_node.get("data", {}) + if not datasource_node_data: + continue + + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + user_input_variables_keys = [] + user_input_variables = [] + + for _, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + user_input_variables_keys.append(last_part) + elif value.get("value") and isinstance(value.get("value"), list): + last_part = value.get("value")[-1] + user_input_variables_keys.append(last_part) + for key, value in variables_map.items(): + if key in user_input_variables_keys: + user_input_variables.append(value) + + datasource_plugins.append({ + "plugin_id": datasource_node_data.get("plugin_id"), + "provider_name": datasource_node_data.get("provider_name"), + "datasource_name": datasource_node_data.get("datasource_name"), + "datasource_configurations": datasource_node_data.get("datasource_configurations"), + "plugin_unique_identifier": datasource_node_data.get("plugin_unique_identifier"), + }) + + return datasource_plugins