From 611f0fb3f64a3c08a6428a1d37cd81b8f98e8103 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 26 Sep 2024 16:38:53 +0800 Subject: [PATCH] update to external knowledge api --- .../console/datasets/test_external.py | 15 ++++-------- api/services/external_knowledge_service.py | 24 ++++++++++++------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/api/controllers/console/datasets/test_external.py b/api/controllers/console/datasets/test_external.py index 3f3b760b9c..044aceb0c7 100644 --- a/api/controllers/console/datasets/test_external.py +++ b/api/controllers/console/datasets/test_external.py @@ -14,16 +14,11 @@ class TestExternalApi(Resource): def post(self): parser = reqparse.RequestParser() parser.add_argument( - "top_k", + "retrieval_setting", nullable=False, required=True, - type=int, - ) - parser.add_argument( - "score_threshold", - nullable=False, - required=True, - type=float, + type=dict, + location="json" ) parser.add_argument( "query", @@ -32,14 +27,14 @@ class TestExternalApi(Resource): type=str, ) parser.add_argument( - "external_knowledge_id", + "knowledge_id", nullable=False, required=True, type=str, ) args = parser.parse_args() result = ExternalDatasetService.test_external_knowledge_retrieval( - args["top_k"], args["score_threshold"], args["query"], args["external_knowledge_id"] + args["retrieval_setting"], args["query"], args["knowledge_id"] ) return result, 200 diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 441c431bb1..ca7bdc1439 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -283,22 +283,28 @@ class ExternalDatasetService: if settings.get("api_key"): headers["Authorization"] = f"Bearer {settings.get('api_key')}" - external_retrieval_parameters["query"] = query - external_retrieval_parameters["external_knowledge_id"] = external_knowledge_binding.external_knowledge_id + request_params = { + "retrieval_setting": { + "top_k": external_retrieval_parameters.get("top_k"), + "score_threshold": external_retrieval_parameters.get("score_threshold"), + }, + "query": query, + "knowledge_id": external_knowledge_binding.external_knowledge_id, + } external_knowledge_api_setting = { "url": f"{settings.get('endpoint')}/dify/external-knowledge/retrieval-documents", "request_method": "post", "headers": headers, - "params": external_retrieval_parameters, + "params": request_params, } response = ExternalDatasetService.process_external_api(ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None) if response.status_code == 200: - return response.json() + return response.json().get("records", []) return [] @staticmethod - def test_external_knowledge_retrieval(top_k: int, score_threshold: float, query: str, external_knowledge_id: str): + def test_external_knowledge_retrieval(retrieval_setting: dict, query: str, external_knowledge_id: str): client = boto3.client( "bedrock-agent-runtime", aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY, @@ -308,7 +314,7 @@ class ExternalDatasetService: response = client.retrieve( knowledgeBaseId=external_knowledge_id, retrievalConfiguration={ - "vectorSearchConfiguration": {"numberOfResults": top_k, "overrideSearchType": "HYBRID"} + "vectorSearchConfiguration": {"numberOfResults": retrieval_setting.get("top_k"), "overrideSearchType": "HYBRID"} }, retrievalQuery={"text": query}, ) @@ -317,7 +323,7 @@ class ExternalDatasetService: if response.get("retrievalResults"): retrieval_results = response.get("retrievalResults") for retrieval_result in retrieval_results: - if retrieval_result.get("score") < score_threshold: + if retrieval_result.get("score") < retrieval_setting.get("score_threshold", .0): continue result = { "metadata": retrieval_result.get("metadata"), @@ -326,4 +332,6 @@ class ExternalDatasetService: "content": retrieval_result.get("content").get("text"), } results.append(result) - return results + return { + "records": results + }