diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index c1fe769997..977ee1192c 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -148,6 +148,11 @@ class _EstimateRules(BaseModel): return list(seen.values()) +class _EstimateHierarchicalRules(_EstimateRules): + parent_mode: Literal["full-doc", "paragraph"] | None = None + subchunk_segmentation: _EstimateSegmentation | None = None + + class _SummaryIndexSettingDisabled(BaseModel): enable: Literal[False] = False @@ -203,7 +208,7 @@ class _HierarchicalProcessRule(BaseModel): model_config = ConfigDict(extra="allow") mode: Literal[ProcessRuleMode.HIERARCHICAL] - rules: _EstimateRules + rules: _EstimateHierarchicalRules summary_index_setting: _SummaryIndexSetting | None = None @field_validator("summary_index_setting", mode="before") @@ -2971,6 +2976,10 @@ class DocumentService: process_rule_dict = validated.process_rule.model_dump(exclude_none=True) if validated.process_rule.mode == ProcessRuleMode.AUTOMATIC: process_rule_dict["rules"] = {} + elif validated.process_rule.mode == ProcessRuleMode.HIERARCHICAL: + rules = process_rule_dict.get("rules") + if isinstance(rules, dict) and not rules.get("parent_mode"): + rules["parent_mode"] = "paragraph" args["process_rule"] = process_rule_dict @staticmethod diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py index a78bc7f9d6..9a8243936b 100644 --- a/api/tests/unit_tests/services/test_dataset_service_document.py +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -1344,6 +1344,27 @@ class TestDocumentServiceEstimateValidation: assert args["process_rule"]["rules"]["pre_processing_rules"] == [{"id": "remove_stopwords", "enabled": False}] + def test_estimate_args_validate_custom_mode_drops_hierarchical_fields(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 128}, + "parent_mode": "full-doc", + "subchunk_segmentation": {"separator": "###", "max_tokens": 64}, + }, + }, + } + + DocumentService.estimate_args_validate(args) + + assert args["process_rule"]["rules"] == { + "pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 128}, + } + def test_estimate_args_validate_requires_summary_index_provider_name(self): args = { "info_list": {"data_source_type": "upload_file"}, @@ -1360,6 +1381,43 @@ class TestDocumentServiceEstimateValidation: with pytest.raises(ValueError, match="Field required"): DocumentService.estimate_args_validate(args) + def test_estimate_args_validate_preserves_hierarchical_fields(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "hierarchical", + "rules": { + "pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 512}, + "parent_mode": "full-doc", + "subchunk_segmentation": {"separator": "###", "max_tokens": 128}, + }, + }, + } + + DocumentService.estimate_args_validate(args) + + assert args["process_rule"]["rules"]["parent_mode"] == "full-doc" + assert args["process_rule"]["rules"]["subchunk_segmentation"] == {"separator": "###", "max_tokens": 128} + + def test_estimate_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "hierarchical", + "rules": { + "pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 512}, + "subchunk_segmentation": {"separator": "###", "max_tokens": 128}, + }, + }, + } + + DocumentService.estimate_args_validate(args) + + assert args["process_rule"]["rules"]["parent_mode"] == "paragraph" + assert args["process_rule"]["rules"]["subchunk_segmentation"] == {"separator": "###", "max_tokens": 128} + class TestDocumentServiceSaveDocumentAdditionalBranches: """Additional unit tests for dataset bootstrap and process-rule branches."""