refactor(api): type get_prompt_template with TypedDict (#34943)

This commit is contained in:
YBoy 2026-04-11 02:38:16 +02:00 committed by GitHub
parent 992ac38d0d
commit f2d6275da4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@ import json
import os
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypedDict, cast
from graphon.file import file_manager
from graphon.model_runtime.entities.message_entities import (
@ -34,6 +34,13 @@ class ModelMode(StrEnum):
prompt_file_contents: dict[str, Any] = {}
class PromptTemplateConfigDict(TypedDict):
prompt_template: PromptTemplateParser
custom_variable_keys: list[str]
special_variable_keys: list[str]
prompt_rules: dict[str, Any]
class SimplePromptTransform(PromptTransform):
"""
Simple Prompt Transform for Chatbot App Basic Mode.
@ -105,18 +112,13 @@ class SimplePromptTransform(PromptTransform):
with_memory_prompt=histories is not None,
)
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
custom_variable_keys = prompt_template_config["custom_variable_keys"]
if not isinstance(custom_variable_keys, list):
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}")
# Type check for custom_variable_keys
if not isinstance(custom_variable_keys_obj, list):
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
# Type check for special_variable_keys
if not isinstance(special_variable_keys_obj, list):
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
special_variable_keys = cast(list[str], special_variable_keys_obj)
special_variable_keys = prompt_template_config["special_variable_keys"]
if not isinstance(special_variable_keys, list):
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}")
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
@ -150,7 +152,7 @@ class SimplePromptTransform(PromptTransform):
has_context: bool,
query_in_prompt: bool,
with_memory_prompt: bool = False,
) -> dict[str, object]:
) -> PromptTemplateConfigDict:
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
custom_variable_keys: list[str] = []
@ -173,12 +175,13 @@ class SimplePromptTransform(PromptTransform):
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
special_variable_keys.append("#query#")
return {
result: PromptTemplateConfigDict = {
"prompt_template": PromptTemplateParser(template=prompt),
"custom_variable_keys": custom_variable_keys,
"special_variable_keys": special_variable_keys,
"prompt_rules": prompt_rules,
}
return result
def _get_chat_model_prompt_messages(
self,