From e03b371c4396d9c3f5cc83308847bb77ec7e0dd3 Mon Sep 17 00:00:00 2001 From: fatelei Date: Sun, 28 Dec 2025 11:13:33 +0800 Subject: [PATCH 1/2] feat: allow fail fast --- api/core/rag/datasource/retrieval_service.py | 16 +++- api/core/rag/retrieval/dataset_retrieval.py | 94 ++++++++++++------- .../rag/retrieval/test_dataset_retrieval.py | 13 ++- 3 files changed, 87 insertions(+), 36 deletions(-) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 9807cb4e6a..0c1158a658 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -106,7 +106,12 @@ class RetrievalService: ) ) - concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED) + if futures: + for future in concurrent.futures.as_completed(futures, timeout=3600): + if future.exception(): + for f in futures: + f.cancel() + break if exceptions: raise ValueError(";\n".join(exceptions)) @@ -662,7 +667,14 @@ class RetrievalService: document_ids_filter=document_ids_filter, ) ) - concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED) + # Use as_completed for early error propagation - cancel remaining futures on first error + if futures: + for future in concurrent.futures.as_completed(futures, timeout=300): + if future.exception(): + # Cancel remaining futures to avoid unnecessary waiting + for f in futures: + f.cancel() + break if exceptions: raise ValueError(";\n".join(exceptions)) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8f6c620925..1880a1b8b7 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -516,6 +516,9 @@ class DatasetRetrieval: ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model with measure_time() as timer: + cancel_event = threading.Event() + thread_exceptions: list[Exception] = [] + if query: query_thread = threading.Thread( target=self._multiple_retrieve_thread, @@ -534,6 +537,8 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": query, "attachment_id": None, + "cancel_event": cancel_event, + "thread_exceptions": thread_exceptions, }, ) all_threads.append(query_thread) @@ -557,12 +562,21 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": None, "attachment_id": attachment_id, + "cancel_event": cancel_event, + "thread_exceptions": thread_exceptions, }, ) all_threads.append(attachment_thread) attachment_thread.start() + for thread in all_threads: - thread.join() + thread.join(timeout=300) + if thread_exceptions: + cancel_event.set() + break + + if thread_exceptions: + raise thread_exceptions[0] self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) if all_documents: @@ -1402,40 +1416,49 @@ class DatasetRetrieval: score_threshold: float, query: str | None, attachment_id: str | None, + cancel_event: threading.Event | None = None, + thread_exceptions: list[Exception] | None = None, ): - with flask_app.app_context(): - threads = [] - all_documents_item: list[Document] = [] - index_type = None - for dataset in available_datasets: - index_type = dataset.indexing_technique - document_ids_filter = None - if dataset.provider != "external": - if metadata_condition and not metadata_filter_document_ids: - continue - if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) - if document_ids: - document_ids_filter = document_ids - else: + try: + with flask_app.app_context(): + threads = [] + all_documents_item: list[Document] = [] + index_type = None + for dataset in available_datasets: + # Check for cancellation signal + if cancel_event and cancel_event.is_set(): + break + index_type = dataset.indexing_technique + document_ids_filter = None + if dataset.provider != "external": + if metadata_condition and not metadata_filter_document_ids: continue - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": flask_app, - "dataset_id": dataset.id, - "query": query, - "top_k": top_k, - "all_documents": all_documents_item, - "document_ids_filter": document_ids_filter, - "metadata_condition": metadata_condition, - "attachment_ids": [attachment_id] if attachment_id else None, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids + else: + continue + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": flask_app, + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents_item, + "document_ids_filter": document_ids_filter, + "metadata_condition": metadata_condition, + "attachment_ids": [attachment_id] if attachment_id else None, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join(timeout=300) + # Check for cancellation signal between threads + if cancel_event and cancel_event.is_set(): + break if reranking_enable: # do rerank for searched documents @@ -1468,3 +1491,8 @@ class DatasetRetrieval: all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item if all_documents_item: all_documents.extend(all_documents_item) + except Exception as e: + if cancel_event: + cancel_event.set() + if thread_exceptions is not None: + thread_exceptions.append(e) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index affd6c648f..6306d665e7 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -421,7 +421,18 @@ class TestRetrievalService: # In real code, this waits for all futures to complete # In tests, futures complete immediately, so wait is a no-op with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"): - yield mock_executor + # Mock concurrent.futures.as_completed for early error propagation + # In real code, this yields futures as they complete + # In tests, we yield all futures immediately since they're already done + def mock_as_completed(futures_list, timeout=None): + """Mock as_completed that yields futures immediately.""" + yield from futures_list + + with patch( + "core.rag.datasource.retrieval_service.concurrent.futures.as_completed", + side_effect=mock_as_completed, + ): + yield mock_executor # ==================== Vector Search Tests ==================== From 8818acba317f8cdfd557fe2a873f0517e0fcd1dc Mon Sep 17 00:00:00 2001 From: fatelei Date: Sun, 28 Dec 2025 11:24:37 +0800 Subject: [PATCH 2/2] chore: resolve review issue --- api/core/rag/datasource/retrieval_service.py | 8 +++++++- api/core/rag/retrieval/dataset_retrieval.py | 20 ++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 0c1158a658..a4ac63004c 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,4 +1,5 @@ import concurrent.futures +import logging from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -36,6 +37,8 @@ default_retrieval_model = { "score_threshold_enabled": False, } +logger = logging.getLogger(__name__) + class RetrievalService: # Cache precompiled regular expressions to avoid repeated compilation @@ -108,7 +111,7 @@ class RetrievalService: if futures: for future in concurrent.futures.as_completed(futures, timeout=3600): - if future.exception(): + if exceptions: for f in futures: f.cancel() break @@ -215,6 +218,7 @@ class RetrievalService: ) all_documents.extend(documents) except Exception as e: + logger.error(e, exc_info=True) exceptions.append(str(e)) @classmethod @@ -308,6 +312,7 @@ class RetrievalService: else: all_documents.extend(documents) except Exception as e: + logger.error(e, exc_info=True) exceptions.append(str(e)) @classmethod @@ -356,6 +361,7 @@ class RetrievalService: else: all_documents.extend(documents) except Exception as e: + logger.error(e, exc_info=True) exceptions.append(str(e)) @staticmethod diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 1880a1b8b7..b32e3e8829 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -569,10 +569,14 @@ class DatasetRetrieval: all_threads.append(attachment_thread) attachment_thread.start() - for thread in all_threads: - thread.join(timeout=300) + # Poll threads with short timeout to detect errors quickly (fail-fast) + while any(t.is_alive() for t in all_threads): + for thread in all_threads: + thread.join(timeout=0.1) + if thread_exceptions: + cancel_event.set() + break if thread_exceptions: - cancel_event.set() break if thread_exceptions: @@ -1454,9 +1458,13 @@ class DatasetRetrieval: ) threads.append(retrieval_thread) retrieval_thread.start() - for thread in threads: - thread.join(timeout=300) - # Check for cancellation signal between threads + + # Poll threads with short timeout to respond quickly to cancellation + while any(t.is_alive() for t in threads): + for thread in threads: + thread.join(timeout=0.1) + if cancel_event and cancel_event.is_set(): + break if cancel_event and cancel_event.is_set(): break