mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 20:17:29 +08:00
knowledge fix
This commit is contained in:
parent
4a483a8754
commit
d5a404236a
@ -3,6 +3,7 @@ from typing import Any, Literal, Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
|
|
||||||
|
|
||||||
class RerankingModelConfig(BaseModel):
|
class RerankingModelConfig(BaseModel):
|
||||||
@ -44,7 +45,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
|||||||
Knowledge retrieval Node Data.
|
Knowledge retrieval Node Data.
|
||||||
"""
|
"""
|
||||||
type: str = 'knowledge-retrieval'
|
type: str = 'knowledge-retrieval'
|
||||||
query_variable_selector: list[str]
|
query_variable_selector: VariableSelector
|
||||||
dataset_ids: list[str]
|
dataset_ids: list[str]
|
||||||
retrieval_mode: Literal['single', 'multiple']
|
retrieval_mode: Literal['single', 'multiple']
|
||||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||||
|
|||||||
@ -43,9 +43,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
||||||
|
|
||||||
# extract variables
|
# extract variables
|
||||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector.value_selector)
|
||||||
variables = {
|
variables = {
|
||||||
'_query': query
|
node_data.query_variable_selector.variable: query
|
||||||
}
|
}
|
||||||
# retrieve knowledge
|
# retrieve knowledge
|
||||||
try:
|
try:
|
||||||
@ -170,7 +170,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
node_data = node_data
|
node_data = node_data
|
||||||
node_data = cast(cls._node_data_cls, node_data)
|
node_data = cast(cls._node_data_cls, node_data)
|
||||||
variable_mapping = {}
|
variable_mapping = {}
|
||||||
variable_mapping['_query'] = node_data.query_variable_selector
|
variable_mapping[node_data.query_variable_selector.variable] = node_data.query_variable_selector.value_selector
|
||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
def _single_retrieve(self, available_datasets, node_data, query):
|
def _single_retrieve(self, available_datasets, node_data, query):
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import Any
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
@ -42,7 +43,7 @@ class QuestionClassifierNodeData(BaseNodeData):
|
|||||||
"""
|
"""
|
||||||
Knowledge retrieval Node Data.
|
Knowledge retrieval Node Data.
|
||||||
"""
|
"""
|
||||||
query_variable_selector: list[str]
|
query_variable_selector: VariableSelector
|
||||||
type: str = 'question-classifier'
|
type: str = 'question-classifier'
|
||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
classes: list[ClassConfig]
|
classes: list[ClassConfig]
|
||||||
|
|||||||
@ -43,9 +43,9 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||||
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
|
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
|
||||||
# extract variables
|
# extract variables
|
||||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector.value_selector)
|
||||||
variables = {
|
variables = {
|
||||||
'_query': query
|
node_data.query_variable_selector.variable: query
|
||||||
}
|
}
|
||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = self._fetch_model_config(node_data)
|
model_instance, model_config = self._fetch_model_config(node_data)
|
||||||
@ -104,7 +104,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||||
node_data = node_data
|
node_data = node_data
|
||||||
node_data = cast(cls._node_data_cls, node_data)
|
node_data = cast(cls._node_data_cls, node_data)
|
||||||
variable_mapping = {'_query': node_data.query_variable_selector}
|
variable_mapping = {node_data.query_variable_selector.variable: node_data.query_variable_selector.value_selector}
|
||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user