diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1ccdbf971c..a501113dc3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -33,10 +33,10 @@ default_retrieval_model = { 'score_threshold_enabled': False } -class KnowledgeRetrievalNode(BaseNode): +class KnowledgeRetrievalNode(BaseNode): _node_data_cls = KnowledgeRetrievalNodeData - _node_type = NodeType.TOOL + _node_type = NodeType.KNOWLEDGE_RETRIEVAL def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) @@ -67,7 +67,9 @@ class KnowledgeRetrievalNode(BaseNode): inputs=variables, error=str(e) ) - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]: + + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[ + dict[str, Any]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param node_data: node data @@ -224,14 +226,14 @@ class KnowledgeRetrievalNode(BaseNode): if score_threshold_enabled: score_threshold = retrieval_model_config.get("score_threshold") - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'], - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, + query=variables['#query#'], + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) return results - - - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ + ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data: node data @@ -333,7 +335,7 @@ class KnowledgeRetrievalNode(BaseNode): return all_documents - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): with flask_app.app_context(): dataset = db.session.query(Dataset).filter( Dataset.tenant_id == self.tenant_id, @@ -368,4 +370,4 @@ class KnowledgeRetrievalNode(BaseNode): if retrieval_model['reranking_enable'] else None ) - all_documents.extend(documents) \ No newline at end of file + all_documents.extend(documents)