mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
refactor(api): type get_prompt_template with TypedDict (#34943)
This commit is contained in:
parent
992ac38d0d
commit
f2d6275da4
@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -34,6 +34,13 @@ class ModelMode(StrEnum):
|
||||
prompt_file_contents: dict[str, Any] = {}
|
||||
|
||||
|
||||
class PromptTemplateConfigDict(TypedDict):
|
||||
prompt_template: PromptTemplateParser
|
||||
custom_variable_keys: list[str]
|
||||
special_variable_keys: list[str]
|
||||
prompt_rules: dict[str, Any]
|
||||
|
||||
|
||||
class SimplePromptTransform(PromptTransform):
|
||||
"""
|
||||
Simple Prompt Transform for Chatbot App Basic Mode.
|
||||
@ -105,18 +112,13 @@ class SimplePromptTransform(PromptTransform):
|
||||
with_memory_prompt=histories is not None,
|
||||
)
|
||||
|
||||
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
|
||||
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
|
||||
custom_variable_keys = prompt_template_config["custom_variable_keys"]
|
||||
if not isinstance(custom_variable_keys, list):
|
||||
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}")
|
||||
|
||||
# Type check for custom_variable_keys
|
||||
if not isinstance(custom_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
|
||||
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
|
||||
|
||||
# Type check for special_variable_keys
|
||||
if not isinstance(special_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
|
||||
special_variable_keys = cast(list[str], special_variable_keys_obj)
|
||||
special_variable_keys = prompt_template_config["special_variable_keys"]
|
||||
if not isinstance(special_variable_keys, list):
|
||||
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}")
|
||||
|
||||
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
|
||||
|
||||
@ -150,7 +152,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
has_context: bool,
|
||||
query_in_prompt: bool,
|
||||
with_memory_prompt: bool = False,
|
||||
) -> dict[str, object]:
|
||||
) -> PromptTemplateConfigDict:
|
||||
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
|
||||
|
||||
custom_variable_keys: list[str] = []
|
||||
@ -173,12 +175,13 @@ class SimplePromptTransform(PromptTransform):
|
||||
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
|
||||
special_variable_keys.append("#query#")
|
||||
|
||||
return {
|
||||
result: PromptTemplateConfigDict = {
|
||||
"prompt_template": PromptTemplateParser(template=prompt),
|
||||
"custom_variable_keys": custom_variable_keys,
|
||||
"special_variable_keys": special_variable_keys,
|
||||
"prompt_rules": prompt_rules,
|
||||
}
|
||||
return result
|
||||
|
||||
def _get_chat_model_prompt_messages(
|
||||
self,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user