diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0795fdb221..6be929677e 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,7 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal, cast +from typing import Any, Literal, TypedDict, cast import sqlalchemy as sa from graphon.file import helpers as file_helpers @@ -107,6 +107,16 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde logger = logging.getLogger(__name__) +class ProcessRulesDict(TypedDict): + mode: str + rules: dict[str, Any] + + +class AutoDisableLogsDict(TypedDict): + document_ids: list[str] + count: int + + class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): @@ -182,7 +192,7 @@ class DatasetService: return datasets.items, datasets.total @staticmethod - def get_process_rules(dataset_id): + def get_process_rules(dataset_id) -> ProcessRulesDict: # get the latest process rule dataset_process_rule = db.session.execute( select(DatasetProcessRule) @@ -192,10 +202,10 @@ class DatasetService: ).scalar_one_or_none() if dataset_process_rule: mode = dataset_process_rule.mode - rules = dataset_process_rule.rules_dict + rules = dataset_process_rule.rules_dict or {} else: - mode = DocumentService.DEFAULT_RULES["mode"] - rules = DocumentService.DEFAULT_RULES["rules"] + mode = str(DocumentService.DEFAULT_RULES["mode"]) + rules = dict(DocumentService.DEFAULT_RULES.get("rules") or {}) return {"mode": mode, "rules": rules} @staticmethod @@ -1199,7 +1209,7 @@ class DatasetService: db.session.commit() @staticmethod - def get_dataset_auto_disable_logs(dataset_id: str): + def get_dataset_auto_disable_logs(dataset_id: str) -> AutoDisableLogsDict: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None features = FeatureService.get_features(current_user.current_tenant_id)