mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 00:33:37 +08:00
refactor: rewrite estimate_args_validate using Pydantic v2 models (#36036)
Signed-off-by: Deepam Goyal <deepam02goyal@gmail.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
cd90d7ffc1
commit
1a93af5cd0
@ -7,9 +7,10 @@ import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from typing import Annotated, Any, Literal, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -117,6 +118,86 @@ class AutoDisableLogsDict(TypedDict):
|
||||
count: int
|
||||
|
||||
|
||||
class _EstimatePreProcessingRule(BaseModel):
|
||||
id: str = Field(min_length=1)
|
||||
enabled: bool
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _validate_id(cls, v: str) -> str:
|
||||
if v not in DatasetProcessRule.PRE_PROCESSING_RULES:
|
||||
raise ValueError("Process rule pre_processing_rules id is invalid")
|
||||
return v
|
||||
|
||||
|
||||
class _EstimateSegmentation(BaseModel):
|
||||
separator: str = Field(min_length=1)
|
||||
max_tokens: int = Field(gt=0)
|
||||
|
||||
|
||||
class _EstimateRules(BaseModel):
|
||||
pre_processing_rules: list[_EstimatePreProcessingRule]
|
||||
segmentation: _EstimateSegmentation
|
||||
|
||||
@field_validator("pre_processing_rules")
|
||||
@classmethod
|
||||
def _deduplicate(cls, v: list[_EstimatePreProcessingRule]) -> list[_EstimatePreProcessingRule]:
|
||||
seen: dict[str, _EstimatePreProcessingRule] = {}
|
||||
for rule in v:
|
||||
seen[rule.id] = rule
|
||||
return list(seen.values())
|
||||
|
||||
|
||||
class _SummaryIndexSettingDisabled(BaseModel):
|
||||
enable: Literal[False] = False
|
||||
|
||||
|
||||
class _SummaryIndexSettingEnabled(BaseModel):
|
||||
enable: Literal[True]
|
||||
model_name: str = Field(min_length=1)
|
||||
model_provider_name: str = Field(min_length=1)
|
||||
|
||||
|
||||
_SummaryIndexSetting = Annotated[
|
||||
_SummaryIndexSettingDisabled | _SummaryIndexSettingEnabled,
|
||||
Field(discriminator="enable"),
|
||||
]
|
||||
|
||||
|
||||
class _AutomaticProcessRule(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
mode: Literal[ProcessRuleMode.AUTOMATIC]
|
||||
summary_index_setting: _SummaryIndexSetting | None = None
|
||||
|
||||
|
||||
class _CustomProcessRule(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
mode: Literal[ProcessRuleMode.CUSTOM]
|
||||
rules: _EstimateRules
|
||||
summary_index_setting: _SummaryIndexSetting | None = None
|
||||
|
||||
|
||||
class _HierarchicalProcessRule(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
mode: Literal[ProcessRuleMode.HIERARCHICAL]
|
||||
rules: _EstimateRules
|
||||
summary_index_setting: _SummaryIndexSetting | None = None
|
||||
|
||||
|
||||
_EstimateProcessRule = Annotated[
|
||||
_AutomaticProcessRule | _CustomProcessRule | _HierarchicalProcessRule,
|
||||
Field(discriminator="mode"),
|
||||
]
|
||||
|
||||
|
||||
class _EstimateArgs(BaseModel):
|
||||
info_list: dict[str, Any]
|
||||
process_rule: _EstimateProcessRule
|
||||
|
||||
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||
@ -2851,94 +2932,16 @@ class DocumentService:
|
||||
|
||||
@classmethod
|
||||
def estimate_args_validate(cls, args: dict[str, Any]):
|
||||
if "info_list" not in args or not args["info_list"]:
|
||||
raise ValueError("Data source info is required")
|
||||
|
||||
if not isinstance(args["info_list"], dict):
|
||||
raise ValueError("Data info is invalid")
|
||||
|
||||
if "process_rule" not in args or not args["process_rule"]:
|
||||
raise ValueError("Process rule is required")
|
||||
|
||||
if not isinstance(args["process_rule"], dict):
|
||||
raise ValueError("Process rule is invalid")
|
||||
|
||||
if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]:
|
||||
raise ValueError("Process rule mode is required")
|
||||
|
||||
if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
|
||||
raise ValueError("Process rule mode is invalid")
|
||||
|
||||
if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC:
|
||||
args["process_rule"]["rules"] = {}
|
||||
else:
|
||||
if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
|
||||
raise ValueError("Process rule rules is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"], dict):
|
||||
raise ValueError("Process rule rules is invalid")
|
||||
|
||||
if (
|
||||
"pre_processing_rules" not in args["process_rule"]["rules"]
|
||||
or args["process_rule"]["rules"]["pre_processing_rules"] is None
|
||||
):
|
||||
raise ValueError("Process rule pre_processing_rules is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list):
|
||||
raise ValueError("Process rule pre_processing_rules is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts = {}
|
||||
for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]:
|
||||
if "id" not in pre_processing_rule or not pre_processing_rule["id"]:
|
||||
raise ValueError("Process rule pre_processing_rules id is required")
|
||||
|
||||
if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES:
|
||||
raise ValueError("Process rule pre_processing_rules id is invalid")
|
||||
|
||||
if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None:
|
||||
raise ValueError("Process rule pre_processing_rules enabled is required")
|
||||
|
||||
if not isinstance(pre_processing_rule["enabled"], bool):
|
||||
raise ValueError("Process rule pre_processing_rules enabled is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule
|
||||
|
||||
args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values())
|
||||
|
||||
if (
|
||||
"segmentation" not in args["process_rule"]["rules"]
|
||||
or args["process_rule"]["rules"]["segmentation"] is None
|
||||
):
|
||||
raise ValueError("Process rule segmentation is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"], dict):
|
||||
raise ValueError("Process rule segmentation is invalid")
|
||||
|
||||
if (
|
||||
"separator" not in args["process_rule"]["rules"]["segmentation"]
|
||||
or not args["process_rule"]["rules"]["segmentation"]["separator"]
|
||||
):
|
||||
raise ValueError("Process rule segmentation separator is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str):
|
||||
raise ValueError("Process rule segmentation separator is invalid")
|
||||
|
||||
if (
|
||||
"max_tokens" not in args["process_rule"]["rules"]["segmentation"]
|
||||
or not args["process_rule"]["rules"]["segmentation"]["max_tokens"]
|
||||
):
|
||||
raise ValueError("Process rule segmentation max_tokens is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
# valid summary index setting
|
||||
summary_index_setting = args["process_rule"].get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
|
||||
raise ValueError("Summary index model name is required")
|
||||
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
|
||||
raise ValueError("Summary index model provider name is required")
|
||||
try:
|
||||
validated = _EstimateArgs.model_validate(args)
|
||||
except ValidationError as e:
|
||||
first = e.errors()[0]
|
||||
original = first.get("ctx", {}).get("error")
|
||||
raise ValueError(str(original) if isinstance(original, ValueError) else first["msg"]) from e
|
||||
process_rule_dict = validated.process_rule.model_dump(exclude_none=True)
|
||||
if validated.process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
process_rule_dict["rules"] = {}
|
||||
args["process_rule"] = process_rule_dict
|
||||
|
||||
@staticmethod
|
||||
def batch_update_document_status(
|
||||
|
||||
@ -1297,7 +1297,7 @@ class TestDocumentServiceEstimateValidation:
|
||||
"""Unit tests for estimate_args_validate branches."""
|
||||
|
||||
def test_estimate_args_validate_rejects_missing_info_list(self):
|
||||
with pytest.raises(ValueError, match="Data source info is required"):
|
||||
with pytest.raises(ValueError, match="Field required"):
|
||||
DocumentService.estimate_args_validate({})
|
||||
|
||||
def test_estimate_args_validate_sets_empty_rules_for_automatic_mode(self):
|
||||
@ -1357,7 +1357,7 @@ class TestDocumentServiceEstimateValidation:
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Summary index model provider name is required"):
|
||||
with pytest.raises(ValueError, match="Field required"):
|
||||
DocumentService.estimate_args_validate(args)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user