From 785dfc5c0085cc93e64480d5f03b7eaf00c1b57c Mon Sep 17 00:00:00 2001 From: jyong Date: Fri, 15 Mar 2024 14:40:53 +0800 Subject: [PATCH] dataset retrival --- .../dataset_multi_retriever_tool.py | 194 ++++++++++ .../dataset_retriever_tool.py | 159 ++++++++ .../nodes/knowledge_retrieval/entities.py | 52 +++ .../knowledge_retrieval.py | 0 .../knowledge_retrieval_node.py | 364 +++++++++++++++++- 5 files changed, 766 insertions(+), 3 deletions(-) create mode 100644 api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/entities.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py b/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py new file mode 100644 index 0000000000..d9934acff9 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py @@ -0,0 +1,194 @@ +import threading +from typing import Optional + +from flask import Flask, current_app +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.retrieval_service import RetrievalService +from core.rerank.rerank import RerankRunner +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class DatasetMultiRetrieverToolInput(BaseModel): + query: str = Field(..., description="dataset multi retriever and rerank") + + +class DatasetMultiRetrieverTool(BaseTool): + """Tool for querying multi dataset.""" + name: str = "dataset-" + args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput + description: str = "dataset multi retriever and rerank. " + tenant_id: str + dataset_ids: list[str] + top_k: int = 2 + score_threshold: Optional[float] = None + reranking_provider_name: str + reranking_model_name: str + return_resource: bool + retriever_from: str + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + + @classmethod + def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): + return cls( + name=f'dataset-{tenant_id}', + tenant_id=tenant_id, + dataset_ids=dataset_ids, + **kwargs + ) + + def _run(self, query: str) -> str: + threads = [] + all_documents = [] + for dataset_id in self.dataset_ids: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'all_documents': all_documents, + 'hit_callbacks': self.hit_callbacks + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=self.reranking_provider_name, + model_type=ModelType.RERANK, + model=self.reranking_model_name + ) + + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(all_documents) + + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from, + 'score': document_score_list.get(segment.index_node_id, None) + } + + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError() + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler]): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + for hit_callback in hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) + if documents: + all_documents.extend(documents) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + + all_documents.extend(documents) \ No newline at end of file diff --git a/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py b/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py new file mode 100644 index 0000000000..13331d981b --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py @@ -0,0 +1,159 @@ +from typing import Optional + +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.rag.datasource.retrieval_service import RetrievalService +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class DatasetRetrieverToolInput(BaseModel): + query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") + + +class DatasetRetrieverTool(BaseTool): + """Tool for querying a Dataset.""" + name: str = "dataset" + args_schema: type[BaseModel] = DatasetRetrieverToolInput + description: str = "use this to retrieve a dataset. " + + tenant_id: str + dataset_id: str + top_k: int = 2 + score_threshold: Optional[float] = None + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + return_resource: bool + retriever_from: str + + @classmethod + def from_dataset(cls, dataset: Dataset, **kwargs): + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + return cls( + name=f'dataset-{dataset.id}', + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + description=description, + **kwargs + ) + + def _run(self, query: str) -> str: + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == self.dataset_id + ).first() + + if not dataset: + return '' + + for hit_callback in self.hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) + return str("\n".join([document.page_content for document in documents])) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + else: + documents = [] + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in documents] + segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + context = {} + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from, + 'score': document_score_list.get(segment.index_node_id, None) + + } + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError() diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py new file mode 100644 index 0000000000..905ee1f80d --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -0,0 +1,52 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + provider: str + mode: str + + +class MultipleRetrievalConfig(BaseModel): + """ + Multiple Retrieval Config. + """ + top_k: int + score_threshold: Optional[float] + reranking_model: RerankingModelConfig + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + + +class SingleRetrievalConfig(BaseModel): + """ + Single Retrieval Config. + """ + model: ModelConfig + + +class KnowledgeRetrievalNodeData(BaseNodeData): + """ + Knowledge retrieval Node Data. + """ + variables: list[VariableSelector] + dataset_ids: list[str] + retrieval_mode: Literal['single', 'multiple'] + multiple_retrieval_config: MultipleRetrievalConfig + singleRetrievalConfig: SingleRetrievalConfig diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 7b8344418b..1ccdbf971c 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,13 +1,371 @@ +import threading +from typing import cast, Any + +from flask import current_app, Flask + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.datasource.retrieval_service import RetrievalService +from core.rerank.rerank import RerankRunner from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment, Document +from models.workflow import WorkflowNodeExecutionStatus +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} class KnowledgeRetrievalNode(BaseNode): + + _node_data_cls = KnowledgeRetrievalNodeData + _node_type = NodeType.TOOL + def _run(self, variable_pool: VariablePool) -> NodeRunResult: - pass + node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) + + # extract variables + variables = { + variable_selector.variable: variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector) + for variable_selector in node_data.variables + } + + # retrieve knowledge + try: + outputs = self._fetch_dataset_retriever( + node_data=node_data, variables=variables + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs + ) + + except Exception as e: + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e) + ) + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]: + """ + A dataset tool is a tool that can be used to retrieve information from a dataset + :param node_data: node data + :param variables: variables + """ + tools = [] + available_datasets = [] + dataset_ids = node_data.dataset_ids + for dataset_id in dataset_ids: + # get dataset from dataset id + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + # pass if dataset is not available + if not dataset: + continue + + # pass if dataset is not available + if (dataset and dataset.available_document_count == 0 + and dataset.available_document_count == 0): + continue + + available_datasets.append(dataset) + all_documents = [] + if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + all_documents = self._single_retrieve(available_datasets, node_data, variables) + elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + all_documents = self._multiple_retrieve(available_datasets, node_data, variables) + + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + context_list = [] + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + + source = { + 'metadata': { + '_source': 'knowledge', + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'document_data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': 'workflow', + 'score': document_score_list.get(segment.index_node_id, None), + 'segment_hit_count': segment.hit_count, + 'segment_word_count': segment.word_count, + 'segment_position': segment.position + } + } + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + + return context_list @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - pass + node_data = node_data + node_data = cast(cls._node_data_cls, node_data) + return { + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } + + def _single_retrieve(self, available_datasets, node_data, variables): + tools = [] + for dataset in available_datasets: + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + message_tool = PromptMessageTool( + name=dataset.id, + description=description, + parameters={ + "type": "object", + "properties": {}, + "required": [], + } + ) + tools.append(message_tool) + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + prompt_messages = [ + SystemPromptMessage(content='You are a helpful AI assistant.'), + UserPromptMessage(content=variables['#query#']) + ] + result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=tools, + stream=False, + model_parameters={ + 'temperature': 0.2, + 'top_p': 0.3, + 'max_tokens': 1500 + } + ) + + if result.message.tool_calls: + # get retrieval model config + function_call_name = result.message.tool_calls[0].function.name + dataset = db.session.query(Dataset).filter( + Dataset.id == function_call_name + ).first() + if dataset: + retrieval_model_config = dataset.retrieval_model \ + if dataset.retrieval_model else default_retrieval_model + + # get top k + top_k = retrieval_model_config['top_k'] + # get retrieval method + retrival_method = retrieval_model_config['search_method'] + # get reranking model + reranking_model = retrieval_model_config['reranking_model'] + # get score threshold + score_threshold = .0 + score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") + if score_threshold_enabled: + score_threshold = retrieval_model_config.get("score_threshold") + + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'], + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + return results + + + + def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + model_name = node_data.singleRetrievalConfig.model.name + provider_name = node_data.singleRetrievalConfig.model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=provider_name, + model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, + model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data.singleRetrievalConfig.model.completion_params + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = node_data.singleRetrievalConfig.model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema( + model_name, + model_credentials + ) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _multiple_retrieve(self, available_datasets, node_data, variables): + threads = [] + all_documents = [] + dataset_ids = [dataset.id for dataset in available_datasets] + for dataset in available_datasets: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset.id, + 'query': variables['#query#'], + 'top_k': node_data.multiple_retrieval_config.top_k, + 'all_documents': all_documents, + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=node_data.multiple_retrieval_config.reranking_model.provider, + model_type=ModelType.RERANK, + model=node_data.multiple_retrieval_config.reranking_model.name + ) + + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(variables['#query#'], all_documents, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.top_k) + + return all_documents + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=top_k + ) + if documents: + all_documents.extend(documents) + else: + if top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + + all_documents.extend(documents) \ No newline at end of file