mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 09:57:03 +08:00
refactor: replace bare dict with dict[str, Any] in moderation module (#35076)
This commit is contained in:
parent
d8fbc00cb9
commit
7056d2ae99
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension
|
||||
|
||||
class ModerationInputParams(BaseModel):
|
||||
app_id: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
@ -23,7 +25,7 @@ class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -41,7 +43,7 @@ class ApiModeration(Moderation):
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -73,7 +75,7 @@ class ApiModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel):
|
||||
flagged: bool = False
|
||||
action: ModerationAction
|
||||
preset_response: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
@ -33,13 +34,13 @@ class Moderation(Extensible, ABC):
|
||||
|
||||
module: ExtensionModule = ExtensionModule.MODERATION
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
|
||||
def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -50,7 +51,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
@ -75,7 +76,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
|
||||
# inputs_config
|
||||
inputs_config = config.get("inputs_config")
|
||||
if not isinstance(inputs_config, dict):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
class ModerationFactory:
|
||||
__extension_instance: Moderation
|
||||
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
self.__extension_instance = extension_class(app_id, tenant_id, config)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -24,7 +26,7 @@ class ModerationFactory:
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
|
||||
@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
|
||||
name: str = "keywords"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
|
||||
if len(keywords_row_len) > 100:
|
||||
raise ValueError("the number of rows for the keywords must be less than 100")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
|
||||
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
@ -8,7 +10,7 @@ class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -18,7 +20,7 @@ class OpenAIModeration(Moderation):
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -49,7 +51,7 @@ class OpenAIModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
def _is_violated(self, inputs: dict[str, Any]):
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user