mirror of https://github.com/langgenius/dify.git
external knowledge api
This commit is contained in:
parent
37f7d5732a
commit
19c526120c
|
|
@ -37,7 +37,7 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
|
|||
from .billing import billing
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website, test_external
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class DatasetListApi(Resource):
|
|||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
provider = request.args.get("provider", default="vendor")
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ class DatasetListApi(Resource):
|
|||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import Forbidden, NotFound, InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
|
|
@ -11,7 +11,9 @@ from controllers.console.setup import setup_required
|
|||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
|
|
@ -249,6 +251,42 @@ class ExternalDatasetCreateApi(Resource):
|
|||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
response = HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
|
||||
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-api-template")
|
||||
api.add_resource(ExternalApiTemplateApi, "/datasets/external-api-template/<uuid:api_template_id>")
|
||||
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-api-template/<uuid:api_template_id>/use-check")
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class HitTestingApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrival_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
|
@ -58,7 +58,7 @@ class HitTestingApi(Resource):
|
|||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrival_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,19 +31,24 @@ class TestExternalApi(Resource):
|
|||
required=True,
|
||||
type=float,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
result = ExternalDatasetService.test_external_knowledge_retrival(
|
||||
args["top_k"], args["score_threshold"]
|
||||
parser.add_argument(
|
||||
"query",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
response = {
|
||||
"data": [item.to_dict() for item in api_templates],
|
||||
"has_more": len(api_templates) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
result = ExternalDatasetService.test_external_knowledge_retrieval(
|
||||
args["top_k"], args["score_threshold"], args["query"], args["external_knowledge_id"]
|
||||
)
|
||||
return result, 200
|
||||
|
||||
|
||||
|
||||
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrival-documents")
|
||||
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents")
|
||||
|
|
|
|||
|
|
@ -28,11 +28,11 @@ class DatasetListApi(DatasetApiResource):
|
|||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
provider = request.args.get("provider", default="vendor")
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
|
|
|
|||
|
|
@ -23,94 +23,110 @@ default_retrieval_model = {
|
|||
|
||||
class RetrievalService:
|
||||
@classmethod
|
||||
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float] = .0,
|
||||
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model',
|
||||
weights: Optional[dict] = None, provider: Optional[str] = None,
|
||||
external_retrieval_model: Optional[dict] = None):
|
||||
def retrieve(cls,
|
||||
retrieval_method: str,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float] = .0,
|
||||
reranking_model: Optional[dict] = None,
|
||||
reranking_mode: Optional[str] = 'reranking_model',
|
||||
weights: Optional[dict] = None
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
return []
|
||||
if provider == 'external':
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrival(
|
||||
dataset.tenant_id,
|
||||
dataset_id,
|
||||
query,
|
||||
external_retrieval_model
|
||||
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
all_documents = []
|
||||
threads = []
|
||||
exceptions = []
|
||||
# retrieval_model source with keyword
|
||||
if retrieval_method == 'keyword_search':
|
||||
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'all_documents': all_documents,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
# retrieval_model source with semantic
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'score_threshold': score_threshold,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'retrieval_method': retrieval_method,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'retrieval_method': retrieval_method,
|
||||
'score_threshold': score_threshold,
|
||||
'top_k': top_k,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if exceptions:
|
||||
exception_message = ';\n'.join(exceptions)
|
||||
raise Exception(exception_message)
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
|
||||
reranking_model, weights, False)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k
|
||||
)
|
||||
else:
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
all_documents = []
|
||||
threads = []
|
||||
exceptions = []
|
||||
# retrieval_model source with keyword
|
||||
if retrival_method == 'keyword_search':
|
||||
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'all_documents': all_documents,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
# retrieval_model source with semantic
|
||||
if RetrievalMethod.is_support_semantic_search(retrival_method):
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'score_threshold': score_threshold,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'retrival_method': retrival_method,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
return all_documents
|
||||
|
||||
# retrieval source with full text
|
||||
if RetrievalMethod.is_support_fulltext_search(retrival_method):
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'retrival_method': retrival_method,
|
||||
'score_threshold': score_threshold,
|
||||
'top_k': top_k,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if exceptions:
|
||||
exception_message = ';\n'.join(exceptions)
|
||||
raise Exception(exception_message)
|
||||
|
||||
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
|
||||
reranking_model, weights, False)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k
|
||||
)
|
||||
return all_documents
|
||||
@classmethod
|
||||
def external_retrieve(cls,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_model: Optional[dict] = None):
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
return []
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id,
|
||||
dataset_id,
|
||||
query,
|
||||
external_retrieval_model
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
|
|
@ -125,16 +141,16 @@ class RetrievalService:
|
|||
|
||||
@classmethod
|
||||
def embedding_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
|
|
@ -152,10 +168,10 @@ class RetrievalService:
|
|||
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
|
|
@ -172,16 +188,16 @@ class RetrievalService:
|
|||
|
||||
@classmethod
|
||||
def full_text_index_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
|
|
@ -194,10 +210,10 @@ class RetrievalService:
|
|||
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""add-dataset-retrival-model
|
||||
"""add-dataset-retrieval-model
|
||||
|
||||
Revision ID: fca025d3b60f
|
||||
Revises: b3a09c049e8e
|
||||
|
|
|
|||
|
|
@ -58,8 +58,8 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
|
|||
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
|
||||
query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by(
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None):
|
||||
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(
|
||||
Dataset.created_at.desc()
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,9 +20,7 @@ class ProcessStatusSetting(BaseModel):
|
|||
|
||||
|
||||
class ApiTemplateSetting(BaseModel):
|
||||
method: str
|
||||
url: str
|
||||
request_method: str
|
||||
api_token: str
|
||||
headers: Optional[dict] = None
|
||||
params: Optional[dict] = None
|
||||
|
|
|
|||
|
|
@ -15,10 +15,13 @@ from models.dataset import (
|
|||
ExternalApiTemplates,
|
||||
ExternalKnowledgeBindings,
|
||||
)
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from models.model import UploadFile
|
||||
from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
# from tasks.external_document_indexing_task import external_document_indexing_task
|
||||
import requests
|
||||
import boto3
|
||||
|
||||
|
||||
class ExternalDatasetService:
|
||||
|
|
@ -173,7 +176,7 @@ class ExternalDatasetService:
|
|||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
db.session.commit()
|
||||
#external_document_indexing_task.delay(dataset.id, api_template_id, data_source, process_parameter)
|
||||
# external_document_indexing_task.delay(dataset.id, api_template_id, data_source, process_parameter)
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
@ -189,7 +192,7 @@ class ExternalDatasetService:
|
|||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
response = getattr(ssrf_proxy, settings.request_method)(data=settings.params, files=files, **kwargs)
|
||||
response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs)
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -260,9 +263,9 @@ class ExternalDatasetService:
|
|||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def fetch_external_knowledge_retrival(
|
||||
tenant_id: str, dataset_id: str, query: str, external_retrival_parameters: dict
|
||||
):
|
||||
def fetch_external_knowledge_retrieval(
|
||||
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
|
||||
) -> list:
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id, tenant_id=tenant_id
|
||||
).first()
|
||||
|
|
@ -276,33 +279,58 @@ class ExternalDatasetService:
|
|||
raise ValueError("external api template not found")
|
||||
|
||||
settings = json.loads(external_api_template.settings)
|
||||
headers = {}
|
||||
if settings.get("api_token"):
|
||||
headers["Authorization"] = f"Bearer {settings.get('api_token')}"
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
if settings.get("api_key"):
|
||||
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
|
||||
|
||||
external_retrival_parameters["query"] = query
|
||||
external_retrieval_parameters["query"] = query
|
||||
external_retrieval_parameters["external_knowledge_id"] = external_knowledge_binding.external_knowledge_id
|
||||
|
||||
api_template_setting = {
|
||||
"url": f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents",
|
||||
"url": f"{settings.get('endpoint')}/dify/external-knowledge/retrieval-documents",
|
||||
"request_method": "post",
|
||||
"headers": settings.get("headers"),
|
||||
"params": external_retrival_parameters,
|
||||
"headers": headers,
|
||||
"params": external_retrieval_parameters,
|
||||
}
|
||||
response = ExternalDatasetService.process_external_api(ApiTemplateSetting(**api_template_setting), None)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def test_external_knowledge_retrival(
|
||||
top_k: int, score_threshold: float
|
||||
def test_external_knowledge_retrieval(
|
||||
top_k: int, score_threshold: float, query: str, external_knowledge_id: str
|
||||
):
|
||||
api_template_setting = {
|
||||
"url": f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents",
|
||||
"request_method": "post",
|
||||
"headers": settings.get("headers"),
|
||||
"params": {
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
client = boto3.client(
|
||||
"bedrock-agent-runtime",
|
||||
aws_secret_access_key='',
|
||||
aws_access_key_id='',
|
||||
region_name='',
|
||||
)
|
||||
response = client.retrieve(
|
||||
knowledgeBaseId=external_knowledge_id,
|
||||
retrievalConfiguration={
|
||||
'vectorSearchConfiguration': {
|
||||
'numberOfResults': top_k,
|
||||
'overrideSearchType': 'HYBRID'
|
||||
}
|
||||
},
|
||||
}
|
||||
response = ExternalDatasetService.process_external_api(ApiTemplateSetting(**api_template_setting), None)
|
||||
return response.json()
|
||||
retrievalQuery={
|
||||
'text': query
|
||||
}
|
||||
)
|
||||
results = []
|
||||
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
|
||||
if response.get("retrievalResults"):
|
||||
retrieval_results = response.get("retrievalResults")
|
||||
for retrieval_result in retrieval_results:
|
||||
result = {
|
||||
"metadata": retrieval_result.get("metadata"),
|
||||
"score": retrieval_result.get("score"),
|
||||
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
|
||||
"content": retrieval_result.get("content").get("text"),
|
||||
}
|
||||
results.append(result)
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -20,15 +20,15 @@ default_retrieval_model = {
|
|||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
retrieval_model: dict,
|
||||
external_retrieval_model: dict,
|
||||
limit: int = 10,
|
||||
cls,
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
retrieval_model: dict,
|
||||
external_retrieval_model: dict,
|
||||
limit: int = 10,
|
||||
) -> dict:
|
||||
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
if (dataset.available_document_count == 0 or dataset.available_segment_count == 0):
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
|
|
@ -56,8 +56,6 @@ class HitTestingService:
|
|||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
provider=dataset.provider,
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
|
@ -72,10 +70,45 @@ class HitTestingService:
|
|||
|
||||
return cls.compact_retrieve_response(dataset, query, all_documents)
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(
|
||||
cls,
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
external_retrieval_model: dict,
|
||||
) -> dict:
|
||||
if dataset.provider != "external":
|
||||
return {
|
||||
"query": {
|
||||
"content": query},
|
||||
"records": [],
|
||||
}
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
all_documents = RetrievalService.external_retrieve(
|
||||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds")
|
||||
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_external_retrieve_response(dataset, query, all_documents)
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
|
||||
i = 0
|
||||
records = []
|
||||
|
||||
for document in documents:
|
||||
index_node_id = document.metadata["doc_id"]
|
||||
|
||||
|
|
@ -91,7 +124,6 @@ class HitTestingService:
|
|||
)
|
||||
|
||||
if not segment:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
record = {
|
||||
|
|
@ -101,8 +133,6 @@ class HitTestingService:
|
|||
|
||||
records.append(record)
|
||||
|
||||
i += 1
|
||||
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
|
|
@ -110,6 +140,25 @@ class HitTestingService:
|
|||
"records": records,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list):
|
||||
records = []
|
||||
if dataset.provider == "external":
|
||||
for document in documents:
|
||||
record = {
|
||||
"content": document.get("content", None),
|
||||
"title": document.get("title", None),
|
||||
"score": document.get("score", None),
|
||||
"metadata": document.get("metadata", None),
|
||||
}
|
||||
records.append(record)
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
},
|
||||
"records": records,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args["query"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue