refactor: use EnumText for dataset and replace string literals 4 (#33606)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
tmimmanuel 2026-03-18 00:18:08 +00:00 committed by GitHub
parent 0bc6c3a73e
commit 3870b2ad2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
69 changed files with 1027 additions and 849 deletions

View File

@ -14,6 +14,7 @@ from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -242,7 +243,7 @@ def migrate_knowledge_vector_database():
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
@ -254,7 +255,7 @@ def migrate_knowledge_vector_database():
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.status == SegmentStatus.COMPLETED,
DocumentSegment.enabled == True,
)
).all()
@ -430,7 +431,7 @@ def old_metadata_migration():
tenant_id=document.tenant_id,
dataset_id=document.dataset_id,
name=key,
type="string",
type=DatasetMetadataType.STRING,
created_by=document.created_by,
)
db.session.add(dataset_metadata)

View File

@ -54,6 +54,7 @@ from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -741,13 +742,15 @@ class DatasetIndexingStatusApi(Resource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields

View File

@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.enums import IndexingStatus, SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
@ -332,13 +333,16 @@ class DatasetDocumentListApi(Resource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
document.completed_segments = completed_segments
@ -503,7 +507,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
@ -573,7 +577,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
extract_settings = []
for document in documents:
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
@ -671,19 +675,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@ -720,20 +726,20 @@ class DocumentIndexingStatusApi(DocumentResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@ -955,7 +961,7 @@ class DocumentProcessingApi(DocumentResource):
match action:
case "pause":
if document.indexing_status != "indexing":
if document.indexing_status != IndexingStatus.INDEXING:
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
@ -964,7 +970,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit()
case "resume":
if document.indexing_status not in {"paused", "error"}:
if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
@ -1169,7 +1175,7 @@ class DocumentRetryApi(DocumentResource):
raise ArchivedDocumentImmutableError()
# 400 if document is completed
if document.indexing_status == "completed":
if document.indexing_status == IndexingStatus.COMPLETED:
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception:

View File

@ -36,6 +36,7 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig,
@ -622,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields

View File

@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import CollectionBindingType
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
@ -43,7 +44,7 @@ class AnnotationReplyFeature:
embedding_model_name = collection_binding_detail.model_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
)
dataset = Dataset(

View File

@ -12,7 +12,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
_logger = logging.getLogger(__name__)
@ -36,7 +36,7 @@ class DatasetIndexToolCallbackHandler:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source="app",
source=DatasetQuerySource.APP,
source_app_id=self._app_id,
created_by_role=(
CreatorUserRole.ACCOUNT

View File

@ -40,6 +40,7 @@ from libs.datetime_utils import naive_utc_now
from models import Account
from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
from models.model import UploadFile
from services.feature_service import FeatureService
@ -56,7 +57,7 @@ class IndexingRunner:
logger.exception("consume document failed")
document = db.session.get(DatasetDocument, document_id)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
error_message = getattr(error, "description", str(error))
document.error = str(error_message)
document.stopped_at = naive_utc_now()
@ -219,7 +220,7 @@ class IndexingRunner:
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
if document_segment.status != SegmentStatus.COMPLETED:
document = Document(
page_content=document_segment.content,
metadata={
@ -382,7 +383,7 @@ class IndexingRunner:
data_source_info = dataset_document.data_source_info_dict
text_docs = []
match dataset_document.data_source_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
@ -395,7 +396,7 @@ class IndexingRunner:
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "notion_import":
case DataSourceType.NOTION_IMPORT:
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
@ -417,7 +418,7 @@ class IndexingRunner:
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
if (
not data_source_info
or "provider" not in data_source_info
@ -445,7 +446,7 @@ class IndexingRunner:
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="splitting",
after_indexing_status=IndexingStatus.SPLITTING,
extra_update_params={
DatasetDocument.parsing_completed_at: naive_utc_now(),
},
@ -545,7 +546,7 @@ class IndexingRunner:
Clean the document text according to the processing rules.
"""
rules: AutomaticRulesConfig | dict[str, Any]
if processing_rule.mode == "automatic":
if processing_rule.mode == ProcessRuleMode.AUTOMATIC:
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
@ -636,7 +637,7 @@ class IndexingRunner:
# update document status to completed
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="completed",
after_indexing_status=IndexingStatus.COMPLETED,
extra_update_params={
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: naive_utc_now(),
@ -659,10 +660,10 @@ class IndexingRunner:
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing",
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
@ -703,10 +704,10 @@ class IndexingRunner:
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing",
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
@ -725,7 +726,7 @@ class IndexingRunner:
@staticmethod
def _update_document_index_status(
document_id: str, after_indexing_status: str, extra_update_params: dict | None = None
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
):
"""
Update the document indexing status.
@ -803,7 +804,7 @@ class IndexingRunner:
cur_time = naive_utc_now()
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
after_indexing_status=IndexingStatus.INDEXING,
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
@ -815,7 +816,7 @@ class IndexingRunner:
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.status: SegmentStatus.INDEXING,
DocumentSegment.indexing_at: naive_utc_now(),
},
)

View File

@ -83,7 +83,7 @@ from models.dataset import (
)
from models.dataset import Document as DatasetDocument
from models.dataset import Document as DocumentModel
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
@ -1008,7 +1008,7 @@ class DatasetRetrieval:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=json.dumps(contents),
source="app",
source=DatasetQuerySource.APP,
source_app_id=app_id,
created_by_role=CreatorUserRole(user_from),
created_by=user_id,

View File

@ -10,6 +10,7 @@ from events.document_index_event import document_index_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Document
from models.enums import IndexingStatus
logger = logging.getLogger(__name__)
@ -35,7 +36,7 @@ def handle(sender, **kwargs):
if not document:
raise NotFound("Document not found")
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)

View File

@ -31,7 +31,20 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
from .account import Account
from .base import Base, TypeBase
from .engine import db
from .enums import CreatorUserRole
from .enums import (
CollectionBindingType,
CreatorUserRole,
DatasetMetadataType,
DatasetQuerySource,
DatasetRuntimeMode,
DataSourceType,
DocumentCreatedFrom,
DocumentDocType,
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SummaryStatus,
)
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
@ -121,7 +134,7 @@ class Dataset(Base):
server_default=sa.text("'only_me'"),
default=DatasetPermissionEnum.ONLY_ME,
)
data_source_type = mapped_column(String(255))
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
indexing_technique: Mapped[str | None] = mapped_column(String(255))
index_struct = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
@ -138,7 +151,9 @@ class Dataset(Base):
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(AdjustedJSON, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
runtime_mode = mapped_column(
EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'")
)
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(sa.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@ -379,7 +394,7 @@ class DatasetProcessRule(Base): # bug
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
dataset_id = mapped_column(StringUUID, nullable=False)
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
rules = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -425,12 +440,12 @@ class Document(Base):
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False)
data_source_info = mapped_column(LongText, nullable=True)
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
batch: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_api_request_id = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -464,7 +479,9 @@ class Document(Base):
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# basic fields
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'"))
indexing_status = mapped_column(
EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'")
)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
@ -475,7 +492,7 @@ class Document(Base):
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
doc_type = mapped_column(String(40), nullable=True)
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)
@ -784,7 +801,7 @@ class DocumentSegment(Base):
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'"))
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -1048,7 +1065,7 @@ class DatasetQuery(TypeBase):
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
source: Mapped[str] = mapped_column(String(255), nullable=False)
source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False)
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1193,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase):
)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
type: Mapped[str] = mapped_column(
EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False
)
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@ -1420,7 +1439,7 @@ class DatasetMetadata(TypeBase):
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
@ -1647,7 +1666,9 @@ class DocumentSegmentSummary(Base):
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
status: Mapped[str] = mapped_column(
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
)
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)

View File

@ -215,6 +215,8 @@ class SegmentStatus(StrEnum):
INDEXING = "indexing"
COMPLETED = "completed"
ERROR = "error"
PAUSED = "paused"
RE_SEGMENT = "re_segment"
class DatasetRuntimeMode(StrEnum):
@ -282,6 +284,7 @@ class SummaryStatus(StrEnum):
GENERATING = "generating"
COMPLETED = "completed"
ERROR = "error"
TIMEOUT = "timeout"
class MessageChainType(StrEnum):

View File

@ -51,6 +51,14 @@ from models.dataset import (
Pipeline,
SegmentAttachmentBinding,
)
from models.enums import (
DatasetRuntimeMode,
DataSourceType,
DocumentCreatedFrom,
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
@ -319,7 +327,7 @@ class DatasetService:
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag_pipeline",
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
created_by=current_user.id,
pipeline_id=pipeline.id,
@ -614,7 +622,7 @@ class DatasetService:
"""
Update pipeline knowledge base node data.
"""
if dataset.runtime_mode != "rag_pipeline":
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
return
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
@ -1229,10 +1237,15 @@ class DocumentService:
"enabled": "available",
}
_INDEXING_STATUSES: tuple[str, ...] = ("parsing", "cleaning", "splitting", "indexing")
_INDEXING_STATUSES: tuple[IndexingStatus, ...] = (
IndexingStatus.PARSING,
IndexingStatus.CLEANING,
IndexingStatus.SPLITTING,
IndexingStatus.INDEXING,
)
DISPLAY_STATUS_FILTERS: dict[str, tuple[Any, ...]] = {
"queuing": (Document.indexing_status == "waiting",),
"queuing": (Document.indexing_status == IndexingStatus.WAITING,),
"indexing": (
Document.indexing_status.in_(_INDEXING_STATUSES),
Document.is_paused.is_not(True),
@ -1241,19 +1254,19 @@ class DocumentService:
Document.indexing_status.in_(_INDEXING_STATUSES),
Document.is_paused.is_(True),
),
"error": (Document.indexing_status == "error",),
"error": (Document.indexing_status == IndexingStatus.ERROR,),
"available": (
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived.is_(False),
Document.enabled.is_(True),
),
"disabled": (
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived.is_(False),
Document.enabled.is_(False),
),
"archived": (
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived.is_(True),
),
}
@ -1536,7 +1549,7 @@ class DocumentService:
"""
Normalize and validate `Document -> UploadFile` linkage for download flows.
"""
if document.data_source_type != "upload_file":
if document.data_source_type != DataSourceType.UPLOAD_FILE:
raise NotFound(invalid_source_message)
data_source_info: dict[str, Any] = document.data_source_info_dict or {}
@ -1617,7 +1630,7 @@ class DocumentService:
select(Document).where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived == False,
)
).all()
@ -1640,7 +1653,7 @@ class DocumentService:
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived == False,
)
).all()
@ -1650,7 +1663,10 @@ class DocumentService:
@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
select(Document).where(
Document.dataset_id == dataset_id,
Document.indexing_status.in_([IndexingStatus.ERROR, IndexingStatus.PAUSED]),
)
).all()
return documents
@ -1683,7 +1699,7 @@ class DocumentService:
def delete_document(document):
# trigger document_was_deleted signal
file_id = None
if document.data_source_type == "upload_file":
if document.data_source_type == DataSourceType.UPLOAD_FILE:
if document.data_source_info:
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
@ -1704,7 +1720,7 @@ class DocumentService:
file_ids = [
document.data_source_info_dict.get("upload_file_id", "")
for document in documents
if document.data_source_type == "upload_file" and document.data_source_info_dict
if document.data_source_type == DataSourceType.UPLOAD_FILE and document.data_source_info_dict
]
# Delete documents first, then dispatch cleanup task after commit
@ -1753,7 +1769,13 @@ class DocumentService:
@staticmethod
def pause_document(document):
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
if document.indexing_status not in {
IndexingStatus.WAITING,
IndexingStatus.PARSING,
IndexingStatus.CLEANING,
IndexingStatus.SPLITTING,
IndexingStatus.INDEXING,
}:
raise DocumentIndexingError()
# update document to be paused
assert current_user is not None
@ -1793,7 +1815,7 @@ class DocumentService:
if cache_result is not None:
raise ValueError("Document is being retried, please try again later")
# retry document indexing
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)
db.session.commit()
@ -1812,7 +1834,7 @@ class DocumentService:
if cache_result is not None:
raise ValueError("Document is being synced, please try again later")
# sync document indexing
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
data_source_info = document.data_source_info_dict
if data_source_info:
data_source_info["mode"] = "scrape"
@ -1840,7 +1862,7 @@ class DocumentService:
knowledge_config: KnowledgeConfig,
account: Account | Any,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
created_from: str = DocumentCreatedFrom.WEB,
) -> tuple[list[Document], str]:
# check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
@ -1932,7 +1954,7 @@ class DocumentService:
if not dataset_process_rule:
process_rule = knowledge_config.process_rule
if process_rule:
if process_rule.mode in ("custom", "hierarchical"):
if process_rule.mode in (ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL):
if process_rule.rules:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
@ -1944,7 +1966,7 @@ class DocumentService:
dataset_process_rule = dataset.latest_process_rule
if not dataset_process_rule:
raise ValueError("No process rule found.")
elif process_rule.mode == "automatic":
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
@ -1967,7 +1989,7 @@ class DocumentService:
if not dataset_process_rule:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode="automatic",
mode=ProcessRuleMode.AUTOMATIC,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
@ -2001,7 +2023,7 @@ class DocumentService:
.where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == "upload_file",
Document.data_source_type == DataSourceType.UPLOAD_FILE,
Document.enabled == True,
Document.name.in_(file_names),
)
@ -2021,7 +2043,7 @@ class DocumentService:
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
@ -2056,7 +2078,7 @@ class DocumentService:
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
enabled=True,
)
.all()
@ -2507,7 +2529,7 @@ class DocumentService:
document_data: KnowledgeConfig,
account: Account,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
created_from: str = DocumentCreatedFrom.WEB,
):
assert isinstance(current_user, Account)
@ -2520,14 +2542,14 @@ class DocumentService:
# save process rule
if document_data.process_rule:
process_rule = document_data.process_rule
if process_rule.mode in {"custom", "hierarchical"}:
if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == "automatic":
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
@ -2609,7 +2631,7 @@ class DocumentService:
if document_data.name:
document.name = document_data.name
# update document to be waiting
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
document.completed_at = None
document.processing_started_at = None
document.parsing_completed_at = None
@ -2623,7 +2645,7 @@ class DocumentService:
# update document segment
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
{DocumentSegment.status: "re_segment"}
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
)
db.session.commit()
# trigger async task
@ -2754,7 +2776,7 @@ class DocumentService:
if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if knowledge_config.process_rule.mode == "automatic":
if knowledge_config.process_rule.mode == ProcessRuleMode.AUTOMATIC:
knowledge_config.process_rule.rules = None
else:
if not knowledge_config.process_rule.rules:
@ -2785,7 +2807,7 @@ class DocumentService:
raise ValueError("Process rule segmentation separator is invalid")
if not (
knowledge_config.process_rule.mode == "hierarchical"
knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL
and knowledge_config.process_rule.rules.parent_mode == "full-doc"
):
if not knowledge_config.process_rule.rules.segmentation.max_tokens:
@ -2814,7 +2836,7 @@ class DocumentService:
if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if args["process_rule"]["mode"] == "automatic":
if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC:
args["process_rule"]["rules"] = {}
else:
if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
@ -3021,7 +3043,7 @@ class DocumentService:
@staticmethod
def _prepare_disable_update(document, user, now):
"""Prepare updates for disabling a document."""
if not document.completed_at or document.indexing_status != "completed":
if not document.completed_at or document.indexing_status != IndexingStatus.COMPLETED:
raise DocumentIndexingError(f"Document: {document.name} is not completed.")
if not document.enabled:
@ -3130,7 +3152,7 @@ class SegmentService:
content=content,
word_count=len(content),
tokens=tokens,
status="completed",
status=SegmentStatus.COMPLETED,
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
@ -3167,7 +3189,7 @@ class SegmentService:
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.status = SegmentStatus.ERROR
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
@ -3227,7 +3249,7 @@ class SegmentService:
word_count=len(content),
tokens=tokens,
keywords=segment_item.get("keywords", []),
status="completed",
status=SegmentStatus.COMPLETED,
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
@ -3259,7 +3281,7 @@ class SegmentService:
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.status = SegmentStatus.ERROR
segment_document.error = str(e)
db.session.commit()
return segment_data_list
@ -3405,7 +3427,7 @@ class SegmentService:
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = "completed"
segment.status = SegmentStatus.COMPLETED
segment.indexing_at = naive_utc_now()
segment.completed_at = naive_utc_now()
segment.updated_by = current_user.id
@ -3530,7 +3552,7 @@ class SegmentService:
logger.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.status = SegmentStatus.ERROR
segment.error = str(e)
db.session.commit()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()

View File

@ -13,7 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode
from extensions.ext_database import db
from models import Account
from models.dataset import Dataset, DatasetQuery
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
logger = logging.getLogger(__name__)
@ -97,7 +97,7 @@ class HitTestingService:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=json.dumps(dataset_queries),
source="hit_testing",
source=DatasetQuerySource.HIT_TESTING,
source_app_id=None,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
@ -137,7 +137,7 @@ class HitTestingService:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source="hit_testing",
source=DatasetQuerySource.HIT_TESTING,
source_app_id=None,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,

View File

@ -7,6 +7,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
from models.enums import DatasetMetadataType
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
@ -130,11 +131,11 @@ class MetadataService:
@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name, "type": "string"},
{"name": BuiltInField.uploader, "type": "string"},
{"name": BuiltInField.upload_date, "type": "time"},
{"name": BuiltInField.last_update_date, "type": "time"},
{"name": BuiltInField.source, "type": "string"},
{"name": BuiltInField.document_name, "type": DatasetMetadataType.STRING},
{"name": BuiltInField.uploader, "type": DatasetMetadataType.STRING},
{"name": BuiltInField.upload_date, "type": DatasetMetadataType.TIME},
{"name": BuiltInField.last_update_date, "type": DatasetMetadataType.TIME},
{"name": BuiltInField.source, "type": DatasetMetadataType.STRING},
]
@staticmethod

View File

@ -6,6 +6,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.enums import IndexingStatus
from models.model import Account, App, EndUser
from models.workflow import Workflow
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -111,6 +112,6 @@ class PipelineGenerateService:
"""
document = db.session.query(Document).where(Document.id == document_id).first()
if document:
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)
db.session.commit()

View File

@ -64,7 +64,7 @@ from models.dataset import ( # type: ignore
PipelineCustomizedTemplate,
PipelineRecommendedPlugin,
)
from models.enums import WorkflowRunTriggeredFrom
from models.enums import IndexingStatus, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
Workflow,
@ -906,7 +906,7 @@ class RagPipelineService:
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = error
db.session.add(document)
db.session.commit()

View File

@ -35,6 +35,7 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.enums import CollectionBindingType, DatasetRuntimeMode
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import (
IconInfo,
@ -313,7 +314,7 @@ class RagPipelineDslService:
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.indexing_technique == "high_quality":
@ -323,7 +324,7 @@ class RagPipelineDslService:
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
@ -334,7 +335,7 @@ class RagPipelineDslService:
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
type=CollectionBindingType.DATASET,
)
self._session.add(dataset_collection_binding)
self._session.commit()
@ -445,13 +446,13 @@ class RagPipelineDslService:
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
@ -460,7 +461,7 @@ class RagPipelineDslService:
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
@ -471,7 +472,7 @@ class RagPipelineDslService:
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
type=CollectionBindingType.DATASET,
)
self._session.add(dataset_collection_binding)
self._session.commit()

View File

@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import DatasetRuntimeMode, DataSourceType
from models.model import UploadFile
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
@ -27,7 +28,7 @@ class RagPipelineTransformService:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE:
return {
"pipeline_id": dataset.pipeline_id,
"dataset_id": dataset_id,
@ -85,7 +86,7 @@ class RagPipelineTransformService:
else:
raise ValueError("Unsupported doc form")
dataset.runtime_mode = "rag_pipeline"
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.pipeline_id = pipeline.id
# deal document data
@ -102,7 +103,7 @@ class RagPipelineTransformService:
pipeline_yaml = {}
if doc_form == "text_model":
match datasource_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
# get graph from transform.file-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
@ -111,7 +112,7 @@ class RagPipelineTransformService:
# get graph from transform.file-general-economy.yml
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
case DataSourceType.NOTION_IMPORT:
if indexing_technique == "high_quality":
# get graph from transform.notion-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
@ -120,7 +121,7 @@ class RagPipelineTransformService:
# get graph from transform.notion-general-economy.yml
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
if indexing_technique == "high_quality":
# get graph from transform.website-crawl-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
@ -133,15 +134,15 @@ class RagPipelineTransformService:
raise ValueError("Unsupported datasource type")
elif doc_form == "hierarchical_model":
match datasource_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
# get graph from transform.file-parentchild.yml
with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
case DataSourceType.NOTION_IMPORT:
# get graph from transform.notion-parentchild.yml
with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
# get graph from transform.website-crawl-parentchild.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
@ -287,7 +288,7 @@ class RagPipelineTransformService:
db.session.flush()
dataset.pipeline_id = pipeline.id
dataset.runtime_mode = "rag_pipeline"
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.updated_by = current_user.id
dataset.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.add(dataset)
@ -310,8 +311,8 @@ class RagPipelineTransformService:
data_source_info_dict = document.data_source_info_dict
if not data_source_info_dict:
continue
if document.data_source_type == "upload_file":
document.data_source_type = "local_file"
if document.data_source_type == DataSourceType.UPLOAD_FILE:
document.data_source_type = DataSourceType.LOCAL_FILE
file_id = data_source_info_dict.get("upload_file_id")
if file_id:
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
@ -331,7 +332,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="local_file",
datasource_type=DataSourceType.LOCAL_FILE,
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
@ -340,8 +341,8 @@ class RagPipelineTransformService:
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "notion_import":
document.data_source_type = "online_document"
elif document.data_source_type == DataSourceType.NOTION_IMPORT:
document.data_source_type = DataSourceType.ONLINE_DOCUMENT
data_source_info = json.dumps(
{
"workspace_id": data_source_info_dict.get("notion_workspace_id"),
@ -359,7 +360,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="online_document",
datasource_type=DataSourceType.ONLINE_DOCUMENT,
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
@ -368,8 +369,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "website_crawl":
document.data_source_type = "website_crawl"
elif document.data_source_type == DataSourceType.WEBSITE_CRAWL:
data_source_info = json.dumps(
{
"source_url": data_source_info_dict.get("url"),
@ -388,7 +388,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="website_crawl",
datasource_type=DataSourceType.WEBSITE_CRAWL,
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,

View File

@ -18,6 +18,7 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from libs import helper
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from models.enums import SummaryStatus
logger = logging.getLogger(__name__)
@ -73,7 +74,7 @@ class SummaryIndexService:
segment: DocumentSegment,
dataset: Dataset,
summary_content: str,
status: str = "generating",
status: SummaryStatus = SummaryStatus.GENERATING,
) -> DocumentSegmentSummary:
"""
Create or update a DocumentSegmentSummary record.
@ -83,7 +84,7 @@ class SummaryIndexService:
segment: DocumentSegment to create summary for
dataset: Dataset containing the segment
summary_content: Generated summary content
status: Summary status (default: "generating")
status: Summary status (default: SummaryStatus.GENERATING)
Returns:
Created or updated DocumentSegmentSummary instance
@ -326,7 +327,7 @@ class SummaryIndexService:
summary_index_node_id=summary_index_node_id,
summary_index_node_hash=summary_hash,
tokens=embedding_tokens,
status="completed",
status=SummaryStatus.COMPLETED,
enabled=True,
)
session.add(summary_record_in_session)
@ -362,7 +363,7 @@ class SummaryIndexService:
summary_record_in_session.summary_index_node_id = summary_index_node_id
summary_record_in_session.summary_index_node_hash = summary_hash
summary_record_in_session.tokens = embedding_tokens # Save embedding tokens
summary_record_in_session.status = "completed"
summary_record_in_session.status = SummaryStatus.COMPLETED
# Ensure summary_content is preserved (use the latest from summary_record parameter)
# This is critical: use the parameter value, not the database value
summary_record_in_session.summary_content = summary_content
@ -400,7 +401,7 @@ class SummaryIndexService:
summary_record.summary_index_node_id = summary_index_node_id
summary_record.summary_index_node_hash = summary_hash
summary_record.tokens = embedding_tokens
summary_record.status = "completed"
summary_record.status = SummaryStatus.COMPLETED
summary_record.summary_content = summary_content
if summary_record_in_session.updated_at:
summary_record.updated_at = summary_record_in_session.updated_at
@ -487,7 +488,7 @@ class SummaryIndexService:
)
if summary_record_in_session:
summary_record_in_session.status = "error"
summary_record_in_session.status = SummaryStatus.ERROR
summary_record_in_session.error = f"Vectorization failed: {str(e)}"
summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None)
error_session.add(summary_record_in_session)
@ -498,7 +499,7 @@ class SummaryIndexService:
summary_record_in_session.id,
)
# Update the original object for consistency
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = summary_record_in_session.error
summary_record.updated_at = summary_record_in_session.updated_at
else:
@ -514,7 +515,7 @@ class SummaryIndexService:
def batch_create_summary_records(
segments: list[DocumentSegment],
dataset: Dataset,
status: str = "not_started",
status: SummaryStatus = SummaryStatus.NOT_STARTED,
) -> None:
"""
Batch create summary records for segments with specified status.
@ -523,7 +524,7 @@ class SummaryIndexService:
Args:
segments: List of DocumentSegment instances
dataset: Dataset containing the segments
status: Initial status for the records (default: "not_started")
status: Initial status for the records (default: SummaryStatus.NOT_STARTED)
"""
segment_ids = [segment.id for segment in segments]
if not segment_ids:
@ -588,7 +589,7 @@ class SummaryIndexService:
)
if summary_record:
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = error
session.add(summary_record)
session.commit()
@ -631,14 +632,14 @@ class SummaryIndexService:
document_id=segment.document_id,
chunk_id=segment.id,
summary_content="",
status="generating",
status=SummaryStatus.GENERATING,
enabled=True,
)
session.add(summary_record_in_session)
session.flush()
# Update status to "generating"
summary_record_in_session.status = "generating"
summary_record_in_session.status = SummaryStatus.GENERATING
summary_record_in_session.error = None # type: ignore[assignment]
session.add(summary_record_in_session)
# Don't flush here - wait until after vectorization succeeds
@ -681,7 +682,7 @@ class SummaryIndexService:
except Exception as vectorize_error:
# If vectorization fails, update status to error in current session
logger.exception("Failed to vectorize summary for segment %s", segment.id)
summary_record_in_session.status = "error"
summary_record_in_session.status = SummaryStatus.ERROR
summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}"
session.add(summary_record_in_session)
session.commit()
@ -694,7 +695,7 @@ class SummaryIndexService:
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
)
if summary_record_in_session:
summary_record_in_session.status = "error"
summary_record_in_session.status = SummaryStatus.ERROR
summary_record_in_session.error = str(e)
session.add(summary_record_in_session)
session.commit()
@ -770,7 +771,7 @@ class SummaryIndexService:
SummaryIndexService.batch_create_summary_records(
segments=segments,
dataset=dataset,
status="not_started",
status=SummaryStatus.NOT_STARTED,
)
summary_records = []
@ -1067,7 +1068,7 @@ class SummaryIndexService:
# Update summary content
summary_record.summary_content = summary_content
summary_record.status = "generating"
summary_record.status = SummaryStatus.GENERATING
summary_record.error = None # type: ignore[assignment] # Clear any previous errors
session.add(summary_record)
# Flush to ensure summary_content is saved before vectorize_summary queries it
@ -1102,7 +1103,7 @@ class SummaryIndexService:
# If vectorization fails, update status to error in current session
# Don't raise the exception - just log it and return the record with error status
# This allows the segment update to complete even if vectorization fails
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = f"Vectorization failed: {str(e)}"
session.commit()
logger.exception("Failed to vectorize summary for segment %s", segment.id)
@ -1112,7 +1113,7 @@ class SummaryIndexService:
else:
# Create new summary record if doesn't exist
summary_record = SummaryIndexService.create_summary_record(
segment, dataset, summary_content, status="generating"
segment, dataset, summary_content, status=SummaryStatus.GENERATING
)
# Re-vectorize summary (this will update status to "completed" and tokens in its own session)
# Note: summary_record was created in a different session,
@ -1132,7 +1133,7 @@ class SummaryIndexService:
# If vectorization fails, update status to error in current session
# Merge the record into current session first
error_record = session.merge(summary_record)
error_record.status = "error"
error_record.status = SummaryStatus.ERROR
error_record.error = f"Vectorization failed: {str(e)}"
session.commit()
logger.exception("Failed to vectorize summary for segment %s", segment.id)
@ -1146,7 +1147,7 @@ class SummaryIndexService:
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
)
if summary_record:
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = str(e)
session.add(summary_record)
session.commit()
@ -1266,7 +1267,7 @@ class SummaryIndexService:
# Check if there are any "not_started" or "generating" status summaries
has_pending_summaries = any(
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
and summary_status_map[segment_id] in ("not_started", "generating")
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
for segment_id in segment_ids
)
@ -1330,7 +1331,7 @@ class SummaryIndexService:
# it means the summary is disabled (enabled=False) or not created yet, ignore it
has_pending_summaries = any(
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
and summary_status_map[segment_id] in ("not_started", "generating")
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
for segment_id in segment_ids
)
@ -1393,17 +1394,17 @@ class SummaryIndexService:
# Count statuses
status_counts = {
"completed": 0,
"generating": 0,
"error": 0,
"not_started": 0,
SummaryStatus.COMPLETED: 0,
SummaryStatus.GENERATING: 0,
SummaryStatus.ERROR: 0,
SummaryStatus.NOT_STARTED: 0,
}
summary_list = []
for segment in segments:
summary = summary_map.get(segment.id)
if summary:
status = summary.status
status = SummaryStatus(summary.status)
status_counts[status] = status_counts.get(status, 0) + 1
summary_list.append(
{
@ -1421,12 +1422,12 @@ class SummaryIndexService:
}
)
else:
status_counts["not_started"] += 1
status_counts[SummaryStatus.NOT_STARTED] += 1
summary_list.append(
{
"segment_id": segment.id,
"segment_position": segment.position,
"status": "not_started",
"status": SummaryStatus.NOT_STARTED,
"summary_preview": None,
"error": None,
"created_at": None,

View File

@ -13,6 +13,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import IndexingStatus, SegmentStatus
logger = logging.getLogger(__name__)
@ -34,7 +35,7 @@ def add_document_to_index_task(dataset_document_id: str):
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
return
if dataset_document.indexing_status != "completed":
if dataset_document.indexing_status != IndexingStatus.COMPLETED:
return
indexing_cache_key = f"document_{dataset_document.id}_indexing"
@ -48,7 +49,7 @@ def add_document_to_index_task(dataset_document_id: str):
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.status == SegmentStatus.COMPLETED,
)
.order_by(DocumentSegment.position.asc())
.all()
@ -139,7 +140,7 @@ def add_document_to_index_task(dataset_document_id: str):
logger.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.indexing_status = IndexingStatus.ERROR
dataset_document.error = str(e)
session.commit()
finally:

View File

@ -11,6 +11,7 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
from models.enums import CollectionBindingType
from models.model import App, AppAnnotationSetting, MessageAnnotation
from services.dataset_service import DatasetCollectionBindingService
@ -47,7 +48,7 @@ def enable_annotation_reply_task(
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
)
annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
@ -56,7 +57,7 @@ def enable_annotation_reply_task(
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
annotation_setting.collection_binding_id, "annotation"
annotation_setting.collection_binding_id, CollectionBindingType.ANNOTATION
)
)
if old_dataset_collection_binding and annotations:

View File

@ -10,6 +10,7 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
from models.enums import IndexingStatus, SegmentStatus
logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
if segment.status != "waiting":
if segment.status != SegmentStatus.WAITING:
return
indexing_cache_key = f"segment_{segment.id}_indexing"
@ -40,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
# update segment status to indexing
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
DocumentSegment.status: SegmentStatus.INDEXING,
DocumentSegment.indexing_at: naive_utc_now(),
}
)
@ -70,7 +71,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
or dataset_document.indexing_status != IndexingStatus.COMPLETED
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
@ -82,7 +83,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
# update segment to completed
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.completed_at: naive_utc_now(),
}
)
@ -94,7 +95,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.status = SegmentStatus.ERROR
segment.error = str(e)
session.commit()
finally:

View File

@ -12,6 +12,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@ -37,7 +38,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
if document.indexing_status == "parsing":
if document.indexing_status == IndexingStatus.PARSING:
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
@ -88,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
return
@ -128,7 +129,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = json.dumps(data_source_info)
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
@ -151,6 +152,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()

View File

@ -14,6 +14,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from models.enums import IndexingStatus
from services.feature_service import FeatureService
from tasks.generate_summary_index_task import generate_summary_index_task
@ -81,7 +82,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -96,7 +97,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
for document in documents:
if document:
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
# Transaction committed and closed
@ -148,7 +149,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
document.need_summary,
)
if (
document.indexing_status == "completed"
document.indexing_status == IndexingStatus.COMPLETED
and document.doc_form != "qa_model"
and document.need_summary is True
):

View File

@ -10,6 +10,7 @@ from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()

View File

@ -15,6 +15,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@ -112,7 +113,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
)
for document in documents:
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -146,7 +147,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()

View File

@ -12,6 +12,7 @@ from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
from models.enums import IndexingStatus, SegmentStatus
logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ def enable_segment_to_index_task(segment_id: str):
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
if segment.status != "completed":
if segment.status != SegmentStatus.COMPLETED:
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
return
@ -65,7 +66,7 @@ def enable_segment_to_index_task(segment_id: str):
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
or dataset_document.indexing_status != IndexingStatus.COMPLETED
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
@ -123,7 +124,7 @@ def enable_segment_to_index_task(segment_id: str):
logger.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.status = SegmentStatus.ERROR
segment.error = str(e)
session.commit()
finally:

View File

@ -12,6 +12,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.feature_service import FeatureService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -63,7 +64,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
.first()
)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -95,7 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
@ -108,7 +109,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(ex)
document.stopped_at = naive_utc_now()
session.add(document)

View File

@ -11,6 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@ -48,7 +49,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -76,7 +77,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
@ -85,7 +86,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(ex)
document.stopped_at = naive_utc_now()
session.add(document)

View File

@ -7,6 +7,7 @@ from faker import Faker
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.account_service import AccountService, TenantService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -35,7 +36,7 @@ class TestGetAvailableDatasetsIntegration:
name=fake.company(),
description=fake.text(max_nb_chars=100),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
)
@ -49,14 +50,14 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
name=f"Document {i}",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -94,7 +95,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -106,13 +107,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Archived Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Archived
)
@ -147,7 +148,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -159,13 +160,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Disabled Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Disabled
archived=False,
)
@ -200,21 +201,21 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
# Create documents with non-completed status
for i, status in enumerate(["indexing", "parsing", "splitting"]):
for i, status in enumerate([IndexingStatus.INDEXING, IndexingStatus.PARSING, IndexingStatus.SPLITTING]):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Document {status}",
created_by=account.id,
doc_form="text_model",
@ -263,7 +264,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="external", # External provider
data_source_type="external",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -307,7 +308,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant1.id,
name="Tenant 1 Dataset",
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account1.id,
)
db_session_with_containers.add(dataset1)
@ -318,7 +319,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant2.id,
name="Tenant 2 Dataset",
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account2.id,
)
db_session_with_containers.add(dataset2)
@ -330,13 +331,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Document for {dataset.name}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -398,7 +399,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=f"Dataset {i}",
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -410,13 +411,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -456,7 +457,7 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
)
@ -467,12 +468,12 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=fake.sentence(),
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
@ -525,7 +526,7 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -572,7 +573,7 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -12,6 +12,7 @@ import pytest
from sqlalchemy.orm import Session
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
class TestDatasetDocumentProperties:
@ -29,7 +30,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -39,10 +40,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=i + 1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name=f"doc_{i}.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -56,7 +57,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -65,12 +66,12 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="available.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -78,12 +79,12 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=2,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="pending.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
archived=False,
)
@ -91,12 +92,12 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=3,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="disabled.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False,
archived=False,
)
@ -111,7 +112,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -121,10 +122,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=i + 1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name=f"doc_{i}.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
word_count=wc,
)
@ -139,7 +140,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -148,10 +149,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="doc.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -166,7 +167,7 @@ class TestDatasetDocumentProperties:
content=f"segment {i}",
word_count=100,
tokens=50,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
created_by=created_by,
)
@ -180,7 +181,7 @@ class TestDatasetDocumentProperties:
content="waiting segment",
word_count=100,
tokens=50,
status="waiting",
status=SegmentStatus.WAITING,
enabled=True,
created_by=created_by,
)
@ -195,7 +196,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -204,10 +205,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="doc.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -235,7 +236,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -244,10 +245,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="doc.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -288,7 +289,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -298,10 +299,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)
@ -335,7 +336,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -345,10 +346,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)
@ -382,7 +383,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -392,10 +393,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)
@ -439,7 +440,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -449,10 +450,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)

View File

@ -12,6 +12,7 @@ import pytest
from sqlalchemy.orm import Session
from models.dataset import DatasetCollectionBinding
from models.enums import CollectionBindingType
from services.dataset_service import DatasetCollectionBindingService
@ -32,7 +33,7 @@ class DatasetCollectionBindingTestDataFactory:
provider_name: str = "openai",
model_name: str = "text-embedding-ada-002",
collection_name: str = "collection-abc",
collection_type: str = "dataset",
collection_type: str = CollectionBindingType.DATASET,
) -> DatasetCollectionBinding:
"""
Create a DatasetCollectionBinding with specified attributes.
@ -41,7 +42,7 @@ class DatasetCollectionBindingTestDataFactory:
provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
collection_name: Name of the vector database collection
collection_type: Type of collection (default: "dataset")
collection_type: Type of collection (default: CollectionBindingType.DATASET)
Returns:
DatasetCollectionBinding instance
@ -76,7 +77,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
# Arrange
provider_name = "openai"
model_name = "text-embedding-ada-002"
collection_type = "dataset"
collection_type = CollectionBindingType.DATASET
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name=provider_name,
@ -104,7 +105,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
# Arrange
provider_name = f"provider-{uuid4()}"
model_name = f"model-{uuid4()}"
collection_type = "dataset"
collection_type = CollectionBindingType.DATASET
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
@ -145,7 +146,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
result = DatasetCollectionBindingService.get_dataset_collection_binding(provider_name, model_name)
# Assert
assert result.type == "dataset"
assert result.type == CollectionBindingType.DATASET
assert result.provider_name == provider_name
assert result.model_name == model_name
@ -186,18 +187,20 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
collection_type="dataset",
collection_type=CollectionBindingType.DATASET,
)
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id, "dataset")
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
binding.id, CollectionBindingType.DATASET
)
# Assert
assert result.id == binding.id
assert result.provider_name == "openai"
assert result.model_name == "text-embedding-ada-002"
assert result.collection_name == "test-collection"
assert result.type == "dataset"
assert result.type == CollectionBindingType.DATASET
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session):
"""Test error handling when collection binding is not found by ID and type."""
@ -206,7 +209,9 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
# Act & Assert
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset")
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
non_existent_id, CollectionBindingType.DATASET
)
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(
self, db_session_with_containers: Session
@ -240,7 +245,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
collection_type="dataset",
collection_type=CollectionBindingType.DATASET,
)
# Act
@ -248,7 +253,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
# Assert
assert result.id == binding.id
assert result.type == "dataset"
assert result.type == CollectionBindingType.DATASET
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session):
"""Test error when binding exists but with wrong collection type."""
@ -258,7 +263,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
collection_type="dataset",
collection_type=CollectionBindingType.DATASET,
)
# Act & Assert

View File

@ -15,6 +15,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
from models.enums import DataSourceType
from models.model import App
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
@ -72,7 +73,7 @@ class DatasetUpdateDeleteTestDataFactory:
tenant_id=tenant_id,
name=name,
description="Test description",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=permission,

View File

@ -15,7 +15,7 @@ import pytest
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
from models.model import UploadFile
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
@ -88,7 +88,7 @@ class DocumentStatusTestDataFactory:
data_source_info=json.dumps(data_source_info or {}),
batch=f"batch-{uuid4()}",
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form="text_model",
)
@ -100,7 +100,7 @@ class DocumentStatusTestDataFactory:
document.paused_by = paused_by
document.paused_at = paused_at
document.doc_metadata = doc_metadata or {}
if indexing_status == "completed" and "completed_at" not in kwargs:
if indexing_status == IndexingStatus.COMPLETED and "completed_at" not in kwargs:
document.completed_at = FIXED_TIME
for key, value in kwargs.items():
@ -139,7 +139,7 @@ class DocumentStatusTestDataFactory:
dataset = Dataset(
tenant_id=tenant_id,
name=name,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
dataset.id = dataset_id
@ -291,7 +291,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
is_paused=False,
)
@ -326,7 +326,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
is_paused=False,
)
@ -354,7 +354,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="parsing",
indexing_status=IndexingStatus.PARSING,
is_paused=False,
)
@ -383,7 +383,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
is_paused=False,
)
@ -412,7 +412,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
is_paused=False,
)
@ -487,7 +487,7 @@ class TestDocumentServiceRecoverDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
is_paused=True,
paused_by=str(uuid4()),
paused_at=paused_time,
@ -526,7 +526,7 @@ class TestDocumentServiceRecoverDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
is_paused=False,
)
@ -609,7 +609,7 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
mock_document_service_dependencies["redis_client"].get.return_value = None
@ -619,7 +619,7 @@ class TestDocumentServiceRetryDocument:
# Assert
db_session_with_containers.refresh(document)
assert document.indexing_status == "waiting"
assert document.indexing_status == IndexingStatus.WAITING
expected_cache_key = f"document_{document.id}_is_retried"
mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1)
@ -646,14 +646,14 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
document2 = DocumentStatusTestDataFactory.create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
position=2,
)
@ -665,8 +665,8 @@ class TestDocumentServiceRetryDocument:
# Assert
db_session_with_containers.refresh(document1)
db_session_with_containers.refresh(document2)
assert document1.indexing_status == "waiting"
assert document2.indexing_status == "waiting"
assert document1.indexing_status == IndexingStatus.WAITING
assert document2.indexing_status == IndexingStatus.WAITING
mock_document_service_dependencies["retry_task"].delay.assert_called_once_with(
dataset.id, [document1.id, document2.id], mock_document_service_dependencies["user_id"]
@ -693,7 +693,7 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
mock_document_service_dependencies["redis_client"].get.return_value = "1"
@ -703,7 +703,7 @@ class TestDocumentServiceRetryDocument:
DocumentService.retry_document(dataset.id, [document])
db_session_with_containers.refresh(document)
assert document.indexing_status == "error"
assert document.indexing_status == IndexingStatus.ERROR
def test_retry_document_missing_current_user_error(
self, db_session_with_containers, mock_document_service_dependencies
@ -726,7 +726,7 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
mock_document_service_dependencies["redis_client"].get.return_value = None
@ -816,7 +816,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
enabled=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document2 = DocumentStatusTestDataFactory.create_document(
db_session_with_containers,
@ -824,7 +824,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
enabled=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
position=2,
)
document_ids = [document1.id, document2.id]
@ -866,7 +866,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
completed_at=FIXED_TIME,
)
document_ids = [document.id]
@ -909,7 +909,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
document_id=str(uuid4()),
archived=False,
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document_ids = [document.id]
@ -951,7 +951,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
document_id=str(uuid4()),
archived=True,
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document_ids = [document.id]
@ -1015,7 +1015,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document_ids = [document.id]
@ -1098,7 +1098,7 @@ class TestDocumentServiceRenameDocument:
document_id=document_id,
dataset_id=dataset.id,
tenant_id=tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -1139,7 +1139,7 @@ class TestDocumentServiceRenameDocument:
dataset_id=dataset.id,
tenant_id=tenant_id,
doc_metadata={"existing_key": "existing_value"},
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -1187,7 +1187,7 @@ class TestDocumentServiceRenameDocument:
dataset_id=dataset.id,
tenant_id=tenant_id,
data_source_info={"upload_file_id": upload_file.id},
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -1277,7 +1277,7 @@ class TestDocumentServiceRenameDocument:
document_id=document_id,
dataset_id=dataset.id,
tenant_id=str(uuid4()),
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act & Assert

View File

@ -16,6 +16,7 @@ from models.dataset import (
DatasetPermission,
DatasetPermissionEnum,
)
from models.enums import DataSourceType
from services.dataset_service import DatasetPermissionService, DatasetService
from services.errors.account import NoPermissionError
@ -67,7 +68,7 @@ class DatasetPermissionTestDataFactory:
tenant_id=tenant_id,
name=name,
description="desc",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=permission,

View File

@ -15,6 +15,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline
from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -74,7 +75,7 @@ class DatasetServiceIntegrationDataFactory:
tenant_id=tenant_id,
name=name,
description=description,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
created_by=created_by,
provider=provider,
@ -98,13 +99,13 @@ class DatasetServiceIntegrationDataFactory:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info='{"upload_file_id": "upload-file-id"}',
batch=str(uuid4()),
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
)
db_session_with_containers.add(document)
@ -437,7 +438,7 @@ class TestDatasetServiceCreateRagPipelineDataset:
created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
assert created_dataset is not None
assert created_dataset.name == entity.name
assert created_dataset.runtime_mode == "rag_pipeline"
assert created_dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE
assert created_dataset.created_by == account.id
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
assert created_pipeline is not None

View File

@ -14,6 +14,7 @@ import pytest
from sqlalchemy.orm import Session
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
@ -42,7 +43,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
dataset = Dataset(
tenant_id=tenant_id or str(uuid4()),
name=name,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by or str(uuid4()),
)
if dataset_id:
@ -72,11 +73,11 @@ class DocumentBatchUpdateIntegrationDataFactory:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info=json.dumps({"upload_file_id": str(uuid4())}),
batch=f"batch-{uuid4()}",
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by or str(uuid4()),
doc_form="text_model",
)
@ -85,7 +86,9 @@ class DocumentBatchUpdateIntegrationDataFactory:
document.archived = archived
document.indexing_status = indexing_status
document.completed_at = (
completed_at if completed_at is not None else (FIXED_TIME if indexing_status == "completed" else None)
completed_at
if completed_at is not None
else (FIXED_TIME if indexing_status == IndexingStatus.COMPLETED else None)
)
for key, value in kwargs.items():
@ -243,7 +246,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
dataset=dataset,
document_ids=document_ids,
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -277,7 +280,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
db_session_with_containers,
dataset=dataset,
enabled=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
completed_at=FIXED_TIME,
)
@ -306,7 +309,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
db_session_with_containers,
dataset=dataset,
enabled=True,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
completed_at=None,
)

View File

@ -5,6 +5,7 @@ from uuid import uuid4
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom
from services.dataset_service import DatasetService
@ -58,7 +59,7 @@ class DatasetDeleteIntegrationDataFactory:
dataset = Dataset(
tenant_id=tenant_id,
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
index_struct=index_struct,
created_by=created_by,
@ -84,10 +85,10 @@ class DatasetDeleteIntegrationDataFactory:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=f"batch-{uuid4()}",
name="Document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form=doc_form,
)

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom
from services.dataset_service import SegmentService
@ -62,7 +63,7 @@ class SegmentServiceTestDataFactory:
tenant_id=tenant_id,
name=f"Test Dataset {uuid4()}",
description="Test description",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=DatasetPermissionEnum.ONLY_ME,
@ -82,10 +83,10 @@ class SegmentServiceTestDataFactory:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=f"batch-{uuid4()}",
name=f"test-doc-{uuid4()}.txt",
created_from="api",
created_from=DocumentCreatedFrom.API,
created_by=created_by,
)
db_session_with_containers.add(document)

View File

@ -24,6 +24,7 @@ from models.dataset import (
DatasetProcessRule,
DatasetQuery,
)
from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode
from models.model import Tag, TagBinding
from services.dataset_service import DatasetService, DocumentService
@ -100,7 +101,7 @@ class DatasetRetrievalTestDataFactory:
tenant_id=tenant_id,
name=name,
description="desc",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=permission,
@ -149,7 +150,7 @@ class DatasetRetrievalTestDataFactory:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=content,
source="web",
source=DatasetQuerySource.APP,
source_app_id=None,
created_by_role="account",
created_by=created_by,
@ -601,7 +602,7 @@ class TestDatasetServiceGetProcessRules:
db_session_with_containers,
dataset_id=dataset.id,
created_by=account.id,
mode="custom",
mode=ProcessRuleMode.CUSTOM,
rules=rules_data,
)

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, ExternalKnowledgeBindings
from models.enums import DataSourceType
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
@ -64,7 +65,7 @@ class DatasetUpdateTestDataFactory:
tenant_id=tenant_id,
name=name,
description=description,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
created_by=created_by,
provider=provider,

View File

@ -4,6 +4,7 @@ from uuid import uuid4
from sqlalchemy import select
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
@ -11,7 +12,7 @@ def _create_dataset(db_session_with_containers) -> Dataset:
dataset = Dataset(
tenant_id=str(uuid4()),
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
)
dataset.id = str(uuid4())
@ -35,11 +36,11 @@ def _create_document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info="{}",
batch=f"batch-{uuid4()}",
name=f"doc-{uuid4()}",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
)
@ -48,7 +49,7 @@ def _create_document(
document.enabled = enabled
document.archived = archived
document.is_paused = is_paused
if indexing_status == "completed":
if indexing_status == IndexingStatus.COMPLETED:
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db_session_with_containers.add(document)
@ -62,7 +63,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
position=1,
@ -71,7 +72,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False,
archived=False,
position=2,
@ -80,7 +81,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True,
position=3,
@ -101,14 +102,14 @@ def test_apply_display_status_filter_applies_when_status_present(db_session_with
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
position=1,
)
_create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
position=2,
)
@ -125,14 +126,14 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
position=1,
)
doc2 = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
position=2,
)

View File

@ -9,7 +9,7 @@ import pytest
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom
from models.model import UploadFile
from services.dataset_service import DocumentService
@ -33,7 +33,7 @@ def make_dataset(db_session_with_containers, dataset_id=None, tenant_id=None, bu
dataset = Dataset(
tenant_id=tenant_id,
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
)
dataset.id = dataset_id
@ -62,11 +62,11 @@ def make_document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info=json.dumps(data_source_info or {}),
batch=f"batch-{uuid4()}",
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
)

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import DataSourceType
from models.model import (
App,
AppAnnotationHitHistory,
@ -287,7 +288,7 @@ class TestMessagesCleanServiceIntegration:
dataset_name="Test dataset",
document_id=str(uuid.uuid4()),
document_name="Test document",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
segment_id=str(uuid.uuid4()),
score=0.9,
content="Test content",

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from core.rag.index_processor.constant.built_in_field import BuiltInField
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
@ -101,7 +102,7 @@ class TestMetadataService:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
built_in_field_enabled=False,
)
@ -132,11 +133,11 @@ class TestMetadataService:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info="{}",
batch="test-batch",
name=fake.file_name(),
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text",
doc_language="en",
@ -163,7 +164,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
# Act: Execute the method under test
result = MetadataService.create_metadata(dataset.id, metadata_args)
@ -201,7 +202,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
long_name = "a" * 256 # 256 characters, exceeding 255 limit
metadata_args = MetadataArgs(type="string", name=long_name)
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=long_name)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
@ -226,11 +227,11 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create first metadata
first_metadata_args = MetadataArgs(type="string", name="duplicate_name")
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="duplicate_name")
MetadataService.create_metadata(dataset.id, first_metadata_args)
# Try to create second metadata with same name
second_metadata_args = MetadataArgs(type="number", name="duplicate_name")
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="duplicate_name")
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name already exists."):
@ -256,7 +257,7 @@ class TestMetadataService:
# Try to create metadata with built-in field name
built_in_field_name = BuiltInField.document_name
metadata_args = MetadataArgs(type="string", name=built_in_field_name)
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=built_in_field_name)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
@ -281,7 +282,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type="string", name="old_name")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Act: Execute the method under test
@ -318,7 +319,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type="string", name="old_name")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Try to update with too long name
@ -347,10 +348,10 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create two metadata entries
first_metadata_args = MetadataArgs(type="string", name="first_metadata")
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="first_metadata")
first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args)
second_metadata_args = MetadataArgs(type="number", name="second_metadata")
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="second_metadata")
second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args)
# Try to update first metadata with second metadata's name
@ -376,7 +377,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type="string", name="old_name")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Try to update with built-in field name
@ -432,7 +433,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type="string", name="to_be_deleted")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="to_be_deleted")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Act: Execute the method under test
@ -496,7 +497,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Create metadata binding
@ -798,7 +799,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Mock DocumentService.get_document
@ -866,7 +867,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Mock DocumentService.get_document
@ -917,7 +918,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Create metadata operation data
@ -1038,7 +1039,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Create document and metadata binding
@ -1101,7 +1102,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Act: Execute the method under test

View File

@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset
from models.enums import DataSourceType
from models.model import App, Tag, TagBinding
from services.tag_service import TagService
@ -100,7 +101,7 @@ class TestTagService:
description=fake.text(max_nb_chars=100),
provider="vendor",
permission="only_me",
data_source_type="upload",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
tenant_id=tenant_id,
created_by=mock_external_service_dependencies["current_user"].id,

View File

@ -8,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.add_document_to_index_task import add_document_to_index_task
@ -79,7 +80,7 @@ class TestAddDocumentToIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -92,12 +93,12 @@ class TestAddDocumentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
@ -137,7 +138,7 @@ class TestAddDocumentToIndexTask:
index_node_id=f"node_{i}",
index_node_hash=f"hash_{i}",
enabled=False,
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment)
@ -297,7 +298,7 @@ class TestAddDocumentToIndexTask:
)
# Set invalid indexing status
document.indexing_status = "processing"
document.indexing_status = IndexingStatus.INDEXING
db_session_with_containers.commit()
# Act: Execute the task
@ -339,7 +340,7 @@ class TestAddDocumentToIndexTask:
# Assert: Verify error handling
db_session_with_containers.refresh(document)
assert document.enabled is False
assert document.indexing_status == "error"
assert document.indexing_status == IndexingStatus.ERROR
assert document.error is not None
assert "doesn't exist" in document.error
assert document.disabled_at is not None
@ -434,7 +435,7 @@ class TestAddDocumentToIndexTask:
Test document indexing when segments are already enabled.
This test verifies:
- Segments with status="completed" are processed regardless of enabled status
- Segments with status=SegmentStatus.COMPLETED are processed regardless of enabled status
- Index processing occurs with all completed segments
- Auto disable log deletion still occurs
- Redis cache is cleared
@ -460,7 +461,7 @@ class TestAddDocumentToIndexTask:
index_node_id=f"node_{i}",
index_node_hash=f"hash_{i}",
enabled=True, # Already enabled
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment)
@ -482,7 +483,7 @@ class TestAddDocumentToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with all completed segments
# (implementation doesn't filter by enabled status, only by status="completed")
# (implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED)
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
@ -594,7 +595,7 @@ class TestAddDocumentToIndexTask:
# Assert: Verify error handling
db_session_with_containers.refresh(document)
assert document.enabled is False
assert document.indexing_status == "error"
assert document.indexing_status == IndexingStatus.ERROR
assert document.error is not None
assert "Index processing failed" in document.error
assert document.disabled_at is not None
@ -614,7 +615,7 @@ class TestAddDocumentToIndexTask:
Test segment filtering with various edge cases.
This test verifies:
- Only segments with status="completed" are processed (regardless of enabled status)
- Only segments with status=SegmentStatus.COMPLETED are processed (regardless of enabled status)
- Segments with status!="completed" are NOT processed
- Segments are ordered by position correctly
- Mixed segment states are handled properly
@ -630,7 +631,7 @@ class TestAddDocumentToIndexTask:
fake = Faker()
segments = []
# Segment 1: Should be processed (enabled=False, status="completed")
# Segment 1: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
segment1 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
@ -643,14 +644,14 @@ class TestAddDocumentToIndexTask:
index_node_id="node_0",
index_node_hash="hash_0",
enabled=False,
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment1)
segments.append(segment1)
# Segment 2: Should be processed (enabled=True, status="completed")
# Note: Implementation doesn't filter by enabled status, only by status="completed"
# Segment 2: Should be processed (enabled=True, status=SegmentStatus.COMPLETED)
# Note: Implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED
segment2 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
@ -663,7 +664,7 @@ class TestAddDocumentToIndexTask:
index_node_id="node_1",
index_node_hash="hash_1",
enabled=True, # Already enabled, but will still be processed
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment2)
@ -682,13 +683,13 @@ class TestAddDocumentToIndexTask:
index_node_id="node_2",
index_node_hash="hash_2",
enabled=False,
status="processing", # Not completed
status=SegmentStatus.INDEXING, # Not completed
created_by=document.created_by,
)
db_session_with_containers.add(segment3)
segments.append(segment3)
# Segment 4: Should be processed (enabled=False, status="completed")
# Segment 4: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
segment4 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
@ -701,7 +702,7 @@ class TestAddDocumentToIndexTask:
index_node_id="node_3",
index_node_hash="hash_3",
enabled=False,
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment4)
@ -726,7 +727,7 @@ class TestAddDocumentToIndexTask:
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 3 # 3 segments with status="completed" should be processed
assert len(documents) == 3 # 3 segments with status=SegmentStatus.COMPLETED should be processed
# Verify correct segments were processed (by position order)
# Segments 1, 2, 4 should be processed (positions 0, 1, 3)
@ -799,7 +800,7 @@ class TestAddDocumentToIndexTask:
# Assert: Verify consistent error handling
db_session_with_containers.refresh(document)
assert document.enabled is False, f"Document should be disabled for {error_name}"
assert document.indexing_status == "error", f"Document status should be error for {error_name}"
assert document.indexing_status == IndexingStatus.ERROR, f"Document status should be error for {error_name}"
assert document.error is not None, f"Error should be recorded for {error_name}"
assert str(exception) in document.error, f"Error message should contain exception for {error_name}"
assert document.disabled_at is not None, f"Disabled timestamp should be set for {error_name}"

View File

@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from models.model import UploadFile
from tasks.batch_clean_document_task import batch_clean_document_task
@ -113,7 +114,7 @@ class TestBatchCleanDocumentTask:
tenant_id=account.current_tenant.id,
name=fake.word(),
description=fake.sentence(),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
@ -144,12 +145,12 @@ class TestBatchCleanDocumentTask:
dataset_id=dataset.id,
position=0,
name=fake.word(),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info=json.dumps({"upload_file_id": str(uuid.uuid4())}),
batch="test_batch",
created_from="test",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
)
@ -183,7 +184,7 @@ class TestBatchCleanDocumentTask:
tokens=50,
index_node_id=str(uuid.uuid4()),
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
@ -297,7 +298,7 @@ class TestBatchCleanDocumentTask:
tokens=50,
index_node_id=str(uuid.uuid4()),
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
@ -671,7 +672,7 @@ class TestBatchCleanDocumentTask:
tokens=25 + i * 5,
index_node_id=str(uuid.uuid4()),
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
segments.append(segment)

View File

@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from models.model import UploadFile
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
@ -139,7 +139,7 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
@ -170,12 +170,12 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
@ -301,7 +301,7 @@ class TestBatchCreateSegmentToIndexTask:
assert segment.dataset_id == dataset.id
assert segment.document_id == document.id
assert segment.position == i + 1
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
assert segment.answer is None # text_model doesn't have answers
@ -442,12 +442,12 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name="disabled_document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Document is disabled
archived=False,
doc_form="text_model",
@ -458,12 +458,12 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=2,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name="archived_document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Document is archived
doc_form="text_model",
@ -474,12 +474,12 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=3,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name="incomplete_document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="indexing", # Not completed
indexing_status=IndexingStatus.INDEXING, # Not completed
enabled=True,
archived=False,
doc_form="text_model",
@ -643,7 +643,7 @@ class TestBatchCreateSegmentToIndexTask:
word_count=len(f"Existing segment {i + 1}"),
tokens=10,
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash=f"hash_{i}",
)
@ -694,7 +694,7 @@ class TestBatchCreateSegmentToIndexTask:
for i, segment in enumerate(new_segments):
expected_position = 4 + i # Should start at position 4
assert segment.position == expected_position
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None

View File

@ -29,7 +29,14 @@ from models.dataset import (
Document,
DocumentSegment,
)
from models.enums import CreatorUserRole
from models.enums import (
CreatorUserRole,
DatasetMetadataType,
DataSourceType,
DocumentCreatedFrom,
IndexingStatus,
SegmentStatus,
)
from models.model import UploadFile
from tasks.clean_dataset_task import clean_dataset_task
@ -176,12 +183,12 @@ class TestCleanDatasetTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name="test_document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="paragraph_index",
@ -219,7 +226,7 @@ class TestCleanDatasetTask:
word_count=20,
tokens=30,
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash",
created_at=datetime.now(),
@ -373,7 +380,7 @@ class TestCleanDatasetTask:
dataset_id=dataset.id,
tenant_id=tenant.id,
name="test_metadata",
type="string",
type=DatasetMetadataType.STRING,
created_by=account.id,
)
metadata.id = str(uuid.uuid4())
@ -587,7 +594,7 @@ class TestCleanDatasetTask:
word_count=len(segment_content),
tokens=50,
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash",
created_at=datetime.now(),
@ -686,7 +693,7 @@ class TestCleanDatasetTask:
dataset_id=dataset.id,
tenant_id=tenant.id,
name=f"test_metadata_{i}",
type="string",
type=DatasetMetadataType.STRING,
created_by=account.id,
)
metadata.id = str(uuid.uuid4())
@ -880,11 +887,11 @@ class TestCleanDatasetTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info="{}",
batch="test_batch",
name=f"test_doc_{special_content}",
created_from="test",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
created_at=datetime.now(),
updated_at=datetime.now(),
@ -905,7 +912,7 @@ class TestCleanDatasetTask:
word_count=len(segment_content.split()),
tokens=len(segment_content) // 4, # Rough token estimation
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash_" + "x" * 50, # Long hash within limits
created_at=datetime.now(),
@ -946,7 +953,7 @@ class TestCleanDatasetTask:
dataset_id=dataset.id,
tenant_id=tenant.id,
name=f"metadata_{special_content}",
type="string",
type=DatasetMetadataType.STRING,
created_by=account.id,
)
special_metadata.id = str(uuid.uuid4())

View File

@ -13,6 +13,7 @@ import pytest
from faker import Faker
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from services.account_service import AccountService, TenantService
from tasks.clean_notion_document_task import clean_notion_document_task
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -88,7 +89,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -105,17 +106,17 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model", # Set doc_form to ensure dataset.doc_form works
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -134,7 +135,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
segments.append(segment)
@ -220,7 +221,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -269,7 +270,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=f"{fake.company()}_{index_type}",
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -281,17 +282,17 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
),
batch="test_batch",
name="Test Notion Page",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form=index_type,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -308,7 +309,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id="test_node",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
db_session_with_containers.commit()
@ -357,7 +358,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -369,16 +370,16 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
),
batch="test_batch",
name="Test Notion Page",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -397,7 +398,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=None, # No index node ID
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
segments.append(segment)
@ -443,7 +444,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -460,16 +461,16 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -488,7 +489,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
all_segments.append(segment)
@ -558,7 +559,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -570,22 +571,22 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
),
batch="test_batch",
name="Test Notion Page",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
# Create segments with different statuses
segment_statuses = ["waiting", "processing", "completed", "error"]
segment_statuses = [SegmentStatus.WAITING, SegmentStatus.INDEXING, SegmentStatus.COMPLETED, SegmentStatus.ERROR]
segments = []
index_node_ids = []
@ -654,7 +655,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -666,16 +667,16 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
),
batch="test_batch",
name="Test Notion Page",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -692,7 +693,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id="test_node",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
db_session_with_containers.commit()
@ -736,7 +737,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -754,16 +755,16 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -783,7 +784,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
all_segments.append(segment)
@ -848,7 +849,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=f"{fake.company()}_{i}",
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -866,16 +867,16 @@ class TestCleanNotionDocumentTask:
tenant_id=account.current_tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -894,7 +895,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
all_segments.append(segment)
@ -963,14 +964,22 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
# Create documents with different indexing statuses
document_statuses = ["waiting", "parsing", "cleaning", "splitting", "indexing", "completed", "error"]
document_statuses = [
IndexingStatus.WAITING,
IndexingStatus.PARSING,
IndexingStatus.CLEANING,
IndexingStatus.SPLITTING,
IndexingStatus.INDEXING,
IndexingStatus.COMPLETED,
IndexingStatus.ERROR,
]
documents = []
all_segments = []
all_index_node_ids = []
@ -981,13 +990,13 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status=status,
@ -1009,7 +1018,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
all_segments.append(segment)
@ -1066,7 +1075,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
built_in_field_enabled=True,
)
@ -1079,7 +1088,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{
"notion_workspace_id": "workspace_test",
@ -1091,10 +1100,10 @@ class TestCleanNotionDocumentTask:
),
batch="test_batch",
name="Test Notion Page with Metadata",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_metadata={
"document_name": "Test Notion Page with Metadata",
"uploader": account.name,
@ -1122,7 +1131,7 @@ class TestCleanNotionDocumentTask:
tokens=75,
index_node_id=f"node_{i}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
keywords={"key1": ["value1", "value2"], "key2": ["value3"]},
)
db_session_with_containers.add(segment)

View File

@ -15,6 +15,7 @@ from faker import Faker
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.create_segment_to_index_task import create_segment_to_index_task
@ -118,7 +119,7 @@ class TestCreateSegmentToIndexTask:
name=fake.company(),
description=fake.text(max_nb_chars=100),
tenant_id=tenant_id,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
@ -133,13 +134,13 @@ class TestCreateSegmentToIndexTask:
dataset_id=dataset.id,
tenant_id=tenant_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account_id,
enabled=True,
archived=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_form="qa_model",
)
db_session_with_containers.add(document)
@ -148,7 +149,7 @@ class TestCreateSegmentToIndexTask:
return dataset, document
def _create_test_segment(
self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting"
self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING
):
"""
Helper method to create a test document segment for testing.
@ -200,7 +201,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -208,7 +209,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify segment status changes
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
assert segment.error is None
@ -257,7 +258,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.COMPLETED
)
# Act: Execute the task
@ -268,7 +269,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status unchanged
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is None
# Verify no index processor calls were made
@ -293,20 +294,25 @@ class TestCreateSegmentToIndexTask:
dataset_id=invalid_dataset_id,
tenant_id=tenant.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
enabled=True,
archived=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers,
invalid_dataset_id,
document.id,
tenant.id,
account.id,
status=SegmentStatus.WAITING,
)
# Act: Execute the task
@ -317,7 +323,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -337,7 +343,12 @@ class TestCreateSegmentToIndexTask:
invalid_document_id = str(uuid4())
segment = self._create_test_segment(
db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting"
db_session_with_containers,
dataset.id,
invalid_document_id,
tenant.id,
account.id,
status=SegmentStatus.WAITING,
)
# Act: Execute the task
@ -348,7 +359,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -373,7 +384,7 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -384,7 +395,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -409,7 +420,7 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -420,7 +431,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -445,7 +456,7 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -456,7 +467,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -477,7 +488,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Mock processor to raise exception
@ -488,7 +499,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify error handling
db_session_with_containers.refresh(segment)
assert segment.status == "error"
assert segment.status == SegmentStatus.ERROR
assert segment.enabled is False
assert segment.disabled_at is not None
assert segment.error == "Processor failed"
@ -512,7 +523,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
custom_keywords = ["custom", "keywords", "test"]
@ -521,7 +532,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -555,7 +566,7 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -563,7 +574,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
# Verify correct doc_form was passed to factory
mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form)
@ -583,7 +594,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task and measure time
@ -597,7 +608,7 @@ class TestCreateSegmentToIndexTask:
# Verify successful completion
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
def test_create_segment_to_index_concurrent_execution(
self, db_session_with_containers, mock_external_service_dependencies
@ -617,7 +628,7 @@ class TestCreateSegmentToIndexTask:
segments = []
for i in range(3):
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
segments.append(segment)
@ -629,7 +640,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify all segments processed
for segment in segments:
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -665,7 +676,7 @@ class TestCreateSegmentToIndexTask:
keywords=["large", "content", "test"],
index_node_id=str(uuid4()),
index_node_hash=str(uuid4()),
status="waiting",
status=SegmentStatus.WAITING,
created_by=account.id,
)
db_session_with_containers.add(segment)
@ -681,7 +692,7 @@ class TestCreateSegmentToIndexTask:
assert execution_time < 10.0 # Should complete within 10 seconds
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -700,7 +711,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Set up Redis cache key to simulate indexing in progress
@ -718,7 +729,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify indexing still completed successfully despite Redis failure
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -740,7 +751,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Simulate an error during indexing to trigger rollback path
@ -752,7 +763,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify error handling and rollback
db_session_with_containers.refresh(segment)
assert segment.status == "error"
assert segment.status == SegmentStatus.ERROR
assert segment.enabled is False
assert segment.disabled_at is not None
assert segment.error is not None
@ -772,7 +783,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -780,7 +791,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
# Verify index processor was called with correct metadata
mock_processor = mock_external_service_dependencies["index_processor"]
@ -814,11 +825,11 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Verify initial state
assert segment.status == "waiting"
assert segment.status == SegmentStatus.WAITING
assert segment.indexing_at is None
assert segment.completed_at is None
@ -827,7 +838,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify final state
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -861,7 +872,7 @@ class TestCreateSegmentToIndexTask:
keywords=[],
index_node_id=str(uuid4()),
index_node_hash=str(uuid4()),
status="waiting",
status=SegmentStatus.WAITING,
created_by=account.id,
)
db_session_with_containers.add(segment)
@ -872,7 +883,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -907,7 +918,7 @@ class TestCreateSegmentToIndexTask:
keywords=["special", "unicode", "test"],
index_node_id=str(uuid4()),
index_node_hash=str(uuid4()),
status="waiting",
status=SegmentStatus.WAITING,
created_by=account.id,
)
db_session_with_containers.add(segment)
@ -918,7 +929,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -937,7 +948,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Create long keyword list
@ -948,7 +959,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -979,10 +990,10 @@ class TestCreateSegmentToIndexTask:
)
segment1 = self._create_test_segment(
db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting"
db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status=SegmentStatus.WAITING
)
segment2 = self._create_test_segment(
db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting"
db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status=SegmentStatus.WAITING
)
# Act: Execute tasks for both tenants
@ -993,8 +1004,8 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.refresh(segment1)
db_session_with_containers.refresh(segment2)
assert segment1.status == "completed"
assert segment2.status == "completed"
assert segment1.status == SegmentStatus.COMPLETED
assert segment2.status == SegmentStatus.COMPLETED
assert segment1.tenant_id == tenant1.id
assert segment2.tenant_id == tenant2.id
assert segment1.tenant_id != segment2.tenant_id
@ -1014,7 +1025,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task with None keywords
@ -1022,7 +1033,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -1050,7 +1061,7 @@ class TestCreateSegmentToIndexTask:
segments = []
for i in range(5):
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
segments.append(segment)
@ -1067,7 +1078,7 @@ class TestCreateSegmentToIndexTask:
# Verify all segments processed successfully
for segment in segments:
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
assert segment.error is None

View File

@ -11,6 +11,7 @@ from core.indexing_runner import DocumentIsPausedError
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from tasks.document_indexing_task import (
_document_indexing,
_document_indexing_with_tenant_queue,
@ -139,7 +140,7 @@ class TestDatasetIndexingTaskIntegration:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -155,12 +156,12 @@ class TestDatasetIndexingTaskIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=position,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=f"doc-{position}.txt",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
)
db_session_with_containers.add(document)
@ -181,7 +182,7 @@ class TestDatasetIndexingTaskIntegration:
for document_id in document_ids:
updated = self._query_document(db_session_with_containers, document_id)
assert updated is not None
assert updated.indexing_status == "parsing"
assert updated.indexing_status == IndexingStatus.PARSING
assert updated.processing_started_at is not None
def _assert_documents_error_contains(
@ -195,7 +196,7 @@ class TestDatasetIndexingTaskIntegration:
for document_id in document_ids:
updated = self._query_document(db_session_with_containers, document_id)
assert updated is not None
assert updated.indexing_status == "error"
assert updated.indexing_status == IndexingStatus.ERROR
assert updated.error is not None
assert expected_error_substring in updated.error
assert updated.stopped_at is not None

View File

@ -13,6 +13,7 @@ import pytest
from faker import Faker
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from services.account_service import AccountService, TenantService
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -90,7 +91,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -102,13 +103,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -150,7 +151,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -162,13 +163,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -182,13 +183,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -209,7 +210,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -220,7 +221,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor load method was called
mock_factory = mock_index_processor_factory.return_value
@ -251,7 +252,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -263,13 +264,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="parent_child_index",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -283,13 +284,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="parent_child_index",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -310,7 +311,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -321,7 +322,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor clean and load methods were called
mock_factory = mock_index_processor_factory.return_value
@ -367,7 +368,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -399,7 +400,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -411,13 +412,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -430,7 +431,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify that no index processor load was called since no segments exist
mock_factory = mock_index_processor_factory.return_value
@ -455,7 +456,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -488,7 +489,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -500,13 +501,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -520,13 +521,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -547,7 +548,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -563,7 +564,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to error
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "error"
assert updated_document.indexing_status == IndexingStatus.ERROR
assert "Test exception during indexing" in updated_document.error
def test_deal_dataset_vector_index_task_with_custom_index_type(
@ -584,7 +585,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -596,13 +597,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="qa_index",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -623,7 +624,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -634,7 +635,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor was initialized with custom index type
mock_index_processor_factory.assert_called_once_with("qa_index")
@ -660,7 +661,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -672,13 +673,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -699,7 +700,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -710,7 +711,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor was initialized with the document's index type
mock_index_processor_factory.assert_called_once_with("text_model")
@ -736,7 +737,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -748,13 +749,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -770,13 +771,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name=f"Test Document {i}",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -801,7 +802,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{i}_{j}",
index_node_hash=f"hash_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -814,7 +815,7 @@ class TestDealDatasetVectorIndexTask:
# Verify all documents were processed
for document in documents:
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor load was called multiple times
mock_factory = mock_index_processor_factory.return_value
@ -839,7 +840,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -851,13 +852,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -871,13 +872,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -898,7 +899,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -916,7 +917,7 @@ class TestDealDatasetVectorIndexTask:
# Verify final document status
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
def test_deal_dataset_vector_index_task_with_disabled_documents(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
@ -936,7 +937,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -948,13 +949,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -968,13 +969,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Enabled Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -987,13 +988,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Disabled Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # This document should be skipped
archived=False,
batch="test_batch",
@ -1015,7 +1016,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -1026,13 +1027,13 @@ class TestDealDatasetVectorIndexTask:
# Verify only enabled document was processed
updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first()
assert updated_enabled_document.indexing_status == "completed"
assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED
# Verify disabled document status remains unchanged
updated_disabled_document = (
db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first()
)
assert updated_disabled_document.indexing_status == "completed" # Should not change
assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change
# Verify index processor load was called only once (for enabled document)
mock_factory = mock_index_processor_factory.return_value
@ -1057,7 +1058,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -1069,13 +1070,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -1089,13 +1090,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Active Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -1108,13 +1109,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Archived Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # This document should be skipped
batch="test_batch",
@ -1136,7 +1137,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -1147,13 +1148,13 @@ class TestDealDatasetVectorIndexTask:
# Verify only active document was processed
updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first()
assert updated_active_document.indexing_status == "completed"
assert updated_active_document.indexing_status == IndexingStatus.COMPLETED
# Verify archived document status remains unchanged
updated_archived_document = (
db_session_with_containers.query(Document).filter_by(id=archived_document.id).first()
)
assert updated_archived_document.indexing_status == "completed" # Should not change
assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change
# Verify index processor load was called only once (for active document)
mock_factory = mock_index_processor_factory.return_value
@ -1178,7 +1179,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -1190,13 +1191,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -1210,13 +1211,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Completed Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -1229,13 +1230,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Incomplete Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="indexing", # This document should be skipped
indexing_status=IndexingStatus.INDEXING, # This document should be skipped
enabled=True,
archived=False,
batch="test_batch",
@ -1257,7 +1258,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -1270,13 +1271,13 @@ class TestDealDatasetVectorIndexTask:
updated_completed_document = (
db_session_with_containers.query(Document).filter_by(id=completed_document.id).first()
)
assert updated_completed_document.indexing_status == "completed"
assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED
# Verify incomplete document status remains unchanged
updated_incomplete_document = (
db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first()
)
assert updated_incomplete_document.indexing_status == "indexing" # Should not change
assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change
# Verify index processor load was called only once (for completed document)
mock_factory = mock_index_processor_factory.return_value

View File

@ -14,6 +14,7 @@ from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Dataset, Document, DocumentSegment, Tenant
from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
logger = logging.getLogger(__name__)
@ -106,7 +107,7 @@ class TestDeleteSegmentFromIndexTask:
dataset.description = fake.text(max_nb_chars=200)
dataset.provider = "vendor"
dataset.permission = "only_me"
dataset.data_source_type = "upload_file"
dataset.data_source_type = DataSourceType.UPLOAD_FILE
dataset.indexing_technique = "high_quality"
dataset.index_struct = '{"type": "paragraph"}'
dataset.created_by = account.id
@ -145,7 +146,7 @@ class TestDeleteSegmentFromIndexTask:
document.data_source_info = kwargs.get("data_source_info", "{}")
document.batch = kwargs.get("batch", fake.uuid4())
document.name = kwargs.get("name", f"Test Document {fake.word()}")
document.created_from = kwargs.get("created_from", "api")
document.created_from = kwargs.get("created_from", DocumentCreatedFrom.API)
document.created_by = account.id
document.created_at = fake.date_time_this_year()
document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year())
@ -162,7 +163,7 @@ class TestDeleteSegmentFromIndexTask:
document.enabled = kwargs.get("enabled", True)
document.archived = kwargs.get("archived", False)
document.updated_at = fake.date_time_this_year()
document.doc_type = kwargs.get("doc_type", "text")
document.doc_type = kwargs.get("doc_type", DocumentDocType.PERSONAL_DOCUMENT)
document.doc_metadata = kwargs.get("doc_metadata", {})
document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX)
document.doc_language = kwargs.get("doc_language", "en")
@ -204,7 +205,7 @@ class TestDeleteSegmentFromIndexTask:
segment.index_node_hash = fake.sha256()
segment.hit_count = 0
segment.enabled = True
segment.status = "completed"
segment.status = SegmentStatus.COMPLETED
segment.created_by = account.id
segment.created_at = fake.date_time_this_year()
segment.updated_by = account.id
@ -386,7 +387,7 @@ class TestDeleteSegmentFromIndexTask:
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(
db_session_with_containers, dataset, account, fake, indexing_status="indexing"
db_session_with_containers, dataset, account, fake, indexing_status=IndexingStatus.INDEXING
)
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)

View File

@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
logger = logging.getLogger(__name__)
@ -97,7 +98,7 @@ class TestDisableSegmentFromIndexTask:
tenant_id=tenant.id,
name=fake.sentence(nb_words=3),
description=fake.text(max_nb_chars=200),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -132,12 +133,12 @@ class TestDisableSegmentFromIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=fake.uuid4(),
name=fake.file_name(),
created_from="api",
created_from=DocumentCreatedFrom.API,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form=doc_form,
@ -189,7 +190,7 @@ class TestDisableSegmentFromIndexTask:
status=status,
enabled=enabled,
created_by=account.id,
completed_at=datetime.now(UTC) if status == "completed" else None,
completed_at=datetime.now(UTC) if status == SegmentStatus.COMPLETED else None,
)
db_session_with_containers.add(segment)
db_session_with_containers.commit()
@ -271,7 +272,7 @@ class TestDisableSegmentFromIndexTask:
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(
db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True
db_session_with_containers, document, dataset, tenant, account, status=SegmentStatus.INDEXING, enabled=True
)
# Act: Execute the task

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from models import Account, Dataset, DocumentSegment
from models import Document as DatasetDocument
from models.dataset import DatasetProcessRule
from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
@ -100,7 +101,7 @@ class TestDisableSegmentsFromIndexTask:
description=fake.text(max_nb_chars=200),
provider="vendor",
permission="only_me",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
updated_by=account.id,
@ -134,11 +135,11 @@ class TestDisableSegmentsFromIndexTask:
document.tenant_id = dataset.tenant_id
document.dataset_id = dataset.id
document.position = 1
document.data_source_type = "upload_file"
document.data_source_type = DataSourceType.UPLOAD_FILE
document.data_source_info = '{"upload_file_id": "test_file_id"}'
document.batch = fake.uuid4()
document.name = f"Test Document {fake.word()}.txt"
document.created_from = "upload_file"
document.created_from = DocumentCreatedFrom.WEB
document.created_by = account.id
document.created_api_request_id = fake.uuid4()
document.processing_started_at = fake.date_time_this_year()
@ -197,7 +198,7 @@ class TestDisableSegmentsFromIndexTask:
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
segment.status = "completed"
segment.status = SegmentStatus.COMPLETED
segment.created_by = account.id
segment.updated_by = account.id
segment.indexing_at = fake.date_time_this_year()
@ -230,7 +231,7 @@ class TestDisableSegmentsFromIndexTask:
process_rule.id = fake.uuid4()
process_rule.tenant_id = dataset.tenant_id
process_rule.dataset_id = dataset.id
process_rule.mode = "automatic"
process_rule.mode = ProcessRuleMode.AUTOMATIC
process_rule.rules = (
"{"
'"mode": "automatic", '

View File

@ -16,6 +16,7 @@ import pytest
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.document_indexing_sync_task import document_indexing_sync_task
@ -54,7 +55,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
tenant_id=tenant_id,
name=f"dataset-{uuid4()}",
description="sync test dataset",
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
indexing_technique="high_quality",
created_by=created_by,
)
@ -76,11 +77,11 @@ class DocumentIndexingSyncTaskTestDataFactory:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(data_source_info) if data_source_info is not None else None,
batch="test-batch",
name=f"doc-{uuid4()}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status=indexing_status,
enabled=True,
@ -113,7 +114,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
word_count=10,
tokens=5,
index_node_id=f"node-{document_id}-{i}",
status="completed",
status=SegmentStatus.COMPLETED,
created_by=created_by,
)
db_session_with_containers.add(segment)
@ -181,7 +182,7 @@ class TestDocumentIndexingSyncTask:
dataset_id=dataset.id,
created_by=account.id,
data_source_info=notion_info,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
segments = DocumentIndexingSyncTaskTestDataFactory.create_segments(
@ -276,7 +277,7 @@ class TestDocumentIndexingSyncTask:
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
assert updated_document is not None
assert updated_document.indexing_status == "error"
assert updated_document.indexing_status == IndexingStatus.ERROR
assert "Datasource credential not found" in updated_document.error
assert updated_document.stopped_at is not None
mock_external_dependencies["indexing_runner"].run.assert_not_called()
@ -301,7 +302,7 @@ class TestDocumentIndexingSyncTask:
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
assert updated_document.processing_started_at is None
assert remaining_segments == 3
mock_external_dependencies["index_processor"].clean.assert_not_called()
@ -327,7 +328,7 @@ class TestDocumentIndexingSyncTask:
)
assert updated_document is not None
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
assert updated_document.data_source_info_dict.get("last_edited_time") == "2024-01-02T00:00:00Z"
assert remaining_segments == 0
@ -369,7 +370,7 @@ class TestDocumentIndexingSyncTask:
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
assert updated_document is not None
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
mock_external_dependencies["index_processor"].clean.assert_not_called()
mock_external_dependencies["indexing_runner"].run.assert_called_once()
@ -393,7 +394,7 @@ class TestDocumentIndexingSyncTask:
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert remaining_segments == 0
mock_external_dependencies["indexing_runner"].run.assert_called_once()
@ -412,7 +413,7 @@ class TestDocumentIndexingSyncTask:
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
assert updated_document is not None
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.error is None
def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies):
@ -430,7 +431,7 @@ class TestDocumentIndexingSyncTask:
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
assert updated_document is not None
assert updated_document.indexing_status == "error"
assert updated_document.indexing_status == IndexingStatus.ERROR
assert "Indexing error" in updated_document.error
assert updated_document.stopped_at is not None

View File

@ -8,6 +8,7 @@ from core.entities.document_task import DocumentTask
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from tasks.document_indexing_task import (
_document_indexing, # Core function
_document_indexing_with_tenant_queue, # Tenant queue wrapper function
@ -97,7 +98,7 @@ class TestDocumentIndexingTasks:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -112,12 +113,12 @@ class TestDocumentIndexingTasks:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
)
db_session_with_containers.add(document)
@ -179,7 +180,7 @@ class TestDocumentIndexingTasks:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -194,12 +195,12 @@ class TestDocumentIndexingTasks:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
)
db_session_with_containers.add(document)
@ -250,7 +251,7 @@ class TestDocumentIndexingTasks:
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
# Verify the run method was called with correct documents
@ -320,7 +321,7 @@ class TestDocumentIndexingTasks:
# Re-query documents from database since _document_indexing uses a different session
for doc_id in existing_document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
# Verify the run method was called with only existing documents
@ -367,7 +368,7 @@ class TestDocumentIndexingTasks:
# Re-query documents from database since _document_indexing close the session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
def test_document_indexing_task_mixed_document_states(
@ -397,12 +398,12 @@ class TestDocumentIndexingTasks:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=2,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=dataset.created_by,
indexing_status="completed", # Already completed
indexing_status=IndexingStatus.COMPLETED, # Already completed
enabled=True,
)
db_session_with_containers.add(doc1)
@ -414,12 +415,12 @@ class TestDocumentIndexingTasks:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=3,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=dataset.created_by,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=False, # Disabled
)
db_session_with_containers.add(doc2)
@ -444,7 +445,7 @@ class TestDocumentIndexingTasks:
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
# Verify the run method was called with all documents
@ -482,12 +483,12 @@ class TestDocumentIndexingTasks:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=i + 3,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=dataset.created_by,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
)
db_session_with_containers.add(document)
@ -507,7 +508,7 @@ class TestDocumentIndexingTasks:
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.indexing_status == IndexingStatus.ERROR
assert updated_document.error is not None
assert "batch upload" in updated_document.error
assert updated_document.stopped_at is not None
@ -548,7 +549,7 @@ class TestDocumentIndexingTasks:
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
def test_document_indexing_task_document_is_paused_error(
@ -591,7 +592,7 @@ class TestDocumentIndexingTasks:
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
# ==================== NEW TESTS FOR REFACTORED FUNCTIONS ====================
@ -702,7 +703,7 @@ class TestDocumentIndexingTasks:
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
# Verify the run method was called with correct documents
@ -827,7 +828,7 @@ class TestDocumentIndexingTasks:
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
# Verify waiting task was still processed despite core processing error

View File

@ -5,6 +5,7 @@ from faker import Faker
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.document_indexing_update_task import document_indexing_update_task
@ -61,7 +62,7 @@ class TestDocumentIndexingUpdateTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=64),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -72,12 +73,12 @@ class TestDocumentIndexingUpdateTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
doc_form="text_model",
)
@ -98,7 +99,7 @@ class TestDocumentIndexingUpdateTask:
word_count=10,
tokens=5,
index_node_id=node_id,
status="completed",
status=SegmentStatus.COMPLETED,
created_by=account.id,
)
db_session_with_containers.add(seg)
@ -122,7 +123,7 @@ class TestDocumentIndexingUpdateTask:
# Assert document status updated before reindex
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
assert updated.indexing_status == "parsing"
assert updated.indexing_status == IndexingStatus.PARSING
assert updated.processing_started_at is not None
# Segments should be deleted

View File

@ -7,6 +7,7 @@ from core.indexing_runner import DocumentIsPausedError
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.duplicate_document_indexing_task import (
_duplicate_document_indexing_task, # Core function
_duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function
@ -107,7 +108,7 @@ class TestDuplicateDocumentIndexingTasks:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -122,12 +123,12 @@ class TestDuplicateDocumentIndexingTasks:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
doc_form="text_model",
)
@ -177,7 +178,7 @@ class TestDuplicateDocumentIndexingTasks:
content=fake.text(max_nb_chars=200),
word_count=50,
tokens=100,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
indexing_at=fake.date_time_this_year(),
created_by=dataset.created_by, # Add required field
@ -242,7 +243,7 @@ class TestDuplicateDocumentIndexingTasks:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -257,12 +258,12 @@ class TestDuplicateDocumentIndexingTasks:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
doc_form="text_model",
)
@ -316,7 +317,7 @@ class TestDuplicateDocumentIndexingTasks:
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
# Verify the run method was called with correct documents
@ -368,7 +369,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were updated to parsing status
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
# Verify indexing runner was called
@ -437,7 +438,7 @@ class TestDuplicateDocumentIndexingTasks:
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in existing_document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
# Verify the run method was called with only existing documents
@ -484,7 +485,7 @@ class TestDuplicateDocumentIndexingTasks:
# Re-query documents from database since _duplicate_document_indexing_task close the session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.processing_started_at is not None
def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
@ -516,12 +517,12 @@ class TestDuplicateDocumentIndexingTasks:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=i + 3,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=dataset.created_by,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
doc_form="text_model",
)
@ -542,7 +543,7 @@ class TestDuplicateDocumentIndexingTasks:
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.indexing_status == IndexingStatus.ERROR
assert updated_document.error is not None
assert "batch upload" in updated_document.error.lower()
assert updated_document.stopped_at is not None
@ -584,7 +585,7 @@ class TestDuplicateDocumentIndexingTasks:
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.indexing_status == IndexingStatus.ERROR
assert updated_document.error is not None
assert "limit" in updated_document.error.lower()
assert updated_document.stopped_at is not None
@ -648,7 +649,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were processed
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
def test_normal_duplicate_document_indexing_task_with_tenant_queue(
@ -691,7 +692,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were processed
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
def test_priority_duplicate_document_indexing_task_with_tenant_queue(
@ -735,7 +736,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were processed
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
def test_tenant_queue_wrapper_processes_next_tasks(
@ -851,7 +852,7 @@ class TestDuplicateDocumentIndexingTasks:
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.is_paused is True
assert updated_document.indexing_status == "parsing"
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.display_status == "paused"
assert updated_document.processing_started_at is not None
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()

View File

@ -8,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
@ -79,7 +80,7 @@ class TestEnableSegmentsToIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -92,12 +93,12 @@ class TestEnableSegmentsToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
@ -110,7 +111,13 @@ class TestEnableSegmentsToIndexTask:
return dataset, document
def _create_test_segments(
self, db_session_with_containers: Session, document, dataset, count=3, enabled=False, status="completed"
self,
db_session_with_containers: Session,
document,
dataset,
count=3,
enabled=False,
status=SegmentStatus.COMPLETED,
):
"""
Helper method to create test document segments.
@ -278,7 +285,7 @@ class TestEnableSegmentsToIndexTask:
invalid_statuses = [
("disabled", {"enabled": False}),
("archived", {"archived": True}),
("not_completed", {"indexing_status": "processing"}),
("not_completed", {"indexing_status": IndexingStatus.INDEXING}),
]
for _, status_attrs in invalid_statuses:
@ -447,7 +454,7 @@ class TestEnableSegmentsToIndexTask:
for segment in segments:
db_session_with_containers.refresh(segment)
assert segment.enabled is False
assert segment.status == "error"
assert segment.status == SegmentStatus.ERROR
assert segment.error is not None
assert "Index processing failed" in segment.error
assert segment.disabled_at is not None

View File

@ -30,6 +30,7 @@ from controllers.console.datasets.error import (
InvalidActionError,
InvalidMetadataError,
)
from models.enums import DataSourceType, IndexingStatus
def unwrap(func):
@ -62,8 +63,8 @@ def document():
return MagicMock(
id="doc-1",
tenant_id="tenant-1",
indexing_status="indexing",
data_source_type="upload_file",
indexing_status=IndexingStatus.INDEXING,
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info_dict={"upload_file_id": "file-1"},
doc_form="text",
archived=False,
@ -407,7 +408,7 @@ class TestDocumentProcessingApi:
api = DocumentProcessingApi()
method = unwrap(api.patch)
doc = MagicMock(indexing_status="error", is_paused=True)
doc = MagicMock(indexing_status=IndexingStatus.ERROR, is_paused=True)
with (
app.test_request_context("/"),
@ -425,7 +426,7 @@ class TestDocumentProcessingApi:
api = DocumentProcessingApi()
method = unwrap(api.patch)
document = MagicMock(indexing_status="paused", is_paused=True)
document = MagicMock(indexing_status=IndexingStatus.PAUSED, is_paused=True)
with (
app.test_request_context("/"),
@ -461,7 +462,7 @@ class TestDocumentProcessingApi:
api = DocumentProcessingApi()
method = unwrap(api.patch)
document = MagicMock(indexing_status="completed")
document = MagicMock(indexing_status=IndexingStatus.COMPLETED)
with app.test_request_context("/"), patch.object(api, "get_document", return_value=document):
with pytest.raises(InvalidActionError):
@ -630,7 +631,7 @@ class TestDocumentRetryApi:
payload = {"document_ids": ["doc-1"]}
document = MagicMock(indexing_status="indexing", archived=False)
document = MagicMock(indexing_status=IndexingStatus.INDEXING, archived=False)
with (
app.test_request_context("/", json=payload),
@ -659,7 +660,7 @@ class TestDocumentRetryApi:
payload = {"document_ids": ["doc-1"]}
document = MagicMock(indexing_status="completed", archived=False)
document = MagicMock(indexing_status=IndexingStatus.COMPLETED, archived=False)
with (
app.test_request_context("/", json=payload),
@ -817,8 +818,8 @@ class TestDocumentIndexingEstimateApi:
method = unwrap(api.get)
document = MagicMock(
indexing_status="indexing",
data_source_type="upload_file",
indexing_status=IndexingStatus.INDEXING,
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info_dict={"upload_file_id": "file-1"},
tenant_id="tenant-1",
doc_form="text",
@ -844,8 +845,8 @@ class TestDocumentIndexingEstimateApi:
method = unwrap(api.get)
document = MagicMock(
indexing_status="indexing",
data_source_type="upload_file",
indexing_status=IndexingStatus.INDEXING,
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info_dict={"upload_file_id": "file-1"},
tenant_id="tenant-1",
doc_form="text",
@ -882,7 +883,7 @@ class TestDocumentIndexingEstimateApi:
api = DocumentIndexingEstimateApi()
method = unwrap(api.get)
document = MagicMock(indexing_status="completed")
document = MagicMock(indexing_status=IndexingStatus.COMPLETED)
with app.test_request_context("/"), patch.object(api, "get_document", return_value=document):
with pytest.raises(DocumentAlreadyFinishedError):
@ -963,8 +964,8 @@ class TestDocumentBatchIndexingEstimateApi:
method = unwrap(api.get)
doc = MagicMock(
indexing_status="indexing",
data_source_type="website_crawl",
indexing_status=IndexingStatus.INDEXING,
data_source_type=DataSourceType.WEBSITE_CRAWL,
data_source_info_dict={
"provider": "firecrawl",
"job_id": "j1",
@ -992,8 +993,8 @@ class TestDocumentBatchIndexingEstimateApi:
method = unwrap(api.get)
doc = MagicMock(
indexing_status="indexing",
data_source_type="notion_import",
indexing_status=IndexingStatus.INDEXING,
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info_dict={
"credential_id": "c1",
"notion_workspace_id": "w1",
@ -1020,7 +1021,7 @@ class TestDocumentBatchIndexingEstimateApi:
method = unwrap(api.get)
document = MagicMock(
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
data_source_type="unknown",
data_source_info_dict={},
doc_form="text",
@ -1130,7 +1131,7 @@ class TestDocumentProcessingApiResume:
api = DocumentProcessingApi()
method = unwrap(api.patch)
document = MagicMock(indexing_status="completed", is_paused=False)
document = MagicMock(indexing_status=IndexingStatus.COMPLETED, is_paused=False)
with app.test_request_context("/"), patch.object(api, "get_document", return_value=document):
with pytest.raises(InvalidActionError):
@ -1348,8 +1349,8 @@ class TestDocumentIndexingEdgeCases:
method = unwrap(api.get)
document = MagicMock(
indexing_status="indexing",
data_source_type="upload_file",
indexing_status=IndexingStatus.INDEXING,
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info_dict={"upload_file_id": "file-1"},
tenant_id="tenant-1",
doc_form="text",

View File

@ -32,6 +32,7 @@ from controllers.service_api.dataset.segment import (
SegmentListQuery,
)
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.dataset_service import DocumentService, SegmentService
@ -657,12 +658,27 @@ class TestSegmentIndexingRequirements:
dataset.indexing_technique = technique
assert dataset.indexing_technique in ["high_quality", "economy"]
@pytest.mark.parametrize("status", ["waiting", "parsing", "indexing", "completed", "error"])
@pytest.mark.parametrize(
"status",
[
IndexingStatus.WAITING,
IndexingStatus.PARSING,
IndexingStatus.INDEXING,
IndexingStatus.COMPLETED,
IndexingStatus.ERROR,
],
)
def test_valid_indexing_statuses(self, status):
"""Test valid document indexing statuses."""
document = Mock(spec=Document)
document.indexing_status = status
assert document.indexing_status in ["waiting", "parsing", "indexing", "completed", "error"]
assert document.indexing_status in {
IndexingStatus.WAITING,
IndexingStatus.PARSING,
IndexingStatus.INDEXING,
IndexingStatus.COMPLETED,
IndexingStatus.ERROR,
}
def test_completed_status_required_for_segments(self):
"""Test that completed status is required for segment operations."""

View File

@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import (
InvalidMetadataError,
)
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
from models.enums import IndexingStatus
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel
@ -244,23 +245,26 @@ class TestDocumentService:
class TestDocumentIndexingStatus:
"""Test document indexing status values."""
_VALID_STATUSES = {
IndexingStatus.WAITING,
IndexingStatus.PARSING,
IndexingStatus.INDEXING,
IndexingStatus.COMPLETED,
IndexingStatus.ERROR,
IndexingStatus.PAUSED,
}
def test_completed_status(self):
"""Test completed status."""
status = "completed"
valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"]
assert status in valid_statuses
assert IndexingStatus.COMPLETED in self._VALID_STATUSES
def test_indexing_status(self):
"""Test indexing status."""
status = "indexing"
valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"]
assert status in valid_statuses
assert IndexingStatus.INDEXING in self._VALID_STATUSES
def test_error_status(self):
"""Test error status."""
status = "error"
valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"]
assert status in valid_statuses
assert IndexingStatus.ERROR in self._VALID_STATUSES
class TestDocumentDocForm:

View File

@ -25,6 +25,13 @@ from models.dataset import (
DocumentSegment,
Embedding,
)
from models.enums import (
DataSourceType,
DocumentCreatedFrom,
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
)
class TestDatasetModelValidation:
@ -40,14 +47,14 @@ class TestDatasetModelValidation:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
# Assert
assert dataset.name == "Test Dataset"
assert dataset.tenant_id == tenant_id
assert dataset.data_source_type == "upload_file"
assert dataset.data_source_type == DataSourceType.UPLOAD_FILE
assert dataset.created_by == created_by
# Note: Default values are set by database, not by model instantiation
@ -57,7 +64,7 @@ class TestDatasetModelValidation:
dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
description="Test description",
indexing_technique="high_quality",
@ -77,14 +84,14 @@ class TestDatasetModelValidation:
dataset_high_quality = Dataset(
tenant_id=str(uuid4()),
name="High Quality Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
indexing_technique="high_quality",
)
dataset_economy = Dataset(
tenant_id=str(uuid4()),
name="Economy Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
indexing_technique="economy",
)
@ -101,14 +108,14 @@ class TestDatasetModelValidation:
dataset_vendor = Dataset(
tenant_id=str(uuid4()),
name="Vendor Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
provider="vendor",
)
dataset_external = Dataset(
tenant_id=str(uuid4()),
name="External Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
provider="external",
)
@ -126,7 +133,7 @@ class TestDatasetModelValidation:
dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
index_struct=json.dumps(index_struct_data),
)
@ -145,7 +152,7 @@ class TestDatasetModelValidation:
dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
)
@ -161,7 +168,7 @@ class TestDatasetModelValidation:
dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
)
@ -178,7 +185,7 @@ class TestDatasetModelValidation:
dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
)
@ -218,10 +225,10 @@ class TestDocumentModelRelationships:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test_document.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
@ -229,10 +236,10 @@ class TestDocumentModelRelationships:
assert document.tenant_id == tenant_id
assert document.dataset_id == dataset_id
assert document.position == 1
assert document.data_source_type == "upload_file"
assert document.data_source_type == DataSourceType.UPLOAD_FILE
assert document.batch == "batch_001"
assert document.name == "test_document.pdf"
assert document.created_from == "web"
assert document.created_from == DocumentCreatedFrom.WEB
assert document.created_by == created_by
# Note: Default values are set by database, not by model instantiation
@ -250,12 +257,12 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
)
# Act
@ -271,12 +278,12 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
indexing_status="parsing",
indexing_status=IndexingStatus.PARSING,
is_paused=True,
)
@ -289,15 +296,20 @@ class TestDocumentModelRelationships:
def test_document_display_status_indexing(self):
"""Test document display_status property for indexing state."""
# Arrange
for indexing_status in ["parsing", "cleaning", "splitting", "indexing"]:
for indexing_status in [
IndexingStatus.PARSING,
IndexingStatus.CLEANING,
IndexingStatus.SPLITTING,
IndexingStatus.INDEXING,
]:
document = Document(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
indexing_status=indexing_status,
)
@ -315,12 +327,12 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
# Act
@ -336,12 +348,12 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -359,12 +371,12 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False,
archived=False,
)
@ -382,12 +394,12 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
archived=True,
)
@ -405,10 +417,10 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
data_source_info=json.dumps(data_source_info),
)
@ -428,10 +440,10 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
)
@ -448,10 +460,10 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
word_count=1000,
)
@ -471,10 +483,10 @@ class TestDocumentModelRelationships:
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
word_count=0,
)
@ -582,7 +594,7 @@ class TestDocumentSegmentIndexing:
word_count=1,
tokens=2,
created_by=str(uuid4()),
status="waiting",
status=SegmentStatus.WAITING,
)
segment_completed = DocumentSegment(
tenant_id=str(uuid4()),
@ -593,12 +605,12 @@ class TestDocumentSegmentIndexing:
word_count=1,
tokens=2,
created_by=str(uuid4()),
status="completed",
status=SegmentStatus.COMPLETED,
)
# Assert
assert segment_waiting.status == "waiting"
assert segment_completed.status == "completed"
assert segment_waiting.status == SegmentStatus.WAITING
assert segment_completed.status == SegmentStatus.COMPLETED
def test_document_segment_enabled_disabled_tracking(self):
"""Test document segment enabled/disabled state tracking."""
@ -769,13 +781,13 @@ class TestDatasetProcessRule:
# Act
process_rule = DatasetProcessRule(
dataset_id=dataset_id,
mode="automatic",
mode=ProcessRuleMode.AUTOMATIC,
created_by=created_by,
)
# Assert
assert process_rule.dataset_id == dataset_id
assert process_rule.mode == "automatic"
assert process_rule.mode == ProcessRuleMode.AUTOMATIC
assert process_rule.created_by == created_by
def test_dataset_process_rule_modes(self):
@ -797,7 +809,7 @@ class TestDatasetProcessRule:
}
process_rule = DatasetProcessRule(
dataset_id=str(uuid4()),
mode="custom",
mode=ProcessRuleMode.CUSTOM,
created_by=str(uuid4()),
rules=json.dumps(rules_data),
)
@ -817,7 +829,7 @@ class TestDatasetProcessRule:
rules_data = {"test": "data"}
process_rule = DatasetProcessRule(
dataset_id=dataset_id,
mode="automatic",
mode=ProcessRuleMode.AUTOMATIC,
created_by=str(uuid4()),
rules=json.dumps(rules_data),
)
@ -827,7 +839,7 @@ class TestDatasetProcessRule:
# Assert
assert result["dataset_id"] == dataset_id
assert result["mode"] == "automatic"
assert result["mode"] == ProcessRuleMode.AUTOMATIC
assert result["rules"] == rules_data
def test_dataset_process_rule_automatic_rules(self):
@ -969,7 +981,7 @@ class TestModelIntegration:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
indexing_technique="high_quality",
)
@ -980,10 +992,10 @@ class TestModelIntegration:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
word_count=100,
)
@ -999,7 +1011,7 @@ class TestModelIntegration:
word_count=3,
tokens=5,
created_by=created_by,
status="completed",
status=SegmentStatus.COMPLETED,
)
# Assert
@ -1009,7 +1021,7 @@ class TestModelIntegration:
assert segment.document_id == document_id
assert dataset.indexing_technique == "high_quality"
assert document.word_count == 100
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
def test_document_to_dict_serialization(self):
"""Test document to_dict method for serialization."""
@ -1022,13 +1034,13 @@ class TestModelIntegration:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
word_count=100,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Mock segment_count and hit_count
@ -1044,6 +1056,6 @@ class TestModelIntegration:
assert result["dataset_id"] == dataset_id
assert result["name"] == "test.pdf"
assert result["word_count"] == 100
assert result["indexing_status"] == "completed"
assert result["indexing_status"] == IndexingStatus.COMPLETED
assert result["segment_count"] == 5
assert result["hit_count"] == 10

View File

@ -11,6 +11,7 @@ from unittest.mock import MagicMock
import pytest
import services.summary_index_service as summary_module
from models.enums import SegmentStatus, SummaryStatus
from services.summary_index_service import SummaryIndexService
@ -42,7 +43,7 @@ def _segment(*, has_document: bool = True) -> MagicMock:
segment.dataset_id = "dataset-1"
segment.content = "hello world"
segment.enabled = True
segment.status = "completed"
segment.status = SegmentStatus.COMPLETED
segment.position = 1
if has_document:
doc = MagicMock(name="document")
@ -64,7 +65,7 @@ def _summary_record(*, summary_content: str = "summary", node_id: str | None = N
record.summary_index_node_id = node_id
record.summary_index_node_hash = None
record.tokens = None
record.status = "generating"
record.status = SummaryStatus.GENERATING
record.error = None
record.enabled = True
record.created_at = datetime(2024, 1, 1, tzinfo=UTC)
@ -133,10 +134,10 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes
segment = _segment()
dataset = _dataset()
result = SummaryIndexService.create_summary_record(segment, dataset, "new", status="generating")
result = SummaryIndexService.create_summary_record(segment, dataset, "new", status=SummaryStatus.GENERATING)
assert result is existing
assert existing.summary_content == "new"
assert existing.status == "generating"
assert existing.status == SummaryStatus.GENERATING
assert existing.enabled is True
assert existing.disabled_at is None
assert existing.disabled_by is None
@ -155,7 +156,7 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status="generating")
record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status=SummaryStatus.GENERATING)
assert record.dataset_id == "dataset-1"
assert record.chunk_id == "seg-1"
assert record.summary_content == "new"
@ -204,7 +205,7 @@ def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch:
assert vector_instance.add_texts.call_count == 2
summary_module.time.sleep.assert_called_once() # type: ignore[attr-defined]
session.flush.assert_called_once()
assert summary.status == "completed"
assert summary.status == SummaryStatus.COMPLETED
assert summary.summary_index_node_id == "uuid-1"
assert summary.summary_index_node_hash == "hash-1"
assert summary.tokens == 5
@ -245,7 +246,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat
create_session_mock.assert_called()
session.add.assert_called()
session.commit.assert_called_once()
assert summary.status == "completed"
assert summary.status == SummaryStatus.COMPLETED
assert summary.summary_index_node_id == "old-node" # reused
@ -275,7 +276,7 @@ def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytes
with pytest.raises(RuntimeError, match="boom"):
SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None)
assert summary.status == "error"
assert summary.status == SummaryStatus.ERROR
assert "Vectorization failed" in (summary.error or "")
error_session.commit.assert_called_once()
@ -310,7 +311,7 @@ def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.Mo
SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))),
)
SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status="not_started")
SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status=SummaryStatus.NOT_STARTED)
session.commit.assert_called_once()
assert existing.enabled is True
@ -332,7 +333,7 @@ def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.Mon
)
SummaryIndexService.update_summary_record_error(segment, dataset, "err")
assert record.status == "error"
assert record.status == SummaryStatus.ERROR
assert record.error == "err"
session.commit.assert_called_once()
@ -387,7 +388,7 @@ def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch
with pytest.raises(RuntimeError, match="boom"):
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True})
assert record.status == "error"
assert record.status == SummaryStatus.ERROR
# Outer exception handler overwrites the error with the raw exception message.
assert record.error == "boom"
@ -614,7 +615,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo
monkeypatch.setattr(summary_module, "logger", logger_mock)
result = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True})
assert result.status in {"generating", "completed"}
assert result.status in {SummaryStatus.GENERATING, SummaryStatus.COMPLETED}
logger_mock.info.assert_called()
@ -787,7 +788,7 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt
segment = _segment()
segment.id = summary.chunk_id
segment.enabled = True
segment.status = "completed"
segment.status = SegmentStatus.COMPLETED
session = MagicMock()
summary_query = MagicMock()
@ -850,11 +851,11 @@ def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vect
bad_segment = _segment()
bad_segment.enabled = False
bad_segment.status = "completed"
bad_segment.status = SegmentStatus.COMPLETED
good_segment = _segment()
good_segment.enabled = True
good_segment.status = "completed"
good_segment.status = SegmentStatus.COMPLETED
session = MagicMock()
summary_query = MagicMock()
@ -1084,7 +1085,7 @@ def test_update_summary_for_segment_existing_vectorize_failure_returns_error_rec
out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new")
assert out is record
assert out.status == "error"
assert out.status == SummaryStatus.ERROR
assert "Vectorization failed" in (out.error or "")
@ -1133,7 +1134,7 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk
with pytest.raises(RuntimeError, match="flush boom"):
SummaryIndexService.update_summary_for_segment(segment, dataset, "new")
assert record.status == "error"
assert record.status == SummaryStatus.ERROR
assert record.error == "flush boom"
session.commit.assert_called()
@ -1222,7 +1223,7 @@ def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: py
monkeypatch.setattr(
SummaryIndexService,
"get_segments_summaries",
MagicMock(return_value={"seg-1": SimpleNamespace(status="completed")}),
MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.COMPLETED)}),
)
result = SummaryIndexService.get_documents_summary_index_status(["doc-1"], "dataset-1", "tenant-1")
assert result["doc-1"] is None
@ -1254,7 +1255,7 @@ def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_erro
monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock)
out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new")
assert out.status == "error"
assert out.status == SummaryStatus.ERROR
assert "Vectorization failed" in (out.error or "")
@ -1276,7 +1277,7 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt
monkeypatch.setattr(
SummaryIndexService,
"get_segments_summaries",
MagicMock(return_value={"seg-1": SimpleNamespace(status="generating")}),
MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.GENERATING)}),
)
assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING"
@ -1294,7 +1295,7 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt
monkeypatch.setattr(
SummaryIndexService,
"get_segments_summaries",
MagicMock(return_value={"seg-1": SimpleNamespace(status="not_started")}),
MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.NOT_STARTED)}),
)
result = SummaryIndexService.get_documents_summary_index_status(["doc-1", "doc-2"], "dataset-1", "tenant-1")
assert result["doc-1"] == "SUMMARIZING"
@ -1311,7 +1312,7 @@ def test_get_document_summary_status_detail_counts_and_previews(monkeypatch: pyt
summary1 = _summary_record(summary_content="x" * 150, node_id="n1")
summary1.chunk_id = "seg-1"
summary1.status = "completed"
summary1.status = SummaryStatus.COMPLETED
summary1.error = None
summary1.created_at = datetime(2024, 1, 1, tzinfo=UTC)
summary1.updated_at = datetime(2024, 1, 2, tzinfo=UTC)

View File

@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
import pytest
from models.enums import DataSourceType
from tasks.clean_dataset_task import clean_dataset_task
# ============================================================================
@ -116,7 +117,7 @@ def mock_document():
doc.id = str(uuid.uuid4())
doc.tenant_id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.data_source_type = "upload_file"
doc.data_source_type = DataSourceType.UPLOAD_FILE
doc.data_source_info = '{"upload_file_id": "test-file-id"}'
doc.data_source_info_dict = {"upload_file_id": "test-file-id"}
return doc

View File

@ -19,6 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document
from models.enums import IndexingStatus
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
from tasks.document_indexing_task import (
_document_indexing,
@ -424,7 +425,7 @@ class TestBatchProcessing:
# Assert - All documents should be set to 'parsing' status
for doc in mock_documents:
assert doc.indexing_status == "parsing"
assert doc.indexing_status == IndexingStatus.PARSING
assert doc.processing_started_at is not None
# IndexingRunner should be called with all documents
@ -573,7 +574,7 @@ class TestProgressTracking:
# Assert - Status should be 'parsing'
for doc in mock_documents:
assert doc.indexing_status == "parsing"
assert doc.indexing_status == IndexingStatus.PARSING
assert doc.processing_started_at is not None
# Verify commit was called to persist status
@ -1158,7 +1159,7 @@ class TestAdvancedScenarios:
# Assert
# All documents should be set to parsing (no limit errors)
for doc in mock_documents:
assert doc.indexing_status == "parsing"
assert doc.indexing_status == IndexingStatus.PARSING
# IndexingRunner should be called with all documents
mock_indexing_runner.run.assert_called_once()
@ -1377,7 +1378,7 @@ class TestPerformanceScenarios:
# Assert
for doc in mock_documents:
assert doc.indexing_status == "parsing"
assert doc.indexing_status == IndexingStatus.PARSING
mock_indexing_runner.run.assert_called_once()
call_args = mock_indexing_runner.run.call_args[0][0]