From 19c526120c7d094fe47e0d10244a24394029b138 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 19 Sep 2024 17:07:33 +0800 Subject: [PATCH] external knowledge api --- api/controllers/console/__init__.py | 2 +- api/controllers/console/datasets/datasets.py | 4 +- api/controllers/console/datasets/external.py | 40 ++- .../console/datasets/hit_testing.py | 4 +- .../console/datasets/test_external.py | 29 ++- .../service_api/dataset/dataset.py | 4 +- api/core/rag/datasource/retrieval_service.py | 228 ++++++++++-------- ...fca025d3b60f_add_dataset_retrival_model.py | 2 +- api/services/dataset_service.py | 4 +- .../external_knowledge_entities.py | 2 - api/services/external_knowledge_service.py | 78 ++++-- api/services/hit_testing_service.py | 77 ++++-- 12 files changed, 304 insertions(+), 170 deletions(-) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 77358acedb..a7dd97f51e 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -37,7 +37,7 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p from .billing import billing # Import datasets controllers -from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website +from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website, test_external # Import explore controllers from .explore import ( diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index b601014554..ebc5d31e7e 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -49,7 +49,7 @@ class DatasetListApi(Resource): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) ids = request.args.getlist("ids") - provider = request.args.get("provider", default="vendor") + # provider = request.args.get("provider", default="vendor") search = request.args.get("keyword", default=None, type=str) tag_ids = request.args.getlist("tag_ids") @@ -57,7 +57,7 @@ class DatasetListApi(Resource): datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) else: datasets, total = DatasetService.get_datasets( - page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids + page, limit, current_user.current_tenant_id, current_user, search, tag_ids ) # check embedding setting diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 397d602d28..b1c76375e4 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,7 +1,7 @@ from flask import request from flask_login import current_user from flask_restful import Resource, marshal, reqparse -from werkzeug.exceptions import Forbidden, NotFound +from werkzeug.exceptions import Forbidden, NotFound, InternalServerError import services from controllers.console import api @@ -11,7 +11,9 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.dataset_fields import dataset_detail_fields from libs.login import login_required +from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService +from services.hit_testing_service import HitTestingService def _validate_name(name): @@ -249,6 +251,42 @@ class ExternalDatasetCreateApi(Resource): return marshal(dataset, dataset_detail_fields), 201 +class ExternalKnowledgeHitTestingApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + parser = reqparse.RequestParser() + parser.add_argument("query", type=str, location="json") + parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") + args = parser.parse_args() + + HitTestingService.hit_testing_args_check(args) + + try: + response = HitTestingService.external_retrieve( + dataset=dataset, + query=args["query"], + account=current_user, + external_retrieval_model=args["external_retrieval_model"], + ) + + return response + except Exception as e: + raise InternalServerError(str(e)) + + +api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets//external-hit-testing") api.add_resource(ExternalApiTemplateListApi, "/datasets/external-api-template") api.add_resource(ExternalApiTemplateApi, "/datasets/external-api-template/") api.add_resource(ExternalApiUseCheckApi, "/datasets/external-api-template//use-check") diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index cf2e8af1a2..6e6d8c0bd7 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -47,7 +47,7 @@ class HitTestingApi(Resource): parser = reqparse.RequestParser() parser.add_argument("query", type=str, location="json") parser.add_argument("retrieval_model", type=dict, required=False, location="json") - parser.add_argument("external_retrival_model", type=dict, required=False, location="json") + parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") args = parser.parse_args() HitTestingService.hit_testing_args_check(args) @@ -58,7 +58,7 @@ class HitTestingApi(Resource): query=args["query"], account=current_user, retrieval_model=args["retrieval_model"], - external_retrieval_model=args["external_retrival_model"], + external_retrieval_model=args["external_retrieval_model"], limit=10, ) diff --git a/api/controllers/console/datasets/test_external.py b/api/controllers/console/datasets/test_external.py index 7772597839..7c46be6533 100644 --- a/api/controllers/console/datasets/test_external.py +++ b/api/controllers/console/datasets/test_external.py @@ -31,19 +31,24 @@ class TestExternalApi(Resource): required=True, type=float, ) - args = parser.parse_args() - result = ExternalDatasetService.test_external_knowledge_retrival( - args["top_k"], args["score_threshold"] + parser.add_argument( + "query", + nullable=False, + required=True, + type=str, ) - response = { - "data": [item.to_dict() for item in api_templates], - "has_more": len(api_templates) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response, 200 + parser.add_argument( + "external_knowledge_id", + nullable=False, + required=True, + type=str, + ) + args = parser.parse_args() + result = ExternalDatasetService.test_external_knowledge_retrieval( + args["top_k"], args["score_threshold"], args["query"], args["external_knowledge_id"] + ) + return result, 200 -api.add_resource(TestExternalApi, "/dify/external-knowledge/retrival-documents") +api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 90d62d7c7f..7483b4b4d6 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -28,11 +28,11 @@ class DatasetListApi(DatasetApiResource): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - provider = request.args.get("provider", default="vendor") + # provider = request.args.get("provider", default="vendor") search = request.args.get("keyword", default=None, type=str) tag_ids = request.args.getlist("tag_ids") - datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids) # check embedding setting provider_manager = ProviderManager() configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 9819278491..496c3e2678 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -23,94 +23,110 @@ default_retrieval_model = { class RetrievalService: @classmethod - def retrieve(cls, retrival_method: str, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float] = .0, - reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model', - weights: Optional[dict] = None, provider: Optional[str] = None, - external_retrieval_model: Optional[dict] = None): + def retrieve(cls, + retrieval_method: str, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float] = .0, + reranking_model: Optional[dict] = None, + reranking_mode: Optional[str] = 'reranking_model', + weights: Optional[dict] = None + ): dataset = db.session.query(Dataset).filter( Dataset.id == dataset_id ).first() if not dataset: return [] - if provider == 'external': - all_documents = ExternalDatasetService.fetch_external_knowledge_retrival( - dataset.tenant_id, - dataset_id, - query, - external_retrieval_model + + if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: + return [] + all_documents = [] + threads = [] + exceptions = [] + # retrieval_model source with keyword + if retrieval_method == 'keyword_search': + keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'top_k': top_k, + 'all_documents': all_documents, + 'exceptions': exceptions, + }) + threads.append(keyword_thread) + keyword_thread.start() + # retrieval_model source with semantic + if RetrievalMethod.is_support_semantic_search(retrieval_method): + embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'top_k': top_k, + 'score_threshold': score_threshold, + 'reranking_model': reranking_model, + 'all_documents': all_documents, + 'retrieval_method': retrieval_method, + 'exceptions': exceptions, + }) + threads.append(embedding_thread) + embedding_thread.start() + + # retrieval source with full text + if RetrievalMethod.is_support_fulltext_search(retrieval_method): + full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'retrieval_method': retrieval_method, + 'score_threshold': score_threshold, + 'top_k': top_k, + 'reranking_model': reranking_model, + 'all_documents': all_documents, + 'exceptions': exceptions, + }) + threads.append(full_text_index_thread) + full_text_index_thread.start() + + for thread in threads: + thread.join() + + if exceptions: + exception_message = ';\n'.join(exceptions) + raise Exception(exception_message) + + if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, + reranking_model, weights, False) + all_documents = data_post_processor.invoke( + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=top_k ) - else: - if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: - return [] - all_documents = [] - threads = [] - exceptions = [] - # retrieval_model source with keyword - if retrival_method == 'keyword_search': - keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) - threads.append(keyword_thread) - keyword_thread.start() - # retrieval_model source with semantic - if RetrievalMethod.is_support_semantic_search(retrival_method): - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'score_threshold': score_threshold, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'retrival_method': retrival_method, - 'exceptions': exceptions, - }) - threads.append(embedding_thread) - embedding_thread.start() + return all_documents - # retrieval source with full text - if RetrievalMethod.is_support_fulltext_search(retrival_method): - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'retrival_method': retrival_method, - 'score_threshold': score_threshold, - 'top_k': top_k, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) - threads.append(full_text_index_thread) - full_text_index_thread.start() - - for thread in threads: - thread.join() - - if exceptions: - exception_message = ';\n'.join(exceptions) - raise Exception(exception_message) - - if retrival_method == RetrievalMethod.HYBRID_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, - reranking_model, weights, False) - all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k - ) - return all_documents + @classmethod + def external_retrieve(cls, + dataset_id: str, + query: str, + external_retrieval_model: Optional[dict] = None): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + if not dataset: + return [] + all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + dataset.tenant_id, + dataset_id, + query, + external_retrieval_model + ) + return all_documents @classmethod def keyword_search( - cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list + cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list ): with flask_app.app_context(): try: @@ -125,16 +141,16 @@ class RetrievalService: @classmethod def embedding_search( - cls, - flask_app: Flask, - dataset_id: str, - query: str, - top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], - all_documents: list, - retrieval_method: str, - exceptions: list, + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, ): with flask_app.app_context(): try: @@ -152,10 +168,10 @@ class RetrievalService: if documents: if ( - reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") - and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value ): data_post_processor = DataPostProcessor( str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False @@ -172,16 +188,16 @@ class RetrievalService: @classmethod def full_text_index_search( - cls, - flask_app: Flask, - dataset_id: str, - query: str, - top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], - all_documents: list, - retrieval_method: str, - exceptions: list, + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, ): with flask_app.app_context(): try: @@ -194,10 +210,10 @@ class RetrievalService: documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) if documents: if ( - reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") - and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value ): data_post_processor = DataPostProcessor( str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False diff --git a/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py index 1f8250c3eb..52495be60a 100644 --- a/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py +++ b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py @@ -1,4 +1,4 @@ -"""add-dataset-retrival-model +"""add-dataset-retrieval-model Revision ID: fca025d3b60f Revises: b3a09c049e8e diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 1c0f4de897..12e0418093 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -58,8 +58,8 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: @staticmethod - def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None): - query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by( + def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None): + query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by( Dataset.created_at.desc() ) diff --git a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py index e84475c181..fee258dd22 100644 --- a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py +++ b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py @@ -20,9 +20,7 @@ class ProcessStatusSetting(BaseModel): class ApiTemplateSetting(BaseModel): - method: str url: str request_method: str - api_token: str headers: Optional[dict] = None params: Optional[dict] = None diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index f189aeacb8..e156b5a8ff 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -15,10 +15,13 @@ from models.dataset import ( ExternalApiTemplates, ExternalKnowledgeBindings, ) +from core.rag.models.document import Document as RetrievalDocument from models.model import UploadFile from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization from services.errors.dataset import DatasetNameDuplicateError # from tasks.external_document_indexing_task import external_document_indexing_task +import requests +import boto3 class ExternalDatasetService: @@ -173,7 +176,7 @@ class ExternalDatasetService: db.session.flush() document_ids.append(document.id) db.session.commit() - #external_document_indexing_task.delay(dataset.id, api_template_id, data_source, process_parameter) + # external_document_indexing_task.delay(dataset.id, api_template_id, data_source, process_parameter) return dataset @@ -189,7 +192,7 @@ class ExternalDatasetService: "follow_redirects": True, } - response = getattr(ssrf_proxy, settings.request_method)(data=settings.params, files=files, **kwargs) + response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs) return response @@ -260,9 +263,9 @@ class ExternalDatasetService: return dataset @staticmethod - def fetch_external_knowledge_retrival( - tenant_id: str, dataset_id: str, query: str, external_retrival_parameters: dict - ): + def fetch_external_knowledge_retrieval( + tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict + ) -> list: external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( dataset_id=dataset_id, tenant_id=tenant_id ).first() @@ -276,33 +279,58 @@ class ExternalDatasetService: raise ValueError("external api template not found") settings = json.loads(external_api_template.settings) - headers = {} - if settings.get("api_token"): - headers["Authorization"] = f"Bearer {settings.get('api_token')}" + headers = { + "Content-Type": "application/json" + } + if settings.get("api_key"): + headers["Authorization"] = f"Bearer {settings.get('api_key')}" - external_retrival_parameters["query"] = query + external_retrieval_parameters["query"] = query + external_retrieval_parameters["external_knowledge_id"] = external_knowledge_binding.external_knowledge_id api_template_setting = { - "url": f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents", + "url": f"{settings.get('endpoint')}/dify/external-knowledge/retrieval-documents", "request_method": "post", - "headers": settings.get("headers"), - "params": external_retrival_parameters, + "headers": headers, + "params": external_retrieval_parameters, } response = ExternalDatasetService.process_external_api(ApiTemplateSetting(**api_template_setting), None) - + if response.status_code == 200: + return response.json() + return [] @staticmethod - def test_external_knowledge_retrival( - top_k: int, score_threshold: float + def test_external_knowledge_retrieval( + top_k: int, score_threshold: float, query: str, external_knowledge_id: str ): - api_template_setting = { - "url": f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents", - "request_method": "post", - "headers": settings.get("headers"), - "params": { - "top_k": top_k, - "score_threshold": score_threshold, + client = boto3.client( + "bedrock-agent-runtime", + aws_secret_access_key='', + aws_access_key_id='', + region_name='', + ) + response = client.retrieve( + knowledgeBaseId=external_knowledge_id, + retrievalConfiguration={ + 'vectorSearchConfiguration': { + 'numberOfResults': top_k, + 'overrideSearchType': 'HYBRID' + } }, - } - response = ExternalDatasetService.process_external_api(ApiTemplateSetting(**api_template_setting), None) - return response.json() \ No newline at end of file + retrievalQuery={ + 'text': query + } + ) + results = [] + if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200: + if response.get("retrievalResults"): + retrieval_results = response.get("retrievalResults") + for retrieval_result in retrieval_results: + result = { + "metadata": retrieval_result.get("metadata"), + "score": retrieval_result.get("score"), + "title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"), + "content": retrieval_result.get("content").get("text"), + } + results.append(result) + return results diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 1f85d92e8c..196720882a 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -20,15 +20,15 @@ default_retrieval_model = { class HitTestingService: @classmethod def retrieve( - cls, - dataset: Dataset, - query: str, - account: Account, - retrieval_model: dict, - external_retrieval_model: dict, - limit: int = 10, + cls, + dataset: Dataset, + query: str, + account: Account, + retrieval_model: dict, + external_retrieval_model: dict, + limit: int = 10, ) -> dict: - if dataset.available_document_count == 0 or dataset.available_segment_count == 0: + if (dataset.available_document_count == 0 or dataset.available_segment_count == 0): return { "query": { "content": query, @@ -56,8 +56,6 @@ class HitTestingService: else None, reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), - provider=dataset.provider, - external_retrieval_model=external_retrieval_model, ) end = time.perf_counter() @@ -72,10 +70,45 @@ class HitTestingService: return cls.compact_retrieve_response(dataset, query, all_documents) + @classmethod + def external_retrieve( + cls, + dataset: Dataset, + query: str, + account: Account, + external_retrieval_model: dict, + ) -> dict: + if dataset.provider != "external": + return { + "query": { + "content": query}, + "records": [], + } + + start = time.perf_counter() + + all_documents = RetrievalService.external_retrieve( + dataset_id=dataset.id, + query=cls.escape_query_for_search(query), + external_retrieval_model=external_retrieval_model, + ) + + end = time.perf_counter() + logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds") + + dataset_query = DatasetQuery( + dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id + ) + + db.session.add(dataset_query) + db.session.commit() + + return cls.compact_external_retrieve_response(dataset, query, all_documents) + @classmethod def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): - i = 0 records = [] + for document in documents: index_node_id = document.metadata["doc_id"] @@ -91,7 +124,6 @@ class HitTestingService: ) if not segment: - i += 1 continue record = { @@ -101,8 +133,6 @@ class HitTestingService: records.append(record) - i += 1 - return { "query": { "content": query, @@ -110,6 +140,25 @@ class HitTestingService: "records": records, } + @classmethod + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list): + records = [] + if dataset.provider == "external": + for document in documents: + record = { + "content": document.get("content", None), + "title": document.get("title", None), + "score": document.get("score", None), + "metadata": document.get("metadata", None), + } + records.append(record) + return { + "query": { + "content": query, + }, + "records": records, + } + @classmethod def hit_testing_args_check(cls, args): query = args["query"]