From 3e810bc490b1f7b06b7007ba7f9d3c7952b856b7 Mon Sep 17 00:00:00 2001 From: jyong Date: Mon, 18 Mar 2024 21:22:16 +0800 Subject: [PATCH] knowledge fix --- api/core/workflow/nodes/knowledge_retrieval/entities.py | 9 ++++----- .../knowledge_retrieval/knowledge_retrieval_node.py | 6 +++--- api/core/workflow/nodes/question_classifier/entities.py | 3 +-- .../question_classifier/question_classifier_node.py | 6 +++--- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index a6975f2413..d6a5111a43 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -3,7 +3,6 @@ from typing import Any, Literal, Optional from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.variable_entities import VariableSelector class RerankingModelConfig(BaseModel): @@ -11,7 +10,7 @@ class RerankingModelConfig(BaseModel): Reranking Model Config. """ provider: str - model: str + mode: str class MultipleRetrievalConfig(BaseModel): @@ -45,8 +44,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData): Knowledge retrieval Node Data. """ type: str = 'knowledge-retrieval' - query_variable_selector: VariableSelector + query_variable_selector: list[str] dataset_ids: list[str] retrieval_mode: Literal['single', 'multiple'] - multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None - singleRetrievalConfig: Optional[SingleRetrievalConfig] = None + multiple_retrieval_config: Optional[MultipleRetrievalConfig] + singleRetrievalConfig: Optional[SingleRetrievalConfig] diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 7f145cfdf4..a4d16cc44f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -43,9 +43,9 @@ class KnowledgeRetrievalNode(BaseNode): node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) # extract variables - query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector.value_selector) + query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) variables = { - node_data.query_variable_selector.variable: query + 'query': query } # retrieve knowledge try: @@ -170,7 +170,7 @@ class KnowledgeRetrievalNode(BaseNode): node_data = node_data node_data = cast(cls._node_data_cls, node_data) variable_mapping = {} - variable_mapping[node_data.query_variable_selector.variable] = node_data.query_variable_selector.value_selector + variable_mapping['query'] = node_data.query_variable_selector return variable_mapping def _single_retrieve(self, available_datasets, node_data, query): diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index c9e353572c..5371862ea8 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -3,7 +3,6 @@ from typing import Any from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.variable_entities import VariableSelector class ModelConfig(BaseModel): @@ -43,7 +42,7 @@ class QuestionClassifierNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - query_variable_selector: VariableSelector + query_variable_selector: list[str] type: str = 'question-classifier' model: ModelConfig classes: list[ClassConfig] diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 0b47e118f0..42bd141faf 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -43,9 +43,9 @@ class QuestionClassifierNode(BaseNode): def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) # extract variables - query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector.value_selector) + query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) variables = { - node_data.query_variable_selector.variable: query + 'query': query } # fetch model config model_instance, model_config = self._fetch_model_config(node_data) @@ -104,7 +104,7 @@ class QuestionClassifierNode(BaseNode): def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: node_data = node_data node_data = cast(cls._node_data_cls, node_data) - variable_mapping = {node_data.query_variable_selector.variable: node_data.query_variable_selector.value_selector} + variable_mapping = {'query': node_data.query_variable_selector} return variable_mapping @classmethod