external knowledge api

This commit is contained in:
jyong 2024-09-19 17:07:33 +08:00
parent 37f7d5732a
commit 19c526120c
12 changed files with 304 additions and 170 deletions

View File

@ -37,7 +37,7 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
from .billing import billing
# Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website, test_external
# Import explore controllers
from .explore import (

View File

@ -49,7 +49,7 @@ class DatasetListApi(Resource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
provider = request.args.get("provider", default="vendor")
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
@ -57,7 +57,7 @@ class DatasetListApi(Resource):
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
)
# check embedding setting

View File

@ -1,7 +1,7 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from werkzeug.exceptions import Forbidden, NotFound, InternalServerError
import services
from controllers.console import api
@ -11,7 +11,9 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
def _validate_name(name):
@ -249,6 +251,42 @@ class ExternalDatasetCreateApi(Resource):
return marshal(dataset, dataset_detail_fields), 201
class ExternalKnowledgeHitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
try:
response = HitTestingService.external_retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
)
return response
except Exception as e:
raise InternalServerError(str(e))
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-api-template")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-api-template/<uuid:api_template_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-api-template/<uuid:api_template_id>/use-check")

View File

@ -47,7 +47,7 @@ class HitTestingApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrival_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@ -58,7 +58,7 @@ class HitTestingApi(Resource):
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrival_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)

View File

@ -31,19 +31,24 @@ class TestExternalApi(Resource):
required=True,
type=float,
)
args = parser.parse_args()
result = ExternalDatasetService.test_external_knowledge_retrival(
args["top_k"], args["score_threshold"]
parser.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
response = {
"data": [item.to_dict() for item in api_templates],
"has_more": len(api_templates) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
parser.add_argument(
"external_knowledge_id",
nullable=False,
required=True,
type=str,
)
args = parser.parse_args()
result = ExternalDatasetService.test_external_knowledge_retrieval(
args["top_k"], args["score_threshold"], args["query"], args["external_knowledge_id"]
)
return result, 200
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrival-documents")
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents")

View File

@ -28,11 +28,11 @@ class DatasetListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
provider = request.args.get("provider", default="vendor")
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids)
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)

View File

@ -23,94 +23,110 @@ default_retrieval_model = {
class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model',
weights: Optional[dict] = None, provider: Optional[str] = None,
external_retrieval_model: Optional[dict] = None):
def retrieve(cls,
retrieval_method: str,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None,
reranking_mode: Optional[str] = 'reranking_model',
weights: Optional[dict] = None
):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset:
return []
if provider == 'external':
all_documents = ExternalDatasetService.fetch_external_knowledge_retrival(
dataset.tenant_id,
dataset_id,
query,
external_retrieval_model
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
threads = []
exceptions = []
# retrieval_model source with keyword
if retrieval_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'all_documents': all_documents,
'exceptions': exceptions,
})
threads.append(keyword_thread)
keyword_thread.start()
# retrieval_model source with semantic
if RetrievalMethod.is_support_semantic_search(retrieval_method):
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
'retrieval_method': retrieval_method,
'exceptions': exceptions,
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'retrieval_method': retrieval_method,
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
'all_documents': all_documents,
'exceptions': exceptions,
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if exceptions:
exception_message = ';\n'.join(exceptions)
raise Exception(exception_message)
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
)
else:
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
threads = []
exceptions = []
# retrieval_model source with keyword
if retrival_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'all_documents': all_documents,
'exceptions': exceptions,
})
threads.append(keyword_thread)
keyword_thread.start()
# retrieval_model source with semantic
if RetrievalMethod.is_support_semantic_search(retrival_method):
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
'retrival_method': retrival_method,
'exceptions': exceptions,
})
threads.append(embedding_thread)
embedding_thread.start()
return all_documents
# retrieval source with full text
if RetrievalMethod.is_support_fulltext_search(retrival_method):
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'retrival_method': retrival_method,
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
'all_documents': all_documents,
'exceptions': exceptions,
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if exceptions:
exception_message = ';\n'.join(exceptions)
raise Exception(exception_message)
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
)
return all_documents
@classmethod
def external_retrieve(cls,
dataset_id: str,
query: str,
external_retrieval_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset:
return []
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,
dataset_id,
query,
external_retrieval_model
)
return all_documents
@classmethod
def keyword_search(
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
):
with flask_app.app_context():
try:
@ -125,16 +141,16 @@ class RetrievalService:
@classmethod
def embedding_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
):
with flask_app.app_context():
try:
@ -152,10 +168,10 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
@ -172,16 +188,16 @@ class RetrievalService:
@classmethod
def full_text_index_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
):
with flask_app.app_context():
try:
@ -194,10 +210,10 @@ class RetrievalService:
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False

View File

@ -1,4 +1,4 @@
"""add-dataset-retrival-model
"""add-dataset-retrieval-model
Revision ID: fca025d3b60f
Revises: b3a09c049e8e

View File

@ -58,8 +58,8 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by(
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(
Dataset.created_at.desc()
)

View File

@ -20,9 +20,7 @@ class ProcessStatusSetting(BaseModel):
class ApiTemplateSetting(BaseModel):
method: str
url: str
request_method: str
api_token: str
headers: Optional[dict] = None
params: Optional[dict] = None

View File

@ -15,10 +15,13 @@ from models.dataset import (
ExternalApiTemplates,
ExternalKnowledgeBindings,
)
from core.rag.models.document import Document as RetrievalDocument
from models.model import UploadFile
from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization
from services.errors.dataset import DatasetNameDuplicateError
# from tasks.external_document_indexing_task import external_document_indexing_task
import requests
import boto3
class ExternalDatasetService:
@ -173,7 +176,7 @@ class ExternalDatasetService:
db.session.flush()
document_ids.append(document.id)
db.session.commit()
#external_document_indexing_task.delay(dataset.id, api_template_id, data_source, process_parameter)
# external_document_indexing_task.delay(dataset.id, api_template_id, data_source, process_parameter)
return dataset
@ -189,7 +192,7 @@ class ExternalDatasetService:
"follow_redirects": True,
}
response = getattr(ssrf_proxy, settings.request_method)(data=settings.params, files=files, **kwargs)
response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs)
return response
@ -260,9 +263,9 @@ class ExternalDatasetService:
return dataset
@staticmethod
def fetch_external_knowledge_retrival(
tenant_id: str, dataset_id: str, query: str, external_retrival_parameters: dict
):
def fetch_external_knowledge_retrieval(
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
@ -276,33 +279,58 @@ class ExternalDatasetService:
raise ValueError("external api template not found")
settings = json.loads(external_api_template.settings)
headers = {}
if settings.get("api_token"):
headers["Authorization"] = f"Bearer {settings.get('api_token')}"
headers = {
"Content-Type": "application/json"
}
if settings.get("api_key"):
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
external_retrival_parameters["query"] = query
external_retrieval_parameters["query"] = query
external_retrieval_parameters["external_knowledge_id"] = external_knowledge_binding.external_knowledge_id
api_template_setting = {
"url": f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents",
"url": f"{settings.get('endpoint')}/dify/external-knowledge/retrieval-documents",
"request_method": "post",
"headers": settings.get("headers"),
"params": external_retrival_parameters,
"headers": headers,
"params": external_retrieval_parameters,
}
response = ExternalDatasetService.process_external_api(ApiTemplateSetting(**api_template_setting), None)
if response.status_code == 200:
return response.json()
return []
@staticmethod
def test_external_knowledge_retrival(
top_k: int, score_threshold: float
def test_external_knowledge_retrieval(
top_k: int, score_threshold: float, query: str, external_knowledge_id: str
):
api_template_setting = {
"url": f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents",
"request_method": "post",
"headers": settings.get("headers"),
"params": {
"top_k": top_k,
"score_threshold": score_threshold,
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key='',
aws_access_key_id='',
region_name='',
)
response = client.retrieve(
knowledgeBaseId=external_knowledge_id,
retrievalConfiguration={
'vectorSearchConfiguration': {
'numberOfResults': top_k,
'overrideSearchType': 'HYBRID'
}
},
}
response = ExternalDatasetService.process_external_api(ApiTemplateSetting(**api_template_setting), None)
return response.json()
retrievalQuery={
'text': query
}
)
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
result = {
"metadata": retrieval_result.get("metadata"),
"score": retrieval_result.get("score"),
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
"content": retrieval_result.get("content").get("text"),
}
results.append(result)
return results

View File

@ -20,15 +20,15 @@ default_retrieval_model = {
class HitTestingService:
@classmethod
def retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
retrieval_model: dict,
external_retrieval_model: dict,
limit: int = 10,
cls,
dataset: Dataset,
query: str,
account: Account,
retrieval_model: dict,
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
if (dataset.available_document_count == 0 or dataset.available_segment_count == 0):
return {
"query": {
"content": query,
@ -56,8 +56,6 @@ class HitTestingService:
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
provider=dataset.provider,
external_retrieval_model=external_retrieval_model,
)
end = time.perf_counter()
@ -72,10 +70,45 @@ class HitTestingService:
return cls.compact_retrieve_response(dataset, query, all_documents)
@classmethod
def external_retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
external_retrieval_model: dict,
) -> dict:
if dataset.provider != "external":
return {
"query": {
"content": query},
"records": [],
}
start = time.perf_counter()
all_documents = RetrievalService.external_retrieve(
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
external_retrieval_model=external_retrieval_model,
)
end = time.perf_counter()
logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds")
dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
)
db.session.add(dataset_query)
db.session.commit()
return cls.compact_external_retrieve_response(dataset, query, all_documents)
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
i = 0
records = []
for document in documents:
index_node_id = document.metadata["doc_id"]
@ -91,7 +124,6 @@ class HitTestingService:
)
if not segment:
i += 1
continue
record = {
@ -101,8 +133,6 @@ class HitTestingService:
records.append(record)
i += 1
return {
"query": {
"content": query,
@ -110,6 +140,25 @@ class HitTestingService:
"records": records,
}
@classmethod
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list):
records = []
if dataset.provider == "external":
for document in documents:
record = {
"content": document.get("content", None),
"title": document.get("title", None),
"score": document.get("score", None),
"metadata": document.get("metadata", None),
}
records.append(record)
return {
"query": {
"content": query,
},
"records": records,
}
@classmethod
def hit_testing_args_check(cls, args):
query = args["query"]