diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index e433cdb98b..3ec7a28bd1 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -18,6 +18,7 @@ from core.workflow.constants import ( ) from core.workflow.system_variable import SystemVariable from factories import variable_factory +from services.rag_pipeline import rag_pipeline VariableValue = Union[str, int, float, dict, list, File] @@ -66,10 +67,16 @@ class VariablePool(BaseModel): for var in self.conversation_variables: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) # Add rag pipeline variables to the variable pool - for var in self.rag_pipeline_variables: - # Combine belong_to_node_id and variable into a single variable name - variable_name = f"{var.variable.belong_to_node_id}.{var.variable.variable}" - self.add((RAG_PIPELINE_VARIABLE_NODE_ID, variable_name), var.value) + if self.rag_pipeline_variables: + rag_pipeline_variables_map = defaultdict(dict) + for var in self.rag_pipeline_variables: + node_id = var.variable.belong_to_node_id + key = var.variable.variable + value = var.value + rag_pipeline_variables_map[node_id][key] = value + for key, value in rag_pipeline_variables_map.items(): + self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) + def add(self, selector: Sequence[str], value: Any, /) -> None: """