From e924dc7b30131ab905ea400694c8c09258f7883f Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 4 Dec 2025 10:14:28 +0800 Subject: [PATCH] chore: ignore redis lock not owned error (#29064) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/dataset_service.py | 576 +++++++++--------- .../test_dataset_service_lock_not_owned.py | 177 ++++++ 2 files changed, 472 insertions(+), 281 deletions(-) create mode 100644 api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2bec61963c..208ebcb018 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,6 +10,7 @@ from collections.abc import Sequence from typing import Any, Literal import sqlalchemy as sa +from redis.exceptions import LockNotOwnedError from sqlalchemy import exists, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -1593,173 +1594,176 @@ class DocumentService: db.session.add(dataset_process_rule) db.session.flush() lock_name = f"add_document_lock_dataset_id_{dataset.id}" - with redis_client.lock(lock_name, timeout=600): - assert dataset_process_rule - position = DocumentService.get_documents_position(dataset.id) - document_ids = [] - duplicate_document_ids = [] - if knowledge_config.data_source.info_list.data_source_type == "upload_file": - if not knowledge_config.data_source.info_list.file_info_list: - raise ValueError("File source info is required") - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids - for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() - ) - - # raise error if file not found - if not file: - raise FileNotExistsError() - - file_name = file.name - data_source_info: dict[str, str | bool] = { - "upload_file_id": file_id, - } - # check duplicate - if knowledge_config.duplicate: - document = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ) + try: + with redis_client.lock(lock_name, timeout=600): + assert dataset_process_rule + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + if not knowledge_config.data_source.info_list.file_info_list: + raise ValueError("File source info is required") + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .first() ) - if document: - document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = naive_utc_now() - document.created_from = created_from - document.doc_form = knowledge_config.doc_form - document.doc_language = knowledge_config.doc_language - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = "waiting" - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - knowledge_config.data_source.info_list.data_source_type, - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": - notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore - if not notion_info_list: - raise ValueError("No notion info list found.") - exist_page_ids = [] - exist_document = {} - documents = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, - ) - .all() - ) - if documents: - for document in documents: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info["notion_page_id"]) - exist_document[data_source_info["notion_page_id"]] = document.id - for notion_info in notion_info_list: - workspace_id = notion_info.workspace_id - for page in notion_info.pages: - if page.page_id not in exist_page_ids: - data_source_info = { - "credential_id": notion_info.credential_id, - "notion_workspace_id": workspace_id, - "notion_page_id": page.page_id, - "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore - "type": page.type, - } - # Truncate page name to 255 characters to prevent DB field length errors - truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - knowledge_config.data_source.info_list.data_source_type, - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - truncated_page_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - else: - exist_document.pop(page.page_id) - # delete not selected documents - if len(exist_document) > 0: - clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": - website_info = knowledge_config.data_source.info_list.website_info_list - if not website_info: - raise ValueError("No website info list found.") - urls = website_info.urls - for url in urls: - data_source_info = { - "url": url, - "provider": website_info.provider, - "job_id": website_info.job_id, - "only_main_content": website_info.only_main_content, - "mode": "crawl", - } - if len(url) > 255: - document_name = url[:200] + "..." - else: - document_name = url - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - knowledge_config.data_source.info_list.data_source_type, - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - document_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - db.session.commit() - # trigger async task - if document_ids: - DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay() - if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info: dict[str, str | bool] = { + "upload_file_id": file_id, + } + # check duplicate + if knowledge_config.duplicate: + document = ( + db.session.query(Document) + .filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ) + .first() + ) + if document: + document.dataset_process_rule_id = dataset_process_rule.id + document.updated_at = naive_utc_now() + document.created_from = created_from + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + file_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + if not notion_info_list: + raise ValueError("No notion info list found.") + exist_page_ids = [] + exist_document = {} + documents = ( + db.session.query(Document) + .filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ) + .all() + ) + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id + for notion_info in notion_info_list: + workspace_id = notion_info.workspace_id + for page in notion_info.pages: + if page.page_id not in exist_page_ids: + data_source_info = { + "credential_id": notion_info.credential_id, + "notion_workspace_id": workspace_id, + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore + "type": page.type, + } + # Truncate page name to 255 characters to prevent DB field length errors + truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + truncated_page_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + else: + exist_document.pop(page.page_id) + # delete not selected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + if not website_info: + raise ValueError("No website info list found.") + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, + "mode": "crawl", + } + if len(url) > 255: + document_name = url[:200] + "..." + else: + document_name = url + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + document_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + db.session.commit() + + # trigger async task + if document_ids: + DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay() + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + except LockNotOwnedError: + pass return documents, batch @@ -2699,50 +2703,55 @@ class SegmentService: # calc embedding use tokens tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] lock_name = f"add_segment_lock_document_id_{document.id}" - with redis_client.lock(lock_name, timeout=600): - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document.id) - .scalar() - ) - segment_document = DocumentSegment( - tenant_id=current_user.current_tenant_id, - dataset_id=document.dataset_id, - document_id=document.id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - status="completed", - indexing_at=naive_utc_now(), - completed_at=naive_utc_now(), - created_by=current_user.id, - ) - if document.doc_form == "qa_model": - segment_document.word_count += len(args["answer"]) - segment_document.answer = args["answer"] + try: + with redis_client.lock(lock_name, timeout=600): + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == document.id) + .scalar() + ) + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + status="completed", + indexing_at=naive_utc_now(), + completed_at=naive_utc_now(), + created_by=current_user.id, + ) + if document.doc_form == "qa_model": + segment_document.word_count += len(args["answer"]) + segment_document.answer = args["answer"] - db.session.add(segment_document) - # update document word count - assert document.word_count is not None - document.word_count += segment_document.word_count - db.session.add(document) - db.session.commit() - - # save vector index - try: - VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form) - except Exception as e: - logger.exception("create segment index failed") - segment_document.enabled = False - segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" - segment_document.error = str(e) + db.session.add(segment_document) + # update document word count + assert document.word_count is not None + document.word_count += segment_document.word_count + db.session.add(document) db.session.commit() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() - return segment + + # save vector index + try: + VectorService.create_segments_vector( + [args["keywords"]], [segment_document], dataset, document.doc_form + ) + except Exception as e: + logger.exception("create segment index failed") + segment_document.enabled = False + segment_document.disabled_at = naive_utc_now() + segment_document.status = "error" + segment_document.error = str(e) + db.session.commit() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() + return segment + except LockNotOwnedError: + pass @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): @@ -2751,84 +2760,89 @@ class SegmentService: lock_name = f"multi_add_segment_lock_document_id_{document.id}" increment_word_count = 0 - with redis_client.lock(lock_name, timeout=600): - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, + try: + with redis_client.lock(lock_name, timeout=600): + embedding_model = None + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == document.id) + .scalar() ) - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document.id) - .scalar() - ) - pre_segment_data_list = [] - segment_data_list = [] - keywords_list = [] - position = max_position + 1 if max_position else 1 - for segment_item in segments: - content = segment_item["content"] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - tokens = 0 - if dataset.indexing_technique == "high_quality" and embedding_model: - # calc embedding use tokens + pre_segment_data_list = [] + segment_data_list = [] + keywords_list = [] + position = max_position + 1 if max_position else 1 + for segment_item in segments: + content = segment_item["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + tokens = 0 + if dataset.indexing_technique == "high_quality" and embedding_model: + # calc embedding use tokens + if document.doc_form == "qa_model": + tokens = embedding_model.get_text_embedding_num_tokens( + texts=[content + segment_item["answer"]] + )[0] + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] + + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=position, + content=content, + word_count=len(content), + tokens=tokens, + keywords=segment_item.get("keywords", []), + status="completed", + indexing_at=naive_utc_now(), + completed_at=naive_utc_now(), + created_by=current_user.id, + ) if document.doc_form == "qa_model": - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content + segment_item["answer"]] - )[0] + segment_document.answer = segment_item["answer"] + segment_document.word_count += len(segment_item["answer"]) + increment_word_count += segment_document.word_count + db.session.add(segment_document) + segment_data_list.append(segment_document) + position += 1 + + pre_segment_data_list.append(segment_document) + if "keywords" in segment_item: + keywords_list.append(segment_item["keywords"]) else: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] - - segment_document = DocumentSegment( - tenant_id=current_user.current_tenant_id, - dataset_id=document.dataset_id, - document_id=document.id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=position, - content=content, - word_count=len(content), - tokens=tokens, - keywords=segment_item.get("keywords", []), - status="completed", - indexing_at=naive_utc_now(), - completed_at=naive_utc_now(), - created_by=current_user.id, - ) - if document.doc_form == "qa_model": - segment_document.answer = segment_item["answer"] - segment_document.word_count += len(segment_item["answer"]) - increment_word_count += segment_document.word_count - db.session.add(segment_document) - segment_data_list.append(segment_document) - position += 1 - - pre_segment_data_list.append(segment_document) - if "keywords" in segment_item: - keywords_list.append(segment_item["keywords"]) - else: - keywords_list.append(None) - # update document word count - assert document.word_count is not None - document.word_count += increment_word_count - db.session.add(document) - try: - # save vector index - VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form) - except Exception as e: - logger.exception("create segment index failed") - for segment_document in segment_data_list: - segment_document.enabled = False - segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" - segment_document.error = str(e) - db.session.commit() - return segment_data_list + keywords_list.append(None) + # update document word count + assert document.word_count is not None + document.word_count += increment_word_count + db.session.add(document) + try: + # save vector index + VectorService.create_segments_vector( + keywords_list, pre_segment_data_list, dataset, document.doc_form + ) + except Exception as e: + logger.exception("create segment index failed") + for segment_document in segment_data_list: + segment_document.enabled = False + segment_document.disabled_at = naive_utc_now() + segment_document.status = "error" + segment_document.error = str(e) + db.session.commit() + return segment_data_list + except LockNotOwnedError: + pass @classmethod def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py new file mode 100644 index 0000000000..bd226f7536 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -0,0 +1,177 @@ +import types +from unittest.mock import Mock, create_autospec + +import pytest +from redis.exceptions import LockNotOwnedError + +from models.account import Account +from models.dataset import Dataset, Document +from services.dataset_service import DocumentService, SegmentService + + +class FakeLock: + """Lock that always fails on enter with LockNotOwnedError.""" + + def __enter__(self): + raise LockNotOwnedError("simulated") + + def __exit__(self, exc_type, exc, tb): + # Normal contextmanager signature; return False so exceptions propagate + return False + + +@pytest.fixture +def fake_current_user(monkeypatch): + user = create_autospec(Account, instance=True) + user.id = "user-1" + user.current_tenant_id = "tenant-1" + monkeypatch.setattr("services.dataset_service.current_user", user) + return user + + +@pytest.fixture +def fake_features(monkeypatch): + """Features.billing.enabled == False to skip quota logic.""" + features = types.SimpleNamespace( + billing=types.SimpleNamespace(enabled=False, subscription=types.SimpleNamespace(plan="ENTERPRISE")), + documents_upload_quota=types.SimpleNamespace(limit=10_000, size=0), + ) + monkeypatch.setattr( + "services.dataset_service.FeatureService.get_features", + lambda tenant_id: features, + ) + return features + + +@pytest.fixture +def fake_lock(monkeypatch): + """Patch redis_client.lock to always raise LockNotOwnedError on enter.""" + + def _fake_lock(name, timeout=None, *args, **kwargs): + return FakeLock() + + # DatasetService imports redis_client directly from extensions.ext_redis + monkeypatch.setattr("services.dataset_service.redis_client.lock", _fake_lock) + + +# --------------------------------------------------------------------------- +# 1. Knowledge Pipeline document creation (save_document_with_dataset_id) +# --------------------------------------------------------------------------- + + +def test_save_document_with_dataset_id_ignores_lock_not_owned( + monkeypatch, + fake_current_user, + fake_features, + fake_lock, +): + # Arrange + dataset = create_autospec(Dataset, instance=True) + dataset.id = "ds-1" + dataset.tenant_id = fake_current_user.current_tenant_id + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" # so we skip re-initialization branch + + # Minimal knowledge_config stub that satisfies pre-lock code + info_list = types.SimpleNamespace(data_source_type="upload_file") + data_source = types.SimpleNamespace(info_list=info_list) + knowledge_config = types.SimpleNamespace( + doc_form="qa_model", + original_document_id=None, # go into "new document" branch + data_source=data_source, + indexing_technique="high_quality", + embedding_model=None, + embedding_model_provider=None, + retrieval_model=None, + process_rule=None, + duplicate=False, + doc_language="en", + ) + + account = fake_current_user + + # Avoid touching real doc_form logic + monkeypatch.setattr("services.dataset_service.DatasetService.check_doc_form", lambda *a, **k: None) + # Avoid real DB interactions + monkeypatch.setattr("services.dataset_service.db", Mock()) + + # Act: this would hit the redis lock, whose __enter__ raises LockNotOwnedError. + # Our implementation should catch it and still return (documents, batch). + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=account, + ) + + # Assert + # We mainly care that: + # - No exception is raised + # - The function returns a sensible tuple + assert isinstance(documents, list) + assert isinstance(batch, str) + + +# --------------------------------------------------------------------------- +# 2. Single-segment creation (add_segment) +# --------------------------------------------------------------------------- + + +def test_add_segment_ignores_lock_not_owned( + monkeypatch, + fake_current_user, + fake_lock, +): + # Arrange + dataset = create_autospec(Dataset, instance=True) + dataset.id = "ds-1" + dataset.tenant_id = fake_current_user.current_tenant_id + dataset.indexing_technique = "economy" # skip embedding/token calculation branch + + document = create_autospec(Document, instance=True) + document.id = "doc-1" + document.dataset_id = dataset.id + document.word_count = 0 + document.doc_form = "qa_model" + + # Minimal args required by add_segment + args = { + "content": "question text", + "answer": "answer text", + "keywords": ["k1", "k2"], + } + + # Avoid real DB operations + db_mock = Mock() + db_mock.session = Mock() + monkeypatch.setattr("services.dataset_service.db", db_mock) + monkeypatch.setattr("services.dataset_service.VectorService", Mock()) + + # Act + result = SegmentService.create_segment(args=args, document=document, dataset=dataset) + + # Assert + # Under LockNotOwnedError except, add_segment should swallow the error and return None. + assert result is None + + +# --------------------------------------------------------------------------- +# 3. Multi-segment creation (multi_create_segment) +# --------------------------------------------------------------------------- + + +def test_multi_create_segment_ignores_lock_not_owned( + monkeypatch, + fake_current_user, + fake_lock, +): + # Arrange + dataset = create_autospec(Dataset, instance=True) + dataset.id = "ds-1" + dataset.tenant_id = fake_current_user.current_tenant_id + dataset.indexing_technique = "economy" # again, skip high_quality path + + document = create_autospec(Document, instance=True) + document.id = "doc-1" + document.dataset_id = dataset.id + document.word_count = 0 + document.doc_form = "qa_model"