refactor: rewrite estimate_args_validate using Pydantic v2 models (#36036)

Signed-off-by: Deepam Goyal <deepam02goyal@gmail.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Deepam Goyal 2026-05-12 11:04:45 +05:30 committed by GitHub
parent cd90d7ffc1
commit 1a93af5cd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 91 deletions

View File

@ -7,9 +7,10 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, TypedDict, cast
from typing import Annotated, Any, Literal, TypedDict, cast
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from redis.exceptions import LockNotOwnedError
from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session, sessionmaker
@ -117,6 +118,86 @@ class AutoDisableLogsDict(TypedDict):
count: int
class _EstimatePreProcessingRule(BaseModel):
id: str = Field(min_length=1)
enabled: bool
@field_validator("id")
@classmethod
def _validate_id(cls, v: str) -> str:
if v not in DatasetProcessRule.PRE_PROCESSING_RULES:
raise ValueError("Process rule pre_processing_rules id is invalid")
return v
class _EstimateSegmentation(BaseModel):
separator: str = Field(min_length=1)
max_tokens: int = Field(gt=0)
class _EstimateRules(BaseModel):
pre_processing_rules: list[_EstimatePreProcessingRule]
segmentation: _EstimateSegmentation
@field_validator("pre_processing_rules")
@classmethod
def _deduplicate(cls, v: list[_EstimatePreProcessingRule]) -> list[_EstimatePreProcessingRule]:
seen: dict[str, _EstimatePreProcessingRule] = {}
for rule in v:
seen[rule.id] = rule
return list(seen.values())
class _SummaryIndexSettingDisabled(BaseModel):
enable: Literal[False] = False
class _SummaryIndexSettingEnabled(BaseModel):
enable: Literal[True]
model_name: str = Field(min_length=1)
model_provider_name: str = Field(min_length=1)
_SummaryIndexSetting = Annotated[
_SummaryIndexSettingDisabled | _SummaryIndexSettingEnabled,
Field(discriminator="enable"),
]
class _AutomaticProcessRule(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal[ProcessRuleMode.AUTOMATIC]
summary_index_setting: _SummaryIndexSetting | None = None
class _CustomProcessRule(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal[ProcessRuleMode.CUSTOM]
rules: _EstimateRules
summary_index_setting: _SummaryIndexSetting | None = None
class _HierarchicalProcessRule(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal[ProcessRuleMode.HIERARCHICAL]
rules: _EstimateRules
summary_index_setting: _SummaryIndexSetting | None = None
_EstimateProcessRule = Annotated[
_AutomaticProcessRule | _CustomProcessRule | _HierarchicalProcessRule,
Field(discriminator="mode"),
]
class _EstimateArgs(BaseModel):
info_list: dict[str, Any]
process_rule: _EstimateProcessRule
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
@ -2851,94 +2932,16 @@ class DocumentService:
@classmethod
def estimate_args_validate(cls, args: dict[str, Any]):
if "info_list" not in args or not args["info_list"]:
raise ValueError("Data source info is required")
if not isinstance(args["info_list"], dict):
raise ValueError("Data info is invalid")
if "process_rule" not in args or not args["process_rule"]:
raise ValueError("Process rule is required")
if not isinstance(args["process_rule"], dict):
raise ValueError("Process rule is invalid")
if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]:
raise ValueError("Process rule mode is required")
if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC:
args["process_rule"]["rules"] = {}
else:
if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
raise ValueError("Process rule rules is required")
if not isinstance(args["process_rule"]["rules"], dict):
raise ValueError("Process rule rules is invalid")
if (
"pre_processing_rules" not in args["process_rule"]["rules"]
or args["process_rule"]["rules"]["pre_processing_rules"] is None
):
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list):
raise ValueError("Process rule pre_processing_rules is invalid")
unique_pre_processing_rule_dicts = {}
for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]:
if "id" not in pre_processing_rule or not pre_processing_rule["id"]:
raise ValueError("Process rule pre_processing_rules id is required")
if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES:
raise ValueError("Process rule pre_processing_rules id is invalid")
if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None:
raise ValueError("Process rule pre_processing_rules enabled is required")
if not isinstance(pre_processing_rule["enabled"], bool):
raise ValueError("Process rule pre_processing_rules enabled is invalid")
unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule
args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values())
if (
"segmentation" not in args["process_rule"]["rules"]
or args["process_rule"]["rules"]["segmentation"] is None
):
raise ValueError("Process rule segmentation is required")
if not isinstance(args["process_rule"]["rules"]["segmentation"], dict):
raise ValueError("Process rule segmentation is invalid")
if (
"separator" not in args["process_rule"]["rules"]["segmentation"]
or not args["process_rule"]["rules"]["segmentation"]["separator"]
):
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str):
raise ValueError("Process rule segmentation separator is invalid")
if (
"max_tokens" not in args["process_rule"]["rules"]["segmentation"]
or not args["process_rule"]["rules"]["segmentation"]["max_tokens"]
):
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
# valid summary index setting
summary_index_setting = args["process_rule"].get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable"):
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
raise ValueError("Summary index model name is required")
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
raise ValueError("Summary index model provider name is required")
try:
validated = _EstimateArgs.model_validate(args)
except ValidationError as e:
first = e.errors()[0]
original = first.get("ctx", {}).get("error")
raise ValueError(str(original) if isinstance(original, ValueError) else first["msg"]) from e
process_rule_dict = validated.process_rule.model_dump(exclude_none=True)
if validated.process_rule.mode == ProcessRuleMode.AUTOMATIC:
process_rule_dict["rules"] = {}
args["process_rule"] = process_rule_dict
@staticmethod
def batch_update_document_status(

View File

@ -1297,7 +1297,7 @@ class TestDocumentServiceEstimateValidation:
"""Unit tests for estimate_args_validate branches."""
def test_estimate_args_validate_rejects_missing_info_list(self):
with pytest.raises(ValueError, match="Data source info is required"):
with pytest.raises(ValueError, match="Field required"):
DocumentService.estimate_args_validate({})
def test_estimate_args_validate_sets_empty_rules_for_automatic_mode(self):
@ -1357,7 +1357,7 @@ class TestDocumentServiceEstimateValidation:
},
}
with pytest.raises(ValueError, match="Summary index model provider name is required"):
with pytest.raises(ValueError, match="Field required"):
DocumentService.estimate_args_validate(args)