refactor: replace bare dict with AdvancedPromptTemplateArgs TypedDict (#35056)

This commit is contained in:
wdeveloper16 2026-04-13 15:05:23 +02:00 committed by GitHub
parent e243e8d8a3
commit 554f060092
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 8 deletions

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel):
@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource):
@account_initialization_required
def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
prompt_args: AdvancedPromptTemplateArgs = {
"app_mode": args.app_mode,
"model_mode": args.model_mode,
"model_name": args.model_name,
"has_context": args.has_context,
}
return AdvancedPromptTemplateService.get_prompt(prompt_args)

View File

@ -1,4 +1,5 @@
import copy
from typing import Any, TypedDict
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
@ -15,9 +16,18 @@ from core.prompt.prompt_templates.advanced_prompt_templates import (
from models.model import AppMode
class AdvancedPromptTemplateArgs(TypedDict):
"""Expected shape of the args dict passed to AdvancedPromptTemplateService.get_prompt."""
app_mode: str
model_mode: str
model_name: str
has_context: str
class AdvancedPromptTemplateService:
@classmethod
def get_prompt(cls, args: dict):
def get_prompt(cls, args: AdvancedPromptTemplateArgs) -> dict[str, Any]:
app_mode = args["app_mode"]
model_mode = args["model_mode"]
model_name = args["model_name"]
@ -29,7 +39,7 @@ class AdvancedPromptTemplateService:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@classmethod
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]:
context_prompt = copy.deepcopy(CONTEXT)
match app_mode:
@ -63,7 +73,7 @@ class AdvancedPromptTemplateService:
return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
def get_completion_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]:
if has_context == "true":
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
@ -72,7 +82,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
def get_chat_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]:
if has_context == "true":
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
@ -81,7 +91,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
match app_mode: