From 5013ea09d59605b2a752f541514c296d507835a2 Mon Sep 17 00:00:00 2001 From: jyong Date: Sat, 16 Mar 2024 00:54:29 +0800 Subject: [PATCH] variable assigner node --- .../nodes/knowledge_retrieval/entities.py | 3 +++ .../knowledge_retrieval_node.py | 8 ++++---- .../nodes/question_classifier/entities.py | 3 ++- .../question_classifier_node.py | 9 ++++----- .../nodes/variable_assigner/entities.py | 17 +++++++++++++++++ .../variable_assigner/variable_assigner_node.py | 12 +++++++++++- 6 files changed, 41 insertions(+), 11 deletions(-) create mode 100644 api/core/workflow/nodes/variable_assigner/entities.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 89e62c7b9b..1a5c6f6d08 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -44,6 +44,9 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ + title: str + desc: str + type: str = 'knowledge-retrieval' query_variable_selector: list[str] dataset_ids: list[str] retrieval_mode: Literal['single', 'multiple'] diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index b054991537..4d5970aaef 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -44,7 +44,7 @@ class KnowledgeRetrievalNode(BaseNode): # extract variables query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) variables = { - 'query': query + '_query': query } # retrieve knowledge try: @@ -163,9 +163,9 @@ class KnowledgeRetrievalNode(BaseNode): def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: node_data = node_data node_data = cast(cls._node_data_cls, node_data) - return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables - } + variable_mapping = {} + variable_mapping['_query'] = node_data.query_variable_selector + return variable_mapping def _single_retrieve(self, available_datasets, node_data, query): tools = [] diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index a407ea01c9..695e698694 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -45,7 +45,8 @@ class QuestionClassifierNodeData(BaseNodeData): """ query_variable_selector: list[str] title: str - description: str + desc: str + type: str = 'question-classifier' model: ModelConfig classes: list[ClassConfig] instruction: str diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index fdeb40c53d..158a4d1ac8 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -40,7 +40,7 @@ class QuestionClassifierNode(BaseNode): # extract variables query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) variables = { - 'query': query + '_query': query } # fetch model config model_instance, model_config = self._fetch_model_config(node_data) @@ -95,13 +95,12 @@ class QuestionClassifierNode(BaseNode): error=str(e) ) - @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + node_data = node_data node_data = cast(cls._node_data_cls, node_data) - return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables - } + variable_mapping = {'_query': node_data.query_variable_selector} + return variable_mapping @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/nodes/variable_assigner/entities.py b/api/core/workflow/nodes/variable_assigner/entities.py new file mode 100644 index 0000000000..1e61fa94bf --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/entities.py @@ -0,0 +1,17 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class VariableAssignerNodeData(BaseNodeData): + """ + Knowledge retrieval Node Data. + """ + title: str + desc: str + type: str = 'variable-assigner' + output_type: str + variables: list[str] diff --git a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py index 231a26a661..c6d11926ed 100644 --- a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py +++ b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py @@ -1,5 +1,15 @@ +from typing import cast + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode class VariableAssignerNode(BaseNode): - pass + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + pass + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + return {}