mirror of https://github.com/langgenius/dify.git
dataset retrival
This commit is contained in:
parent
12eb236364
commit
785dfc5c00
|
|
@ -0,0 +1,194 @@
|
|||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
class DatasetMultiRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="dataset multi retriever and rerank")
|
||||
|
||||
|
||||
class DatasetMultiRetrieverTool(BaseTool):
|
||||
"""Tool for querying multi dataset."""
|
||||
name: str = "dataset-"
|
||||
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
|
||||
description: str = "dataset multi retriever and rerank. "
|
||||
tenant_id: str
|
||||
dataset_ids: list[str]
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
|
||||
return cls(
|
||||
name=f'dataset-{tenant_id}',
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents = []
|
||||
for dataset_id in self.dataset_ids:
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'all_documents': all_documents,
|
||||
'hit_callbacks': self.hit_callbacks
|
||||
})
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=self.reranking_model_name
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(segments,
|
||||
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||
float('inf')))
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=segment.dataset_id
|
||||
).first()
|
||||
document = Document.query.filter(Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'data_source_type': document.data_source_type,
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': self.retriever_from,
|
||||
'score': document_score_list.get(segment.index_node_id, None)
|
||||
}
|
||||
|
||||
if self.retriever_from == 'dev':
|
||||
source['hit_count'] = segment.hit_count
|
||||
source['word_count'] = segment.word_count
|
||||
source['segment_position'] = segment.position
|
||||
source['index_node_hash'] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler]):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
for hit_callback in hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrival_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
class DatasetRetrieverTool(BaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
name: str = "dataset"
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
return cls(
|
||||
name=f'dataset-{dataset.id}',
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == self.dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return ''
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrival_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k
|
||||
)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(segments,
|
||||
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||
float('inf')))
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
context = {}
|
||||
document = Document.query.filter(Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'data_source_type': document.data_source_type,
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': self.retriever_from,
|
||||
'score': document_score_list.get(segment.index_node_id, None)
|
||||
|
||||
}
|
||||
if self.retriever_from == 'dev':
|
||||
source['hit_count'] = segment.hit_count
|
||||
source['word_count'] = segment.word_count
|
||||
source['segment_position'] = segment.position
|
||||
source['index_node_hash'] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
provider: str
|
||||
mode: str
|
||||
|
||||
|
||||
class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Multiple Retrieval Config.
|
||||
"""
|
||||
top_k: int
|
||||
score_threshold: Optional[float]
|
||||
reranking_model: RerankingModelConfig
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
|
||||
class SingleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Single Retrieval Config.
|
||||
"""
|
||||
model: ModelConfig
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
variables: list[VariableSelector]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal['single', 'multiple']
|
||||
multiple_retrieval_config: MultipleRetrievalConfig
|
||||
singleRetrievalConfig: SingleRetrievalConfig
|
||||
|
|
@ -1,13 +1,371 @@
|
|||
import threading
|
||||
from typing import cast, Any
|
||||
|
||||
from flask import current_app, Flask
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment, Document
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
_node_data_cls = KnowledgeRetrievalNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
pass
|
||||
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
||||
|
||||
# extract variables
|
||||
variables = {
|
||||
variable_selector.variable: variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector)
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
# retrieve knowledge
|
||||
try:
|
||||
outputs = self._fetch_dataset_retriever(
|
||||
node_data=node_data, variables=variables
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e)
|
||||
)
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param node_data: node data
|
||||
:param variables: variables
|
||||
"""
|
||||
tools = []
|
||||
available_datasets = []
|
||||
dataset_ids = node_data.dataset_ids
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
all_documents = []
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
all_documents = self._single_retrieve(available_datasets, node_data, variables)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
all_documents = self._multiple_retrieve(available_datasets, node_data, variables)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
context_list = []
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(segments,
|
||||
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||
float('inf')))
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=segment.dataset_id
|
||||
).first()
|
||||
document = Document.query.filter(Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
|
||||
source = {
|
||||
'metadata': {
|
||||
'_source': 'knowledge',
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'document_data_source_type': document.data_source_type,
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': 'workflow',
|
||||
'score': document_score_list.get(segment.index_node_id, None),
|
||||
'segment_hit_count': segment.hit_count,
|
||||
'segment_word_count': segment.word_count,
|
||||
'segment_position': segment.position
|
||||
}
|
||||
}
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
|
||||
return context_list
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
pass
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
return {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
def _single_retrieve(self, available_datasets, node_data, variables):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
message_tool = PromptMessageTool(
|
||||
name=dataset.id,
|
||||
description=description,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
tools.append(message_tool)
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content='You are a helpful AI assistant.'),
|
||||
UserPromptMessage(content=variables['#query#'])
|
||||
]
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
function_call_name = result.message.tool_calls[0].function.name
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == function_call_name
|
||||
).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
# get retrieval method
|
||||
retrival_method = retrieval_model_config['search_method']
|
||||
# get reranking model
|
||||
reranking_model = retrieval_model_config['reranking_model']
|
||||
# get score threshold
|
||||
score_threshold = .0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'],
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
return results
|
||||
|
||||
|
||||
|
||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data.singleRetrievalConfig.model.name
|
||||
provider_name = node_data.singleRetrievalConfig.model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider_name,
|
||||
model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = node_data.singleRetrievalConfig.model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data.singleRetrievalConfig.model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _multiple_retrieve(self, available_datasets, node_data, variables):
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
for dataset in available_datasets:
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset.id,
|
||||
'query': variables['#query#'],
|
||||
'top_k': node_data.multiple_retrieval_config.top_k,
|
||||
'all_documents': all_documents,
|
||||
})
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
model_type=ModelType.RERANK,
|
||||
model=node_data.multiple_retrieval_config.reranking_model.name
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(variables['#query#'], all_documents,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.top_k)
|
||||
|
||||
return all_documents
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrival_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
Loading…
Reference in New Issue