mirror of https://github.com/langgenius/dify.git
fix prompt transform bugs
This commit is contained in:
parent
3b234febf5
commit
6aecf42b6e
|
|
@ -20,7 +20,7 @@ from core.prompt.prompt_transform import PromptTransform
|
|||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
|
||||
|
||||
class AdvancePromptTransform(PromptTransform):
|
||||
class AdvancedPromptTransform(PromptTransform):
|
||||
"""
|
||||
Advanced Prompt Transform for Workflow LLM Node.
|
||||
"""
|
||||
|
|
@ -74,10 +74,10 @@ class AdvancePromptTransform(PromptTransform):
|
|||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix
|
||||
self._set_histories_variable(
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
|
|
@ -104,7 +104,7 @@ class AdvancePromptTransform(PromptTransform):
|
|||
def _get_chat_model_prompt_messages(self,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
query: Optional[str],
|
||||
files: list[FileObj],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
|
|
@ -122,7 +122,7 @@ class AdvancePromptTransform(PromptTransform):
|
|||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
|
|
@ -136,7 +136,7 @@ class AdvancePromptTransform(PromptTransform):
|
|||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||
|
||||
if memory:
|
||||
self._append_chat_histories(memory, prompt_messages, model_config)
|
||||
prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config)
|
||||
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
|
|
@ -157,7 +157,7 @@ class AdvancePromptTransform(PromptTransform):
|
|||
|
||||
last_message.content = prompt_message_contents
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
prompt_message_contents = [TextPromptMessageContent(data='')] # not for query
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
|
||||
|
|
@ -165,26 +165,30 @@ class AdvancePromptTransform(PromptTransform):
|
|||
|
||||
return prompt_messages
|
||||
|
||||
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
|
||||
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
if '#context#' in prompt_template.variable_keys:
|
||||
if context:
|
||||
prompt_inputs['#context#'] = context
|
||||
else:
|
||||
prompt_inputs['#context#'] = ''
|
||||
|
||||
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
|
||||
return prompt_inputs
|
||||
|
||||
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
if '#query#' in prompt_template.variable_keys:
|
||||
if query:
|
||||
prompt_inputs['#query#'] = query
|
||||
else:
|
||||
prompt_inputs['#query#'] = ''
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
def _set_histories_variable(self, memory: TokenBufferMemory,
|
||||
raw_prompt: str,
|
||||
role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity,
|
||||
prompt_template: PromptTemplateParser,
|
||||
prompt_inputs: dict,
|
||||
model_config: ModelConfigEntity) -> None:
|
||||
model_config: ModelConfigEntity) -> dict:
|
||||
if '#histories#' in prompt_template.variable_keys:
|
||||
if memory:
|
||||
inputs = {'#histories#': '', **prompt_inputs}
|
||||
|
|
@ -205,3 +209,5 @@ class AdvancePromptTransform(PromptTransform):
|
|||
prompt_inputs['#histories#'] = histories
|
||||
else:
|
||||
prompt_inputs['#histories#'] = ''
|
||||
|
||||
return prompt_inputs
|
||||
|
|
|
|||
|
|
@ -10,12 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||
class PromptTransform:
|
||||
def _append_chat_histories(self, memory: TokenBufferMemory,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigEntity) -> None:
|
||||
model_config: ModelConfigEntity) -> list[PromptMessage]:
|
||||
if memory:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
if prompt:
|
||||
prompt_messages.append(SystemPromptMessage(content=prompt))
|
||||
|
||||
self._append_chat_histories(
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory=memory,
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=model_config
|
||||
|
|
|
|||
|
|
@ -0,0 +1,193 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.application_entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \
|
||||
ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity
|
||||
from core.file.file_obj import FileObj, FileType, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from models.model import Conversation
|
||||
|
||||
|
||||
def test__get_completion_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-3.5-turbo-instruct'
|
||||
|
||||
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt=prompt_template,
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(
|
||||
user="Human",
|
||||
assistant="Assistant"
|
||||
)
|
||||
)
|
||||
)
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
files = []
|
||||
context = "I am superman."
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Hello")
|
||||
]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_completion_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({
|
||||
"#context#": context,
|
||||
"#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: "
|
||||
f"{prompt.content}" for prompt in history_prompt_messages]),
|
||||
**inputs,
|
||||
})
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
|
||||
|
||||
files = []
|
||||
query = "Hi2."
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi1."),
|
||||
AssistantPromptMessage(content="Hello1!")
|
||||
]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 6
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert prompt_messages[5].content == query
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
||||
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
|
||||
|
||||
files = []
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 3
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
|
||||
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
|
||||
|
||||
files = [
|
||||
FileObj(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
url="https://example.com/image1.jpg",
|
||||
file_config={
|
||||
"image": {
|
||||
"detail": "high",
|
||||
}
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert isinstance(prompt_messages[3].content, list)
|
||||
assert len(prompt_messages[3].content) == 2
|
||||
assert prompt_messages[3].content[1].data == files[0].url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_chat_model_args():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-4'
|
||||
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[
|
||||
AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM),
|
||||
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
|
||||
context = "I am superman."
|
||||
|
||||
return model_config_mock, prompt_template_entity, inputs, context
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, Conversation
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_pcqm():
|
||||
|
|
@ -141,7 +143,16 @@ def test__get_chat_model_prompt_messages():
|
|||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-4'
|
||||
|
||||
memory_mock = MagicMock(spec=TokenBufferMemory)
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Hello")
|
||||
]
|
||||
memory_mock.get_history_prompt_messages.return_value = history_prompt_messages
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {
|
||||
"name": "John"
|
||||
|
|
@ -154,7 +165,7 @@ def test__get_chat_model_prompt_messages():
|
|||
query=query,
|
||||
files=[],
|
||||
context=context,
|
||||
memory=None,
|
||||
memory=memory_mock,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
||||
|
|
@ -171,9 +182,11 @@ def test__get_chat_model_prompt_messages():
|
|||
full_inputs = {**inputs, '#context#': context}
|
||||
real_system_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 2
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].content == real_system_prompt
|
||||
assert prompt_messages[1].content == query
|
||||
assert prompt_messages[1].content == history_prompt_messages[0].content
|
||||
assert prompt_messages[2].content == history_prompt_messages[1].content
|
||||
assert prompt_messages[3].content == query
|
||||
|
||||
|
||||
def test__get_completion_model_prompt_messages():
|
||||
|
|
@ -181,7 +194,19 @@ def test__get_completion_model_prompt_messages():
|
|||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-3.5-turbo-instruct'
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Hello")
|
||||
]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {
|
||||
"name": "John"
|
||||
|
|
@ -194,7 +219,7 @@ def test__get_completion_model_prompt_messages():
|
|||
query=query,
|
||||
files=[],
|
||||
context=context,
|
||||
memory=None,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
||||
|
|
@ -205,12 +230,17 @@ def test__get_completion_model_prompt_messages():
|
|||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
|
||||
full_inputs = {**inputs, '#context#': context, '#query#': query}
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
|
||||
max_token_limit=2000,
|
||||
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
)}
|
||||
real_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert stops == prompt_template['prompt_rules'].get('stops')
|
||||
assert stops == prompt_rules.get('stops')
|
||||
assert prompt_messages[0].content == real_prompt
|
||||
|
|
|
|||
Loading…
Reference in New Issue