diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 37ed5b8385..a6975f2413 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -3,6 +3,7 @@ from typing import Any, Literal, Optional from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector class RerankingModelConfig(BaseModel): @@ -44,7 +45,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): Knowledge retrieval Node Data. """ type: str = 'knowledge-retrieval' - query_variable_selector: list[str] + query_variable_selector: VariableSelector dataset_ids: list[str] retrieval_mode: Literal['single', 'multiple'] multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 0534695adb..dde89a3427 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -43,9 +43,9 @@ class KnowledgeRetrievalNode(BaseNode): node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) # extract variables - query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) + query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector.value_selector) variables = { - '_query': query + node_data.query_variable_selector.variable: query } # retrieve knowledge try: @@ -170,7 +170,7 @@ class KnowledgeRetrievalNode(BaseNode): node_data = node_data node_data = cast(cls._node_data_cls, node_data) variable_mapping = {} - variable_mapping['_query'] = node_data.query_variable_selector + variable_mapping[node_data.query_variable_selector.variable] = node_data.query_variable_selector.value_selector return variable_mapping def _single_retrieve(self, available_datasets, node_data, query): diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index 5371862ea8..c9e353572c 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -3,6 +3,7 @@ from typing import Any from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector class ModelConfig(BaseModel): @@ -42,7 +43,7 @@ class QuestionClassifierNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - query_variable_selector: list[str] + query_variable_selector: VariableSelector type: str = 'question-classifier' model: ModelConfig classes: list[ClassConfig] diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index d351dfb692..0b47e118f0 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -43,9 +43,9 @@ class QuestionClassifierNode(BaseNode): def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) # extract variables - query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) + query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector.value_selector) variables = { - '_query': query + node_data.query_variable_selector.variable: query } # fetch model config model_instance, model_config = self._fetch_model_config(node_data) @@ -104,7 +104,7 @@ class QuestionClassifierNode(BaseNode): def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: node_data = node_data node_data = cast(cls._node_data_cls, node_data) - variable_mapping = {'_query': node_data.query_variable_selector} + variable_mapping = {node_data.query_variable_selector.variable: node_data.query_variable_selector.value_selector} return variable_mapping @classmethod