refactor: replace bare dict with dict[str, Any] in moderation module (#35076)

This commit is contained in:
wdeveloper16 2026-04-13 19:09:06 +02:00 committed by GitHub
parent d8fbc00cb9
commit 7056d2ae99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 25 additions and 18 deletions

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel):
app_id: str = ""
inputs: dict = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
query: str = ""
@ -23,7 +25,7 @@ class ApiModeration(Moderation):
name: str = "api"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -41,7 +43,7 @@ class ApiModeration(Moderation):
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -73,7 +75,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
if self.config is None:
raise ValueError("The config is not set.")
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
inputs: dict = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
query: str = ""
@ -33,13 +34,13 @@ class Moderation(Extensible, ABC):
module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
super().__init__(tenant_id, config)
self.app_id = app_id
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
"""
Validate the incoming form config data.
@ -50,7 +51,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@abstractmethod
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
@ -75,7 +76,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):

View File

@ -1,3 +1,5 @@
from typing import Any
from core.extension.extensible import ExtensionModule
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
from extensions.ext_code_based_extension import code_based_extension
@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory:
__extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict):
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -24,7 +26,7 @@ class ModerationFactory:
# FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review

View File

@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:

View File

@ -1,3 +1,5 @@
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType
from core.model_manager import ModelManager
@ -8,7 +10,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -18,7 +20,7 @@ class OpenAIModeration(Moderation):
"""
cls._validate_inputs_and_outputs_config(config, True)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -49,7 +51,7 @@ class OpenAIModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict):
def _is_violated(self, inputs: dict[str, Any]):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
model_instance = model_manager.get_model_instance(