knowledge fix

This commit is contained in:
jyong 2024-03-18 20:54:50 +08:00
parent 4a483a8754
commit d5a404236a
4 changed files with 10 additions and 8 deletions

View File

@ -3,6 +3,7 @@ from typing import Any, Literal, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class RerankingModelConfig(BaseModel): class RerankingModelConfig(BaseModel):
@ -44,7 +45,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
Knowledge retrieval Node Data. Knowledge retrieval Node Data.
""" """
type: str = 'knowledge-retrieval' type: str = 'knowledge-retrieval'
query_variable_selector: list[str] query_variable_selector: VariableSelector
dataset_ids: list[str] dataset_ids: list[str]
retrieval_mode: Literal['single', 'multiple'] retrieval_mode: Literal['single', 'multiple']
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None

View File

@ -43,9 +43,9 @@ class KnowledgeRetrievalNode(BaseNode):
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
# extract variables # extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector.value_selector)
variables = { variables = {
'_query': query node_data.query_variable_selector.variable: query
} }
# retrieve knowledge # retrieve knowledge
try: try:
@ -170,7 +170,7 @@ class KnowledgeRetrievalNode(BaseNode):
node_data = node_data node_data = node_data
node_data = cast(cls._node_data_cls, node_data) node_data = cast(cls._node_data_cls, node_data)
variable_mapping = {} variable_mapping = {}
variable_mapping['_query'] = node_data.query_variable_selector variable_mapping[node_data.query_variable_selector.variable] = node_data.query_variable_selector.value_selector
return variable_mapping return variable_mapping
def _single_retrieve(self, available_datasets, node_data, query): def _single_retrieve(self, available_datasets, node_data, query):

View File

@ -3,6 +3,7 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
@ -42,7 +43,7 @@ class QuestionClassifierNodeData(BaseNodeData):
""" """
Knowledge retrieval Node Data. Knowledge retrieval Node Data.
""" """
query_variable_selector: list[str] query_variable_selector: VariableSelector
type: str = 'question-classifier' type: str = 'question-classifier'
model: ModelConfig model: ModelConfig
classes: list[ClassConfig] classes: list[ClassConfig]

View File

@ -43,9 +43,9 @@ class QuestionClassifierNode(BaseNode):
def _run(self, variable_pool: VariablePool) -> NodeRunResult: def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
# extract variables # extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector.value_selector)
variables = { variables = {
'_query': query node_data.query_variable_selector.variable: query
} }
# fetch model config # fetch model config
model_instance, model_config = self._fetch_model_config(node_data) model_instance, model_config = self._fetch_model_config(node_data)
@ -104,7 +104,7 @@ class QuestionClassifierNode(BaseNode):
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
node_data = node_data node_data = node_data
node_data = cast(cls._node_data_cls, node_data) node_data = cast(cls._node_data_cls, node_data)
variable_mapping = {'_query': node_data.query_variable_selector} variable_mapping = {node_data.query_variable_selector.variable: node_data.query_variable_selector.value_selector}
return variable_mapping return variable_mapping
@classmethod @classmethod