From 554f06009265f12c6b47a48f912cf601a66ce70b Mon Sep 17 00:00:00 2001 From: wdeveloper16 Date: Mon, 13 Apr 2026 15:05:23 +0200 Subject: [PATCH] refactor: replace bare dict with AdvancedPromptTemplateArgs TypedDict (#35056) --- .../console/app/advanced_prompt_template.py | 11 +++++++--- .../advanced_prompt_template_service.py | 20 ++++++++++++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index 3bd61feb44..ed66da1be5 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required -from services.advanced_prompt_template_service import AdvancedPromptTemplateService +from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService class AdvancedPromptTemplateQuery(BaseModel): @@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource): @account_initialization_required def get(self): args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - - return AdvancedPromptTemplateService.get_prompt(args.model_dump()) + prompt_args: AdvancedPromptTemplateArgs = { + "app_mode": args.app_mode, + "model_mode": args.model_mode, + "model_name": args.model_name, + "has_context": args.has_context, + } + return AdvancedPromptTemplateService.get_prompt(prompt_args) diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index a6e6b1bae7..5d136e7393 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,4 +1,5 @@ import copy +from typing import Any, TypedDict from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, @@ -15,9 +16,18 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( from models.model import AppMode +class AdvancedPromptTemplateArgs(TypedDict): + """Expected shape of the args dict passed to AdvancedPromptTemplateService.get_prompt.""" + + app_mode: str + model_mode: str + model_name: str + has_context: str + + class AdvancedPromptTemplateService: @classmethod - def get_prompt(cls, args: dict): + def get_prompt(cls, args: AdvancedPromptTemplateArgs) -> dict[str, Any]: app_mode = args["app_mode"] model_mode = args["model_mode"] model_name = args["model_name"] @@ -29,7 +39,7 @@ class AdvancedPromptTemplateService: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str): + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]: context_prompt = copy.deepcopy(CONTEXT) match app_mode: @@ -63,7 +73,7 @@ class AdvancedPromptTemplateService: return {} @classmethod - def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str): + def get_completion_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]: if has_context == "true": prompt_template["completion_prompt_config"]["prompt"]["text"] = ( context + prompt_template["completion_prompt_config"]["prompt"]["text"] @@ -72,7 +82,7 @@ class AdvancedPromptTemplateService: return prompt_template @classmethod - def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str): + def get_chat_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]: if has_context == "true": prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] @@ -81,7 +91,7 @@ class AdvancedPromptTemplateService: return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str): + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) match app_mode: