From f2d6275da405b27724f423c4266bc9bd5b3f0c5c Mon Sep 17 00:00:00 2001 From: YBoy <231405196+YB0y@users.noreply.github.com> Date: Sat, 11 Apr 2026 02:38:16 +0200 Subject: [PATCH] refactor(api): type get_prompt_template with TypedDict (#34943) --- api/core/prompt/simple_prompt_transform.py | 31 ++++++++++++---------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index c706353ffe..36fca60db3 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -2,7 +2,7 @@ import json import os from collections.abc import Mapping, Sequence from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( @@ -34,6 +34,13 @@ class ModelMode(StrEnum): prompt_file_contents: dict[str, Any] = {} +class PromptTemplateConfigDict(TypedDict): + prompt_template: PromptTemplateParser + custom_variable_keys: list[str] + special_variable_keys: list[str] + prompt_rules: dict[str, Any] + + class SimplePromptTransform(PromptTransform): """ Simple Prompt Transform for Chatbot App Basic Mode. @@ -105,18 +112,13 @@ class SimplePromptTransform(PromptTransform): with_memory_prompt=histories is not None, ) - custom_variable_keys_obj = prompt_template_config["custom_variable_keys"] - special_variable_keys_obj = prompt_template_config["special_variable_keys"] + custom_variable_keys = prompt_template_config["custom_variable_keys"] + if not isinstance(custom_variable_keys, list): + raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}") - # Type check for custom_variable_keys - if not isinstance(custom_variable_keys_obj, list): - raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}") - custom_variable_keys = cast(list[str], custom_variable_keys_obj) - - # Type check for special_variable_keys - if not isinstance(special_variable_keys_obj, list): - raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}") - special_variable_keys = cast(list[str], special_variable_keys_obj) + special_variable_keys = prompt_template_config["special_variable_keys"] + if not isinstance(special_variable_keys, list): + raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}") variables = {k: inputs[k] for k in custom_variable_keys if k in inputs} @@ -150,7 +152,7 @@ class SimplePromptTransform(PromptTransform): has_context: bool, query_in_prompt: bool, with_memory_prompt: bool = False, - ) -> dict[str, object]: + ) -> PromptTemplateConfigDict: prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys: list[str] = [] @@ -173,12 +175,13 @@ class SimplePromptTransform(PromptTransform): prompt += prompt_rules.get("query_prompt", "{{#query#}}") special_variable_keys.append("#query#") - return { + result: PromptTemplateConfigDict = { "prompt_template": PromptTemplateParser(template=prompt), "custom_variable_keys": custom_variable_keys, "special_variable_keys": special_variable_keys, "prompt_rules": prompt_rules, } + return result def _get_chat_model_prompt_messages( self,