chore: ignore redis lock not owned error (#29064)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei 2025-12-04 10:14:28 +08:00 committed by GitHub
parent 4b969bdce3
commit e924dc7b30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 472 additions and 281 deletions

View File

@ -10,6 +10,7 @@ from collections.abc import Sequence
from typing import Any, Literal
import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -1593,173 +1594,176 @@ class DocumentService:
db.session.add(dataset_process_rule)
db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
with redis_client.lock(lock_name, timeout=600):
assert dataset_process_rule
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
if not knowledge_config.data_source.info_list.file_info_list:
raise ValueError("File source info is required")
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info: dict[str, str | bool] = {
"upload_file_id": file_id,
}
# check duplicate
if knowledge_config.duplicate:
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
try:
with redis_client.lock(lock_name, timeout=600):
assert dataset_process_rule
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
if not knowledge_config.data_source.info_list.file_info_list:
raise ValueError("File source info is required")
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
.all()
)
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"credential_id": notion_info.credential_id,
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
truncated_page_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
else:
exist_document.pop(page.page_id)
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
else:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
document_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
if document_ids:
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info: dict[str, str | bool] = {
"upload_file_id": file_id,
}
# check duplicate
if knowledge_config.duplicate:
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
.first()
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
.all()
)
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"credential_id": notion_info.credential_id,
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
truncated_page_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
else:
exist_document.pop(page.page_id)
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
else:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
document_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
if document_ids:
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
except LockNotOwnedError:
pass
return documents, batch
@ -2699,50 +2703,55 @@ class SegmentService:
# calc embedding use tokens
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
lock_name = f"add_segment_lock_document_id_{document.id}"
with redis_client.lock(lock_name, timeout=600):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
status="completed",
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
try:
with redis_client.lock(lock_name, timeout=600):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
status="completed",
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
# save vector index
try:
VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form)
except Exception as e:
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
return segment
# save vector index
try:
VectorService.create_segments_vector(
[args["keywords"]], [segment_document], dataset, document.doc_form
)
except Exception as e:
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
return segment
except LockNotOwnedError:
pass
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
@ -2751,84 +2760,89 @@ class SegmentService:
lock_name = f"multi_add_segment_lock_document_id_{document.id}"
increment_word_count = 0
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
try:
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
)
pre_segment_data_list = []
segment_data_list = []
keywords_list = []
position = max_position + 1 if max_position else 1
for segment_item in segments:
content = segment_item["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
pre_segment_data_list = []
segment_data_list = []
keywords_list = []
position = max_position + 1 if max_position else 1
for segment_item in segments:
content = segment_item["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
if document.doc_form == "qa_model":
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=position,
content=content,
word_count=len(content),
tokens=tokens,
keywords=segment_item.get("keywords", []),
status="completed",
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
db.session.add(segment_document)
segment_data_list.append(segment_document)
position += 1
pre_segment_data_list.append(segment_document)
if "keywords" in segment_item:
keywords_list.append(segment_item["keywords"])
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=position,
content=content,
word_count=len(content),
tokens=tokens,
keywords=segment_item.get("keywords", []),
status="completed",
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
db.session.add(segment_document)
segment_data_list.append(segment_document)
position += 1
pre_segment_data_list.append(segment_document)
if "keywords" in segment_item:
keywords_list.append(segment_item["keywords"])
else:
keywords_list.append(None)
# update document word count
assert document.word_count is not None
document.word_count += increment_word_count
db.session.add(document)
try:
# save vector index
VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form)
except Exception as e:
logger.exception("create segment index failed")
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
return segment_data_list
keywords_list.append(None)
# update document word count
assert document.word_count is not None
document.word_count += increment_word_count
db.session.add(document)
try:
# save vector index
VectorService.create_segments_vector(
keywords_list, pre_segment_data_list, dataset, document.doc_form
)
except Exception as e:
logger.exception("create segment index failed")
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
return segment_data_list
except LockNotOwnedError:
pass
@classmethod
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):

View File

@ -0,0 +1,177 @@
import types
from unittest.mock import Mock, create_autospec
import pytest
from redis.exceptions import LockNotOwnedError
from models.account import Account
from models.dataset import Dataset, Document
from services.dataset_service import DocumentService, SegmentService
class FakeLock:
"""Lock that always fails on enter with LockNotOwnedError."""
def __enter__(self):
raise LockNotOwnedError("simulated")
def __exit__(self, exc_type, exc, tb):
# Normal contextmanager signature; return False so exceptions propagate
return False
@pytest.fixture
def fake_current_user(monkeypatch):
user = create_autospec(Account, instance=True)
user.id = "user-1"
user.current_tenant_id = "tenant-1"
monkeypatch.setattr("services.dataset_service.current_user", user)
return user
@pytest.fixture
def fake_features(monkeypatch):
"""Features.billing.enabled == False to skip quota logic."""
features = types.SimpleNamespace(
billing=types.SimpleNamespace(enabled=False, subscription=types.SimpleNamespace(plan="ENTERPRISE")),
documents_upload_quota=types.SimpleNamespace(limit=10_000, size=0),
)
monkeypatch.setattr(
"services.dataset_service.FeatureService.get_features",
lambda tenant_id: features,
)
return features
@pytest.fixture
def fake_lock(monkeypatch):
"""Patch redis_client.lock to always raise LockNotOwnedError on enter."""
def _fake_lock(name, timeout=None, *args, **kwargs):
return FakeLock()
# DatasetService imports redis_client directly from extensions.ext_redis
monkeypatch.setattr("services.dataset_service.redis_client.lock", _fake_lock)
# ---------------------------------------------------------------------------
# 1. Knowledge Pipeline document creation (save_document_with_dataset_id)
# ---------------------------------------------------------------------------
def test_save_document_with_dataset_id_ignores_lock_not_owned(
monkeypatch,
fake_current_user,
fake_features,
fake_lock,
):
# Arrange
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.data_source_type = "upload_file"
dataset.indexing_technique = "high_quality" # so we skip re-initialization branch
# Minimal knowledge_config stub that satisfies pre-lock code
info_list = types.SimpleNamespace(data_source_type="upload_file")
data_source = types.SimpleNamespace(info_list=info_list)
knowledge_config = types.SimpleNamespace(
doc_form="qa_model",
original_document_id=None, # go into "new document" branch
data_source=data_source,
indexing_technique="high_quality",
embedding_model=None,
embedding_model_provider=None,
retrieval_model=None,
process_rule=None,
duplicate=False,
doc_language="en",
)
account = fake_current_user
# Avoid touching real doc_form logic
monkeypatch.setattr("services.dataset_service.DatasetService.check_doc_form", lambda *a, **k: None)
# Avoid real DB interactions
monkeypatch.setattr("services.dataset_service.db", Mock())
# Act: this would hit the redis lock, whose __enter__ raises LockNotOwnedError.
# Our implementation should catch it and still return (documents, batch).
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=account,
)
# Assert
# We mainly care that:
# - No exception is raised
# - The function returns a sensible tuple
assert isinstance(documents, list)
assert isinstance(batch, str)
# ---------------------------------------------------------------------------
# 2. Single-segment creation (add_segment)
# ---------------------------------------------------------------------------
def test_add_segment_ignores_lock_not_owned(
monkeypatch,
fake_current_user,
fake_lock,
):
# Arrange
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.indexing_technique = "economy" # skip embedding/token calculation branch
document = create_autospec(Document, instance=True)
document.id = "doc-1"
document.dataset_id = dataset.id
document.word_count = 0
document.doc_form = "qa_model"
# Minimal args required by add_segment
args = {
"content": "question text",
"answer": "answer text",
"keywords": ["k1", "k2"],
}
# Avoid real DB operations
db_mock = Mock()
db_mock.session = Mock()
monkeypatch.setattr("services.dataset_service.db", db_mock)
monkeypatch.setattr("services.dataset_service.VectorService", Mock())
# Act
result = SegmentService.create_segment(args=args, document=document, dataset=dataset)
# Assert
# Under LockNotOwnedError except, add_segment should swallow the error and return None.
assert result is None
# ---------------------------------------------------------------------------
# 3. Multi-segment creation (multi_create_segment)
# ---------------------------------------------------------------------------
def test_multi_create_segment_ignores_lock_not_owned(
monkeypatch,
fake_current_user,
fake_lock,
):
# Arrange
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.indexing_technique = "economy" # again, skip high_quality path
document = create_autospec(Document, instance=True)
document.id = "doc-1"
document.dataset_id = dataset.id
document.word_count = 0
document.doc_form = "qa_model"