diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 383474f4f6..4f5a95dcde 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,9 +7,10 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal, TypedDict, cast +from typing import Annotated, Any, Literal, TypedDict, cast import sqlalchemy as sa +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator from redis.exceptions import LockNotOwnedError from sqlalchemy import delete, exists, func, select, update from sqlalchemy.orm import Session, sessionmaker @@ -117,6 +118,86 @@ class AutoDisableLogsDict(TypedDict): count: int +class _EstimatePreProcessingRule(BaseModel): + id: str = Field(min_length=1) + enabled: bool + + @field_validator("id") + @classmethod + def _validate_id(cls, v: str) -> str: + if v not in DatasetProcessRule.PRE_PROCESSING_RULES: + raise ValueError("Process rule pre_processing_rules id is invalid") + return v + + +class _EstimateSegmentation(BaseModel): + separator: str = Field(min_length=1) + max_tokens: int = Field(gt=0) + + +class _EstimateRules(BaseModel): + pre_processing_rules: list[_EstimatePreProcessingRule] + segmentation: _EstimateSegmentation + + @field_validator("pre_processing_rules") + @classmethod + def _deduplicate(cls, v: list[_EstimatePreProcessingRule]) -> list[_EstimatePreProcessingRule]: + seen: dict[str, _EstimatePreProcessingRule] = {} + for rule in v: + seen[rule.id] = rule + return list(seen.values()) + + +class _SummaryIndexSettingDisabled(BaseModel): + enable: Literal[False] = False + + +class _SummaryIndexSettingEnabled(BaseModel): + enable: Literal[True] + model_name: str = Field(min_length=1) + model_provider_name: str = Field(min_length=1) + + +_SummaryIndexSetting = Annotated[ + _SummaryIndexSettingDisabled | _SummaryIndexSettingEnabled, + Field(discriminator="enable"), +] + + +class _AutomaticProcessRule(BaseModel): + model_config = ConfigDict(extra="allow") + + mode: Literal[ProcessRuleMode.AUTOMATIC] + summary_index_setting: _SummaryIndexSetting | None = None + + +class _CustomProcessRule(BaseModel): + model_config = ConfigDict(extra="allow") + + mode: Literal[ProcessRuleMode.CUSTOM] + rules: _EstimateRules + summary_index_setting: _SummaryIndexSetting | None = None + + +class _HierarchicalProcessRule(BaseModel): + model_config = ConfigDict(extra="allow") + + mode: Literal[ProcessRuleMode.HIERARCHICAL] + rules: _EstimateRules + summary_index_setting: _SummaryIndexSetting | None = None + + +_EstimateProcessRule = Annotated[ + _AutomaticProcessRule | _CustomProcessRule | _HierarchicalProcessRule, + Field(discriminator="mode"), +] + + +class _EstimateArgs(BaseModel): + info_list: dict[str, Any] + process_rule: _EstimateProcessRule + + class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): @@ -2851,94 +2932,16 @@ class DocumentService: @classmethod def estimate_args_validate(cls, args: dict[str, Any]): - if "info_list" not in args or not args["info_list"]: - raise ValueError("Data source info is required") - - if not isinstance(args["info_list"], dict): - raise ValueError("Data info is invalid") - - if "process_rule" not in args or not args["process_rule"]: - raise ValueError("Process rule is required") - - if not isinstance(args["process_rule"], dict): - raise ValueError("Process rule is invalid") - - if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: - raise ValueError("Process rule mode is required") - - if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: - raise ValueError("Process rule mode is invalid") - - if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC: - args["process_rule"]["rules"] = {} - else: - if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: - raise ValueError("Process rule rules is required") - - if not isinstance(args["process_rule"]["rules"], dict): - raise ValueError("Process rule rules is invalid") - - if ( - "pre_processing_rules" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["pre_processing_rules"] is None - ): - raise ValueError("Process rule pre_processing_rules is required") - - if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): - raise ValueError("Process rule pre_processing_rules is invalid") - - unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: - if "id" not in pre_processing_rule or not pre_processing_rule["id"]: - raise ValueError("Process rule pre_processing_rules id is required") - - if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: - raise ValueError("Process rule pre_processing_rules id is invalid") - - if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: - raise ValueError("Process rule pre_processing_rules enabled is required") - - if not isinstance(pre_processing_rule["enabled"], bool): - raise ValueError("Process rule pre_processing_rules enabled is invalid") - - unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - - args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - - if ( - "segmentation" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["segmentation"] is None - ): - raise ValueError("Process rule segmentation is required") - - if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): - raise ValueError("Process rule segmentation is invalid") - - if ( - "separator" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["separator"] - ): - raise ValueError("Process rule segmentation separator is required") - - if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): - raise ValueError("Process rule segmentation separator is invalid") - - if ( - "max_tokens" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] - ): - raise ValueError("Process rule segmentation max_tokens is required") - - if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): - raise ValueError("Process rule segmentation max_tokens is invalid") - - # valid summary index setting - summary_index_setting = args["process_rule"].get("summary_index_setting") - if summary_index_setting and summary_index_setting.get("enable"): - if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]: - raise ValueError("Summary index model name is required") - if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]: - raise ValueError("Summary index model provider name is required") + try: + validated = _EstimateArgs.model_validate(args) + except ValidationError as e: + first = e.errors()[0] + original = first.get("ctx", {}).get("error") + raise ValueError(str(original) if isinstance(original, ValueError) else first["msg"]) from e + process_rule_dict = validated.process_rule.model_dump(exclude_none=True) + if validated.process_rule.mode == ProcessRuleMode.AUTOMATIC: + process_rule_dict["rules"] = {} + args["process_rule"] = process_rule_dict @staticmethod def batch_update_document_status( diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py index 1633194aa8..a78bc7f9d6 100644 --- a/api/tests/unit_tests/services/test_dataset_service_document.py +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -1297,7 +1297,7 @@ class TestDocumentServiceEstimateValidation: """Unit tests for estimate_args_validate branches.""" def test_estimate_args_validate_rejects_missing_info_list(self): - with pytest.raises(ValueError, match="Data source info is required"): + with pytest.raises(ValueError, match="Field required"): DocumentService.estimate_args_validate({}) def test_estimate_args_validate_sets_empty_rules_for_automatic_mode(self): @@ -1357,7 +1357,7 @@ class TestDocumentServiceEstimateValidation: }, } - with pytest.raises(ValueError, match="Summary index model provider name is required"): + with pytest.raises(ValueError, match="Field required"): DocumentService.estimate_args_validate(args)