chore: DocumentSegment to Typebase (#35635)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-05-12 16:02:17 +09:00 committed by GitHub
parent bb73776339
commit 51a8f79d67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 132 additions and 174 deletions

View File

@ -245,6 +245,7 @@ class Jieba(BaseKeyword):
segment = pre_segment_data["segment"]
if pre_segment_data["keywords"]:
segment.keywords = pre_segment_data["keywords"]
assert segment.index_node_id
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
@ -253,6 +254,7 @@ class Jieba(BaseKeyword):
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
segment.keywords = list(keywords)
assert segment.index_node_id
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, segment.index_node_id, list(keywords)
)

View File

@ -1,5 +1,6 @@
import concurrent.futures
import logging
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, NotRequired, TypedDict
@ -526,7 +527,7 @@ class RetrievalService:
index_node_ids = [i for i in index_node_ids if i]
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
index_node_segments: Sequence[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
@ -568,8 +569,9 @@ class RetrievalService:
DocumentSegment.status == "completed",
DocumentSegment.index_node_id.in_(index_node_ids),
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
index_node_segments = session.execute(document_segment_stmt).scalars().all()
for index_node_segment in index_node_segments:
assert index_node_segment.index_node_id
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids:

View File

@ -50,6 +50,7 @@ class DatasetDocumentStore:
output = {}
for document_segment in document_segments:
assert document_segment.index_node_id
doc_id = document_segment.index_node_id
output[doc_id] = Document(
page_content=document_segment.content,
@ -103,7 +104,7 @@ class DatasetDocumentStore:
if not segment_document:
max_position += 1
assert self._document_id
segment_document = DocumentSegment(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,

View File

@ -84,7 +84,7 @@ class IndexProcessor:
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
indexing_start_at = time.perf_counter()
# delete from vector index

View File

@ -8,7 +8,6 @@ import os
import pickle
import re
import time
from collections.abc import Sequence
from datetime import datetime
from json import JSONDecodeError
from typing import Any, ClassVar, TypedDict, cast
@ -831,7 +830,7 @@ class Document(Base):
)
class DocumentSegment(Base):
class DocumentSegment(TypeBase):
__tablename__ = "document_segments"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_pkey"),
@ -844,35 +843,40 @@ class DocumentSegment(Base):
)
# initial fields
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int]
content = mapped_column(LongText, nullable=False)
answer = mapped_column(LongText, nullable=True)
content: Mapped[str] = mapped_column(LongText, nullable=False)
word_count: Mapped[int]
tokens: Mapped[int]
# indexing fields
keywords = mapped_column(sa.JSON, nullable=True)
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
# basic fields
# indexing fields
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
keywords: Mapped[Any] = mapped_column(sa.JSON, nullable=True, default=None)
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
status: Mapped[SegmentStatus] = mapped_column(
EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"), default=SegmentStatus.WAITING
)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(LongText, nullable=True)
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
@property
def dataset(self):
@ -899,7 +903,7 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self) -> Sequence[Any]:
def child_chunks(self):
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -914,7 +918,7 @@ class DocumentSegment(Base):
return child_chunks or []
return []
def get_child_chunks(self) -> Sequence[Any]:
def get_child_chunks(self):
if not self.document:
return []
process_rule = self.document.dataset_process_rule

View File

@ -111,6 +111,7 @@ class VectorService:
"dataset_id": segment.dataset_id,
},
)
assert segment.index_node_id
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
vector = Vector(dataset=dataset)
@ -138,6 +139,7 @@ class VectorService:
regenerate: bool = False,
):
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
assert segment.index_node_id
if regenerate:
# delete child chunks
index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)

View File

@ -50,7 +50,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
segment_ids = [segment.id for segment in segments]
# Collect image file IDs from segment content

View File

@ -19,6 +19,7 @@ from graphon.model_runtime.entities.model_entities import ModelType
from libs import helper
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import SegmentStatus
from models.model import UploadFile
from services.vector_service import VectorService
@ -156,7 +157,7 @@ def batch_create_segment_to_index_task(
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
status=SegmentStatus.COMPLETED,
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == IndexStructureType.QA_INDEX:

View File

@ -53,7 +53,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
segment_contents = [segment.content for segment in segments]
except Exception:
logger.exception("Cleaned document when document deleted failed")

View File

@ -38,7 +38,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
for document_id in document_ids:
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
total_index_node_ids.extend([segment.index_node_id for segment in segments])
total_index_node_ids.extend([segment.index_node_id for segment in segments if segment.index_node_id])
# Wrap vector / keyword index cleanup in try/except so that a transient
# failure here (e.g. billing API hiccup propagated via FeatureService when

View File

@ -9,6 +9,7 @@ from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
from models.enums import SegmentStatus
logger = logging.getLogger(__name__)
@ -30,7 +31,7 @@ def disable_segment_from_index_task(segment_id: str):
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
if segment.status != "completed":
if segment.status != SegmentStatus.COMPLETED:
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
return
@ -59,6 +60,7 @@ def disable_segment_from_index_task(segment_id: str):
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
assert segment.index_node_id
index_processor.clean(dataset, [segment.index_node_id])
# Disable summary index for this segment

View File

@ -55,7 +55,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = session.scalars(

View File

@ -69,7 +69,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()

View File

@ -45,7 +45,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
clean_success = False
try:

View File

@ -137,7 +137,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

View File

@ -61,7 +61,7 @@ def remove_document_from_index_task(document_id: str):
except Exception as e:
logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e))
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)

View File

@ -85,7 +85,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

View File

@ -70,7 +70,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

View File

@ -13,9 +13,9 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom
from models.enums import DataSourceType, DocumentCreatedFrom, SegmentStatus
from services.dataset_service import SegmentService
@ -35,13 +35,13 @@ class SegmentServiceTestDataFactory:
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
if tenant is None:
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
tenant = Tenant(name=f"tenant-{uuid4()}", status=TenantStatus.NORMAL)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
@ -103,7 +103,7 @@ class SegmentServiceTestDataFactory:
created_by: str,
position: int = 1,
content: str = "Test content",
status: str = "completed",
status: SegmentStatus = SegmentStatus.COMPLETED,
word_count: int = 10,
tokens: int = 15,
) -> DocumentSegment:
@ -203,7 +203,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=1,
status="completed",
status=SegmentStatus.COMPLETED,
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
@ -212,7 +212,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=2,
status="indexing",
status=SegmentStatus.INDEXING,
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
@ -221,7 +221,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=3,
status="waiting",
status=SegmentStatus.WAITING,
)
# Act
@ -257,7 +257,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=1,
status="completed",
status=SegmentStatus.COMPLETED,
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
@ -266,7 +266,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=2,
status="indexing",
status=SegmentStatus.INDEXING,
)
# Act
@ -415,7 +415,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=1,
status="completed",
status=SegmentStatus.COMPLETED,
content="This is important information",
)
SegmentServiceTestDataFactory.create_segment(
@ -425,7 +425,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=2,
status="indexing",
status=SegmentStatus.INDEXING,
content="This is also important",
)
SegmentServiceTestDataFactory.create_segment(
@ -435,7 +435,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=3,
status="completed",
status=SegmentStatus.COMPLETED,
content="This is irrelevant",
)
@ -477,7 +477,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=1,
status="completed",
status=SegmentStatus.COMPLETED,
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
@ -486,7 +486,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=2,
status="waiting",
status=SegmentStatus.WAITING,
)
# Act

View File

@ -128,7 +128,6 @@ class TestAddDocumentToIndexTask:
for i in range(3):
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -451,7 +450,6 @@ class TestAddDocumentToIndexTask:
segments = []
for i in range(3):
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -630,7 +628,6 @@ class TestAddDocumentToIndexTask:
# Segment 1: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
segment1 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -650,7 +647,6 @@ class TestAddDocumentToIndexTask:
# Segment 2: Should be processed (enabled=True, status=SegmentStatus.COMPLETED)
# Note: Implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED
segment2 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -669,7 +665,6 @@ class TestAddDocumentToIndexTask:
# Segment 3: Should NOT be processed (enabled=False, status="processing")
segment3 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -688,7 +683,6 @@ class TestAddDocumentToIndexTask:
# Segment 4: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
segment4 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,

View File

@ -177,7 +177,6 @@ class TestBatchCleanDocumentTask:
fake = Faker()
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=document.dataset_id,
document_id=document.id,
@ -290,10 +289,9 @@ class TestBatchCleanDocumentTask:
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
document = self._create_test_document(db_session_with_containers, dataset, account)
assert account.current_tenant
# Create segment with simple content (no image references)
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=document.dataset_id,
document_id=document.id,
@ -692,9 +690,9 @@ class TestBatchCleanDocumentTask:
# Create multiple segments for the document
segments = []
assert account.current_tenant
for i in range(3):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=document.dataset_id,
document_id=document.id,

View File

@ -220,7 +220,6 @@ class TestCleanDatasetTask:
DocumentSegment: Created document segment instance
"""
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -232,8 +231,6 @@ class TestCleanDatasetTask:
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash",
created_at=datetime.now(),
updated_at=datetime.now(),
)
db_session_with_containers.add(segment)
@ -614,7 +611,6 @@ class TestCleanDatasetTask:
"""
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -626,8 +622,6 @@ class TestCleanDatasetTask:
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash",
created_at=datetime.now(),
updated_at=datetime.now(),
)
db_session_with_containers.add(segment)
@ -729,8 +723,6 @@ class TestCleanDatasetTask:
type=DatasetMetadataType.STRING,
created_by=account.id,
)
metadata.id = str(uuid.uuid4())
metadata.created_at = datetime.now()
metadata_items.append(metadata)
# Create binding for each metadata item
@ -741,8 +733,6 @@ class TestCleanDatasetTask:
document_id=documents[i % len(documents)].id,
created_by=account.id,
)
binding.id = str(uuid.uuid4())
binding.created_at = datetime.now()
bindings.append(binding)
db_session_with_containers.add_all(metadata_items)
@ -946,7 +936,6 @@ class TestCleanDatasetTask:
long_content = "Very long content " * 100 # Long content within reasonable limits
segment_content = f"Segment with special chars: {special_content}\n{long_content}"
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -958,8 +947,6 @@ class TestCleanDatasetTask:
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash_" + "x" * 50, # Long hash within limits
created_at=datetime.now(),
updated_at=datetime.now(),
)
db_session_with_containers.add(segment)
db_session_with_containers.commit()

View File

@ -132,11 +132,10 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.add(document)
db_session_with_containers.flush()
document_ids.append(document.id)
assert tenant
# Create segments for each document
for j in range(2):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -297,10 +296,9 @@ class TestCleanNotionDocumentTask:
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
assert tenant
# Create test segment
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -379,12 +377,11 @@ class TestCleanNotionDocumentTask:
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
assert tenant
# Create segments without index_node_ids
segments = []
for i in range(3):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -468,11 +465,10 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.add(document)
db_session_with_containers.flush()
documents.append(document)
assert tenant
# Create segments for each document
for j in range(2):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -569,10 +565,9 @@ class TestCleanNotionDocumentTask:
segment_statuses = [SegmentStatus.WAITING, SegmentStatus.INDEXING, SegmentStatus.COMPLETED, SegmentStatus.ERROR]
segments = []
index_node_ids = []
assert tenant
for i, status in enumerate(segment_statuses):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -665,10 +660,9 @@ class TestCleanNotionDocumentTask:
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
assert tenant
# Create segment
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -765,12 +759,11 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.add(document)
db_session_with_containers.flush()
documents.append(document)
assert tenant
# Create multiple segments for each document
num_segments_per_doc = 5
for j in range(num_segments_per_doc):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -875,7 +868,6 @@ class TestCleanNotionDocumentTask:
# Create segments for each document
for j in range(3):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -984,11 +976,10 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.add(document)
db_session_with_containers.flush()
documents.append(document)
assert tenant
# Create segments for each document
for j in range(2):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -1093,10 +1084,9 @@ class TestCleanNotionDocumentTask:
# Create segments with metadata
segments = []
index_node_ids = []
assert tenant
for i in range(3):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,

View File

@ -90,7 +90,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -150,7 +149,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -202,7 +200,6 @@ class TestDealDatasetVectorIndexTask:
# Create segments
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -253,7 +250,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset with parent-child index
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -305,7 +301,6 @@ class TestDealDatasetVectorIndexTask:
# Create segments
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -371,7 +366,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset without documents
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -403,7 +397,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -461,7 +454,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset without documents
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -494,7 +486,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -546,7 +537,6 @@ class TestDealDatasetVectorIndexTask:
# Create segments
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -592,7 +582,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset with custom index type
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -624,7 +613,6 @@ class TestDealDatasetVectorIndexTask:
# Create segments
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -670,7 +658,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset without doc_form (should use default)
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -702,7 +689,6 @@ class TestDealDatasetVectorIndexTask:
# Create segments
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -748,7 +734,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -806,7 +791,6 @@ class TestDealDatasetVectorIndexTask:
for i, document in enumerate(documents):
for j in range(2):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -832,6 +816,7 @@ class TestDealDatasetVectorIndexTask:
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == document.id).limit(1)
)
assert updated_document
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor load was called multiple times
@ -853,7 +838,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -905,7 +889,6 @@ class TestDealDatasetVectorIndexTask:
# Create segments
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -952,7 +935,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -1024,7 +1006,6 @@ class TestDealDatasetVectorIndexTask:
# Create segments for enabled document only
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=enabled_document.id,
@ -1075,7 +1056,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -1147,7 +1127,6 @@ class TestDealDatasetVectorIndexTask:
# Create segments for active document only
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=active_document.id,
@ -1198,7 +1177,6 @@ class TestDealDatasetVectorIndexTask:
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
@ -1270,7 +1248,6 @@ class TestDealDatasetVectorIndexTask:
# Create segments for completed document only
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=completed_document.id,

View File

@ -209,26 +209,25 @@ class TestDeleteSegmentFromIndexTask:
segments = []
for i in range(count):
segment = DocumentSegment()
segment.id = fake.uuid4()
segment.tenant_id = document.tenant_id
segment.dataset_id = document.dataset_id
segment.document_id = document.id
segment.position = i + 1
segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}"
segment.answer = f"Test segment answer {i + 1}: {fake.text(max_nb_chars=100)}"
segment.word_count = fake.random_int(min=10, max=100)
segment.tokens = fake.random_int(min=5, max=50)
segment.keywords = [fake.word() for _ in range(3)]
segment.index_node_id = f"node_{fake.uuid4()}"
segment.index_node_hash = fake.sha256()
segment.hit_count = 0
segment.enabled = True
segment.status = SegmentStatus.COMPLETED
segment.created_by = account.id
segment.created_at = fake.date_time_this_year()
segment.updated_by = account.id
segment.updated_at = segment.created_at
created_at = fake.date_time_this_year()
segment = DocumentSegment(
tenant_id=document.tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
position=i + 1,
content=f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}",
answer=f"Test segment answer {i + 1}: {fake.text(max_nb_chars=100)}",
word_count=fake.random_int(min=10, max=100),
tokens=fake.random_int(min=5, max=50),
keywords=[fake.word() for _ in range(3)],
index_node_id=f"node_{fake.uuid4()}",
index_node_hash=fake.sha256(),
hit_count=0,
enabled=True,
status=SegmentStatus.COMPLETED,
created_by=account.id,
updated_by=account.id,
)
db_session_with_containers.add(segment)
segments.append(segment)

View File

@ -159,7 +159,7 @@ class TestDisableSegmentFromIndexTask:
dataset: Dataset,
tenant: Tenant,
account: Account,
status: str = "completed",
status: SegmentStatus = SegmentStatus.COMPLETED,
enabled: bool = True,
) -> DocumentSegment:
"""

View File

@ -185,30 +185,31 @@ class TestDisableSegmentsFromIndexTask:
segments = []
for i in range(count):
segment = DocumentSegment()
segment.id = fake.uuid4()
segment.tenant_id = dataset.tenant_id
segment.dataset_id = dataset.id
segment.document_id = document.id
segment.position = i + 1
segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}"
segment.answer = f"Test answer {i + 1}" if i % 2 == 0 else None
segment.word_count = fake.random_int(min=10, max=100)
segment.tokens = fake.random_int(min=5, max=50)
segment.keywords = [fake.word() for _ in range(3)]
segment.index_node_id = f"node_{segment.id}"
segment.index_node_hash = fake.sha256()
segment.hit_count = 0
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
segment.status = SegmentStatus.COMPLETED
segment.created_by = account.id
segment.updated_by = account.id
segment.indexing_at = fake.date_time_this_year()
segment.completed_at = fake.date_time_this_year()
segment.error = None
segment.stopped_at = None
id = fake.uuid4()
segment = DocumentSegment(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=i + 1,
content=f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}",
answer=f"Test answer {i + 1}" if i % 2 == 0 else None,
word_count=fake.random_int(min=10, max=100),
tokens=fake.random_int(min=5, max=50),
keywords=[fake.word() for _ in range(3)],
index_node_id=f"node_{id}",
index_node_hash=fake.sha256(),
hit_count=0,
enabled=True,
disabled_at=None,
disabled_by=None,
status=SegmentStatus.COMPLETED,
created_by=account.id,
updated_by=account.id,
indexing_at=fake.date_time_this_year(),
completed_at=fake.date_time_this_year(),
error=None,
stopped_at=None,
)
segments.append(segment)

View File

@ -175,7 +175,6 @@ class TestDuplicateDocumentIndexingTasks:
for document in documents:
for i in range(segments_per_doc):
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,

View File

@ -139,7 +139,6 @@ class TestEnableSegmentsToIndexTask:
for i in range(count):
text = fake.text(max_nb_chars=200)
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,

View File

@ -282,7 +282,6 @@ class TestSegmentServiceQueries:
def test_get_segment_by_id_returns_only_document_segment_instances(self):
segment = DocumentSegment(
id="segment-1",
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",
@ -292,7 +291,7 @@ class TestSegmentServiceQueries:
tokens=2,
created_by="user-1",
)
segment.id = "segment-1"
with patch("services.dataset_service.db") as mock_db:
mock_db.session.scalar.return_value = segment
result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
@ -307,7 +306,6 @@ class TestSegmentServiceQueries:
def test_get_segments_by_document_and_dataset_returns_scalars_result(self):
segment = DocumentSegment(
id="segment-1",
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",
@ -318,6 +316,7 @@ class TestSegmentServiceQueries:
created_by="user-1",
)
segment.id = "segment-1"
with patch("services.dataset_service.db") as mock_db:
mock_db.session.scalars.return_value.all.return_value = [segment]
@ -461,6 +460,7 @@ class TestSegmentServiceMutations:
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
result = SegmentService.multi_create_segment(segments, document, dataset)
assert result
assert len(result) == 2
assert [segment.position for segment in result] == [2, 3]