diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index e017c2c5b8..4083ee2aed 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -5,6 +5,7 @@ from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed from pydantic_settings import BaseSettings from configs.middleware.cache.redis_config import RedisConfig +from configs.middleware.external.bedrock_config import BedrockConfig from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig @@ -221,5 +222,6 @@ class MiddlewareConfig( TiDBVectorConfig, WeaviateConfig, ElasticsearchConfig, + BedrockConfig, ): pass diff --git a/api/configs/middleware/external/bedrock_config.py b/api/configs/middleware/external/bedrock_config.py new file mode 100644 index 0000000000..be10da1432 --- /dev/null +++ b/api/configs/middleware/external/bedrock_config.py @@ -0,0 +1,19 @@ +from typing import Optional + +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings + + +class BedrockConfig(BaseSettings): + """ + bedrock configs + """ + AWS_SECRET_ACCESS_KEY: Optional[str] = Field( + description="AWS secret access key", + default=None, + ) + + AWS_ACCESS_KEY_ID: Optional[str] = Field( + description="AWS secret access id", + default=None, + ) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index b1c76375e4..c0e07d1bae 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -231,7 +231,9 @@ class ExternalDatasetCreateApi(Resource): help="name is required. Name must be between 1 to 100 characters.", type=_validate_name, ) - parser.add_argument("description", type=str, required=True, nullable=True, location="json") + parser.add_argument("description", type=str, required=False, nullable=True, location="json") + parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") + args = parser.parse_args() @@ -287,6 +289,7 @@ class ExternalKnowledgeHitTestingApi(Resource): api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets//external-hit-testing") +api.add_resource(ExternalDatasetCreateApi, "/datasets/external") api.add_resource(ExternalApiTemplateListApi, "/datasets/external-api-template") api.add_resource(ExternalApiTemplateApi, "/datasets/external-api-template/") api.add_resource(ExternalApiUseCheckApi, "/datasets/external-api-template//use-check") diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 7cf472d984..23dc092895 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler: for item in resource: dataset_retriever_resource = DatasetRetrieverResource( message_id=self._message_id, - position=item.get("position"), + position=item.get("position") or 0, dataset_id=item.get("dataset_id"), dataset_name=item.get("dataset_name"), document_id=item.get("document_id"), diff --git a/api/core/rag/entities/context_entities.py b/api/core/rag/entities/context_entities.py new file mode 100644 index 0000000000..dde3beccf6 --- /dev/null +++ b/api/core/rag/entities/context_entities.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class DocumentContext(BaseModel): + """ + Model class for document context. + """ + + content: str + score: float diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 0ff1fdb81c..02b4bc82b0 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -17,6 +17,8 @@ class Document(BaseModel): """ metadata: Optional[dict] = Field(default_factory=dict) + provider: Optional[str] = 'dify' + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 6356ff87ab..27f86aed34 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -28,11 +28,16 @@ class RerankModelRunner: docs = [] doc_id = [] unique_documents = [] - for document in documents: + dify_documents = [item for item in documents if item.provider == "dify"] + external_documents = [item for item in documents if item.provider == "external"] + for document in dify_documents: if document.metadata["doc_id"] not in doc_id: doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) + for document in external_documents: + docs.append(document.page_content) + unique_documents.append(document) documents = unique_documents @@ -46,14 +51,10 @@ class RerankModelRunner: # format document rerank_document = Document( page_content=result.text, - metadata={ - "doc_id": documents[result.index].metadata["doc_id"], - "doc_hash": documents[result.index].metadata["doc_hash"], - "document_id": documents[result.index].metadata["document_id"], - "dataset_id": documents[result.index].metadata["dataset_id"], - "score": result.score, - }, + metadata=documents[result.index].metadata, + provider=documents[result.index].provider, ) + rerank_document.metadata["score"] = result.score rerank_documents.append(rerank_document) return rerank_documents diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 286ecd4c03..8c404fb12c 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -20,6 +20,7 @@ from core.ops.utils import measure_time from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.context_entities import DocumentContext from core.rag.models.document import Document from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter @@ -30,6 +31,7 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr from extensions.ext_database import db from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument +from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, @@ -110,7 +112,7 @@ class DatasetRetrieval: continue # pass if dataset is not available - if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0 and dataset.provider != "external": continue available_datasets.append(dataset) @@ -146,69 +148,84 @@ class DatasetRetrieval: message_id, ) - document_score_list = {} - for item in all_documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - + dify_documents = [item for item in all_documents if item.provider == "dify"] + external_documents = [item for item in all_documents if item.provider == "external"] document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in all_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() + retrieval_resource_list = [] + # deal with external documents + for item in external_documents: + document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) + source = { + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": invoke_from.to_source(), + "score": item.metadata.get("score"), + "content": item.page_content, + } + retrieval_resource_list.append(source) + document_score_list = {} + # deal with dify documents + if dify_documents: + for item in dify_documents: + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: - if segment.answer: - document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") - else: - document_context_list.append(segment.get_sign_content()) - if show_retrieve_source: - context_list = [] - resource_number = 1 + + index_node_ids = [document.metadata["doc_id"] for document in dify_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(dataset_ids), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = DatasetDocument.query.filter( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() - if dataset and document: - source = { - "position": resource_number, - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": invoke_from.to_source(), - "score": document_score_list.get(segment.index_node_id, None), - } + if segment.answer: + document_context_list.append(DocumentContext(content=f"question:{segment.get_sign_content()} answer:{segment.answer}", score=document_score_list.get(segment.index_node_id, None))) + else: + document_context_list.append(DocumentContext(content=segment.get_sign_content(), score=document_score_list.get(segment.index_node_id, None))) + if show_retrieve_source: + for segment in sorted_segments: + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = DatasetDocument.query.filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).first() + if dataset and document: + source = { + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": invoke_from.to_source(), + "score": document_score_list.get(segment.index_node_id, None), + } - if invoke_from.to_source() == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash - if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" - else: - source["content"] = segment.content - context_list.append(source) - resource_number += 1 - if hit_callback: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join(document_context_list)) + if invoke_from.to_source() == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash + if segment.answer: + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + else: + source["content"] = segment.content + retrieval_resource_list.append(source) + if hit_callback and retrieval_resource_list: + hit_callback.return_retriever_resource_info(retrieval_resource_list) + if document_context_list: + document_context_list = sorted(document_context_list, key=lambda x: x.score, reverse=True) + return str("\n".join([document_context.content for document_context in document_context_list])) return "" def single_retrieve( @@ -256,36 +273,56 @@ class DatasetRetrieval: # get retrieval model config dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: - retrieval_model_config = dataset.retrieval_model or default_retrieval_model - - # get top k - top_k = retrieval_model_config["top_k"] - # get retrieval method - if dataset.indexing_technique == "economy": - retrieval_method = "keyword_search" - else: - retrieval_method = retrieval_model_config["search_method"] - # get reranking model - reranking_model = ( - retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None - ) - # get score threshold - score_threshold = 0.0 - score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") - if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") - - with measure_time() as timer: - results = RetrievalService.retrieve( - retrieval_method=retrieval_method, - dataset_id=dataset.id, + results = [] + if dataset.provider == "external": + external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + tenant_id=dataset.tenant_id, + dataset_id=dataset_id, query=query, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), - weights=retrieval_model_config.get("weights", None), + external_retrieval_parameters=dataset.retrieval_model ) + for external_document in external_documents: + document = Document( + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", + ) + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name + results.append(document) + else: + retrieval_model_config = dataset.retrieval_model or default_retrieval_model + + # get top k + top_k = retrieval_model_config["top_k"] + # get retrieval method + if dataset.indexing_technique == "economy": + retrieval_method = "keyword_search" + else: + retrieval_method = retrieval_model_config["search_method"] + # get reranking model + reranking_model = ( + retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None + ) + # get score threshold + score_threshold = 0.0 + score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") + if score_threshold_enabled: + score_threshold = retrieval_model_config.get("score_threshold") + + with measure_time() as timer: + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), + weights=retrieval_model_config.get("weights", None), + ) self._on_query(query, [dataset_id], app_id, user_from, user_id) if results: @@ -356,7 +393,8 @@ class DatasetRetrieval: self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None ) -> None: """Handle retrieval end.""" - for document in documents: + dify_documents = [document for document in documents if document.provider == "dify"] + for document in dify_documents: query = db.session.query(DocumentSegment).filter( DocumentSegment.index_node_id == document.metadata["doc_id"] ) @@ -409,35 +447,54 @@ class DatasetRetrieval: if not dataset: return [] - # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model - - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve( - retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k + if dataset.provider == "external": + external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + tenant_id=dataset.tenant_id, + dataset_id=dataset_id, + query=query, + external_retrieval_parameters=dataset.retrieval_model ) - if documents: - all_documents.extend(documents) - else: - if top_k > 0: - # retrieval source - documents = RetrievalService.retrieve( - retrieval_method=retrieval_model["search_method"], - dataset_id=dataset.id, - query=query, - top_k=retrieval_model.get("top_k") or 2, - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] - else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] - else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + for external_document in external_documents: + document = Document( + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", ) + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name + all_documents.append(document) + else: + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model or default_retrieval_model - all_documents.extend(documents) + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k + ) + if documents: + all_documents.extend(documents) + else: + if top_k > 0: + # retrieval source + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) + + all_documents.extend(documents) def to_dataset_retriever_tool( self, diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index af55688a52..121c96e619 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -156,16 +156,34 @@ class KnowledgeRetrievalNode(BaseNode): weights, node_data.multiple_retrieval_config.reranking_enable, ) - - context_list = [] - if all_documents: + dify_documents = [item for item in all_documents if item.provider == "dify"] + external_documents = [item for item in all_documents if item.provider == "external"] + retrieval_resource_list = [] + # deal with external documents + for item in external_documents: + source = { + "metadata": { + "_source": "knowledge", + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": 'workflow', + "score": item.metadata.get("score"), + }, + "title": item.metadata.get("title"), + "content": item.page_content, + } + retrieval_resource_list.append(source) + document_score_list = {} + # deal with dify documents + if dify_documents: document_score_list = {} - page_number_list = {} - for item in all_documents: + for item in dify_documents: if item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - index_node_ids = [document.metadata["doc_id"] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in dify_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), @@ -186,13 +204,10 @@ class KnowledgeRetrievalNode(BaseNode): Document.enabled == True, Document.archived == False, ).first() - - resource_number = 1 if dataset and document: source = { "metadata": { "_source": "knowledge", - "position": resource_number, "dataset_id": dataset.id, "dataset_name": dataset.name, "document_id": document.id, @@ -212,9 +227,14 @@ class KnowledgeRetrievalNode(BaseNode): source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: source["content"] = segment.get_sign_content() - context_list.append(source) - resource_number += 1 - return context_list + retrieval_resource_list.append(source) + if retrieval_resource_list: + retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score"), reverse=True) + position = 1 + for item in retrieval_resource_list: + item["metadata"]["position"] = position + position += 1 + return retrieval_resource_list @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py new file mode 100644 index 0000000000..c79b7759db --- /dev/null +++ b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py @@ -0,0 +1,48 @@ +"""update-retrieval-resource + +Revision ID: 6af6a521a53e +Revises: ec3df697ebbb +Create Date: 2024-09-24 09:22:43.570120 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6af6a521a53e' +down_revision = 'ec3df697ebbb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('document_id', + existing_type=sa.UUID(), + nullable=True) + batch_op.alter_column('data_source_type', + existing_type=sa.TEXT(), + nullable=True) + batch_op.alter_column('segment_id', + existing_type=sa.UUID(), + nullable=True) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('segment_id', + existing_type=sa.UUID(), + nullable=False) + batch_op.alter_column('data_source_type', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('document_id', + existing_type=sa.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 7fb217f7c3..585c83e24a 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -72,6 +72,15 @@ class Dataset(db.Model): def index_struct_dict(self): return json.loads(self.index_struct) if self.index_struct else None + @property + def external_retrieval_model(self): + + default_retrieval_model = { + "top_k": 2, + "score_threshold": .0, + } + return self.retrieval_model or default_retrieval_model + @property def created_by_account(self): return db.session.get(Account, self.created_by) diff --git a/api/models/model.py b/api/models/model.py index ae0bc3210b..7f91871991 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1422,10 +1422,10 @@ class DatasetRetrieverResource(db.Model): position = db.Column(db.Integer, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) dataset_name = db.Column(db.Text, nullable=False) - document_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=True) document_name = db.Column(db.Text, nullable=False) - data_source_type = db.Column(db.Text, nullable=False) - segment_id = db.Column(StringUUID, nullable=False) + data_source_type = db.Column(db.Text, nullable=True) + segment_id = db.Column(StringUUID, nullable=True) score = db.Column(db.Float, nullable=True) content = db.Column(db.Text, nullable=False) hit_count = db.Column(db.Integer, nullable=True) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index e156b5a8ff..e1acac5c58 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -7,6 +7,7 @@ from typing import Any, Optional, Union import httpx +from configs import dify_config from core.helper import ssrf_proxy from extensions.ext_database import db from models.dataset import ( @@ -243,6 +244,7 @@ class ExternalDatasetService: name=args.get("name"), description=args.get("description", ""), provider="external", + retrieval_model=args.get("external_retrieval_model"), created_by=user_id, ) @@ -305,9 +307,9 @@ class ExternalDatasetService: ): client = boto3.client( "bedrock-agent-runtime", - aws_secret_access_key='', - aws_access_key_id='', - region_name='', + aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY, + aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID, + region_name='us-east-1', ) response = client.retrieve( knowledgeBaseId=external_knowledge_id, @@ -326,6 +328,8 @@ class ExternalDatasetService: if response.get("retrievalResults"): retrieval_results = response.get("retrievalResults") for retrieval_result in retrieval_results: + if retrieval_result.get("score") < score_threshold: + continue result = { "metadata": retrieval_result.get("metadata"), "score": retrieval_result.get("score"),