mirror of https://github.com/langgenius/dify.git
llm and answer node support inner variable template
This commit is contained in:
parent
8a2d04b305
commit
971436d935
|
|
@ -21,6 +21,8 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
"""
|
||||
Advanced Prompt Transform for Workflow LLM Node.
|
||||
"""
|
||||
def __init__(self, with_variable_tmpl: bool = False) -> None:
|
||||
self.with_variable_tmpl = with_variable_tmpl
|
||||
|
||||
def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
|
||||
inputs: dict,
|
||||
|
|
@ -74,7 +76,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
|
||||
prompt_messages = []
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
|
@ -128,7 +130,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
for prompt_item in raw_prompt_list:
|
||||
raw_prompt = prompt_item.text
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
|
@ -211,7 +213,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
if '#histories#' in prompt_template.variable_keys:
|
||||
if memory:
|
||||
inputs = {'#histories#': '', **prompt_inputs}
|
||||
prompt_template = PromptTemplateParser(raw_prompt)
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
tmp_human_message = UserPromptMessage(
|
||||
content=prompt_template.format(prompt_inputs)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
import re
|
||||
|
||||
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}")
|
||||
WITH_VARIABLE_TMPL_REGEX = re.compile(
|
||||
r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#[a-zA-Z0-9_]{1,50}\.[a-zA-Z0-9_\.]{1,100}#|#histories#|#query#|#context#)\}\}"
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateParser:
|
||||
|
|
@ -15,13 +18,15 @@ class PromptTemplateParser:
|
|||
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
def __init__(self, template: str, with_variable_tmpl: bool = False):
|
||||
self.template = template
|
||||
self.with_variable_tmpl = with_variable_tmpl
|
||||
self.regex = WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self) -> list:
|
||||
# Regular expression to match the template rules
|
||||
return re.findall(REGEX, self.template)
|
||||
return re.findall(self.regex, self.template)
|
||||
|
||||
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
|
||||
def replacer(match):
|
||||
|
|
@ -29,12 +34,12 @@ class PromptTemplateParser:
|
|||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
if remove_template_variables:
|
||||
return PromptTemplateParser.remove_template_variables(value)
|
||||
return PromptTemplateParser.remove_template_variables(value, self.with_variable_tmpl)
|
||||
return value
|
||||
|
||||
prompt = re.sub(REGEX, replacer, self.template)
|
||||
prompt = re.sub(self.regex, replacer, self.template)
|
||||
return re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
@classmethod
|
||||
def remove_template_variables(cls, text: str):
|
||||
return re.sub(REGEX, r'{\1}', text)
|
||||
def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False):
|
||||
return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from core.workflow.nodes.answer.entities import (
|
|||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
|
|
@ -66,32 +67,8 @@ class AnswerNode(BaseNode):
|
|||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
# re-fetch variable values
|
||||
variable_values = {}
|
||||
for variable_selector in node_data.variables:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
if isinstance(value, str | int | float):
|
||||
value = str(value)
|
||||
elif isinstance(value, FileVar):
|
||||
value = value.to_dict()
|
||||
elif isinstance(value, list):
|
||||
new_value = []
|
||||
for item in value:
|
||||
if isinstance(item, FileVar):
|
||||
new_value.append(item.to_dict())
|
||||
else:
|
||||
new_value.append(item)
|
||||
|
||||
value = new_value
|
||||
|
||||
variable_values[variable_selector.variable] = value
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variable_values,
|
||||
outputs={
|
||||
"answer": answer
|
||||
}
|
||||
|
|
@ -116,15 +93,18 @@ class AnswerNode(BaseNode):
|
|||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(node_data.answer)
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
|
|
@ -164,8 +144,11 @@ class AnswerNode(BaseNode):
|
|||
"""
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in node_data.variables:
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
|
|
|||
|
|
@ -2,14 +2,12 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
variables: list[VariableSelector] = []
|
||||
answer: str
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from pydantic import BaseModel
|
|||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
|
|
@ -44,7 +43,6 @@ class LLMNodeData(BaseNodeData):
|
|||
LLM Node Data.
|
||||
"""
|
||||
model: ModelConfig
|
||||
variables: list[VariableSelector] = []
|
||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
|
|
|
|||
|
|
@ -15,13 +15,14 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
|
|
@ -48,9 +49,7 @@ class LLMNode(BaseNode):
|
|||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data, variable_pool)
|
||||
|
||||
node_inputs = {
|
||||
**inputs
|
||||
}
|
||||
node_inputs = {}
|
||||
|
||||
# fetch files
|
||||
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
|
||||
|
|
@ -192,10 +191,21 @@ class LLMNode(BaseNode):
|
|||
:return:
|
||||
"""
|
||||
inputs = {}
|
||||
for variable_selector in node_data.variables:
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
elif isinstance(prompt_template, CompletionModelPromptTemplate):
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.value_selector} not found')
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
|
|
@ -411,7 +421,7 @@ class LLMNode(BaseNode):
|
|||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=node_data.prompt_template,
|
||||
inputs=inputs,
|
||||
|
|
@ -486,9 +496,6 @@ class LLMNode(BaseNode):
|
|||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class QuestionClassifierNode(LLMNode):
|
|||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
import re
|
||||
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
|
||||
|
||||
|
||||
class VariableTemplateParser:
|
||||
"""
|
||||
Rules:
|
||||
|
||||
1. Template variables must be enclosed in `{{}}`.
|
||||
2. The template variable Key can only be: #node_id.var1.var2#.
|
||||
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.template = template
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self) -> list:
|
||||
# Regular expression to match the template rules
|
||||
matches = re.findall(REGEX, self.template)
|
||||
|
||||
first_group_matches = [match[0] for match in matches]
|
||||
|
||||
return list(set(first_group_matches))
|
||||
|
||||
def extract_variable_selectors(self) -> list[VariableSelector]:
|
||||
variable_selectors = []
|
||||
for variable_key in self.variable_keys:
|
||||
remove_hash = variable_key.replace('#', '')
|
||||
split_result = remove_hash.split('.')
|
||||
if len(split_result) < 2:
|
||||
continue
|
||||
|
||||
variable_selectors.append(VariableSelector(
|
||||
variable=variable_key,
|
||||
value_selector=split_result
|
||||
))
|
||||
|
||||
return variable_selectors
|
||||
|
||||
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
if remove_template_variables:
|
||||
return VariableTemplateParser.remove_template_variables(value)
|
||||
return value
|
||||
|
||||
prompt = re.sub(REGEX, replacer, self.template)
|
||||
return re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
@classmethod
|
||||
def remove_template_variables(cls, text: str):
|
||||
return re.sub(REGEX, r'{\1}', text)
|
||||
|
|
@ -291,7 +291,7 @@ class WorkflowConverter:
|
|||
if app_model.mode == AppMode.CHAT.value:
|
||||
http_request_variables.append({
|
||||
"variable": "_query",
|
||||
"value_selector": ["start", "sys.query"]
|
||||
"value_selector": ["sys", ".query"]
|
||||
})
|
||||
|
||||
request_body = {
|
||||
|
|
@ -375,7 +375,7 @@ class WorkflowConverter:
|
|||
"""
|
||||
retrieve_config = dataset_config.retrieve_config
|
||||
if new_app_mode == AppMode.ADVANCED_CHAT:
|
||||
query_variable_selector = ["start", "sys.query"]
|
||||
query_variable_selector = ["sys", "query"]
|
||||
elif retrieve_config.query_variable:
|
||||
# fetch query variable
|
||||
query_variable_selector = ["start", retrieve_config.query_variable]
|
||||
|
|
@ -449,19 +449,31 @@ class WorkflowConverter:
|
|||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
for v in start_node['data']['variables']:
|
||||
template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
|
||||
|
||||
prompts = [
|
||||
{
|
||||
"role": 'user',
|
||||
"text": prompt_template_config['prompt_template'].template
|
||||
"text": template
|
||||
}
|
||||
]
|
||||
else:
|
||||
advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template
|
||||
prompts = [{
|
||||
"role": m.role.value,
|
||||
"text": m.text
|
||||
} for m in advanced_chat_prompt_template.messages] \
|
||||
if advanced_chat_prompt_template else []
|
||||
|
||||
prompts = []
|
||||
for m in advanced_chat_prompt_template.messages:
|
||||
if advanced_chat_prompt_template:
|
||||
text = m.text
|
||||
for v in start_node['data']['variables']:
|
||||
text = text.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
|
||||
|
||||
prompts.append({
|
||||
"role": m.role.value,
|
||||
"text": text
|
||||
})
|
||||
# Completion Model
|
||||
else:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
|
|
@ -475,8 +487,13 @@ class WorkflowConverter:
|
|||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
for v in start_node['data']['variables']:
|
||||
template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
|
||||
|
||||
prompts = {
|
||||
"text": prompt_template_config['prompt_template'].template
|
||||
"text": template
|
||||
}
|
||||
|
||||
prompt_rules = prompt_template_config['prompt_rules']
|
||||
|
|
@ -486,9 +503,16 @@ class WorkflowConverter:
|
|||
}
|
||||
else:
|
||||
advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template
|
||||
if advanced_completion_prompt_template:
|
||||
text = advanced_completion_prompt_template.prompt
|
||||
for v in start_node['data']['variables']:
|
||||
text = text.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
|
||||
else:
|
||||
text = ""
|
||||
|
||||
prompts = {
|
||||
"text": advanced_completion_prompt_template.prompt,
|
||||
} if advanced_completion_prompt_template else {"text": ""}
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
role_prefix = {
|
||||
|
|
@ -519,10 +543,6 @@ class WorkflowConverter:
|
|||
"mode": model_config.mode,
|
||||
"completion_params": completion_params
|
||||
},
|
||||
"variables": [{
|
||||
"variable": v['variable'],
|
||||
"value_selector": ["start", v['variable']]
|
||||
} for v in start_node['data']['variables']],
|
||||
"prompt_template": prompts,
|
||||
"memory": memory,
|
||||
"context": {
|
||||
|
|
@ -532,7 +552,7 @@ class WorkflowConverter:
|
|||
},
|
||||
"vision": {
|
||||
"enabled": file_upload is not None,
|
||||
"variable_selector": ["start", "sys.files"] if file_upload is not None else None,
|
||||
"variable_selector": ["sys", "files"] if file_upload is not None else None,
|
||||
"configs": {
|
||||
"detail": file_upload.image_config['detail']
|
||||
} if file_upload is not None else None
|
||||
|
|
@ -571,11 +591,7 @@ class WorkflowConverter:
|
|||
"data": {
|
||||
"title": "ANSWER",
|
||||
"type": NodeType.ANSWER.value,
|
||||
"variables": [{
|
||||
"variable": "text",
|
||||
"value_selector": ["llm", "text"]
|
||||
}],
|
||||
"answer": "{{text}}"
|
||||
"answer": "{{#llm.text#}}"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,32 +40,17 @@ def test_execute_llm(setup_openai_mock):
|
|||
'mode': 'chat',
|
||||
'completion_params': {}
|
||||
},
|
||||
'variables': [
|
||||
{
|
||||
'variable': 'weather',
|
||||
'value_selector': ['abc', 'output'],
|
||||
},
|
||||
{
|
||||
'variable': 'query',
|
||||
'value_selector': ['sys', 'query']
|
||||
}
|
||||
],
|
||||
'prompt_template': [
|
||||
{
|
||||
'role': 'system',
|
||||
'text': 'you are a helpful assistant.\ntoday\'s weather is {{weather}}.'
|
||||
'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'text': '{{query}}'
|
||||
'text': '{{#sys.query#}}'
|
||||
}
|
||||
],
|
||||
'memory': {
|
||||
'window': {
|
||||
'enabled': True,
|
||||
'size': 2
|
||||
}
|
||||
},
|
||||
'memory': None,
|
||||
'context': {
|
||||
'enabled': False
|
||||
},
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from core.workflow.entities.node_entities import SystemVariable
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
|
@ -21,17 +20,7 @@ def test_execute_answer():
|
|||
'data': {
|
||||
'title': '123',
|
||||
'type': 'answer',
|
||||
'variables': [
|
||||
{
|
||||
'value_selector': ['llm', 'text'],
|
||||
'variable': 'text'
|
||||
},
|
||||
{
|
||||
'value_selector': ['start', 'weather'],
|
||||
'variable': 'weather'
|
||||
},
|
||||
],
|
||||
'answer': 'Today\'s weather is {{weather}}\n{{text}}\n{{img}}\nFin.'
|
||||
'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from services.workflow.workflow_converter import WorkflowConverter
|
|||
def default_variables():
|
||||
return [
|
||||
VariableEntity(
|
||||
variable="text-input",
|
||||
variable="text_input",
|
||||
label="text-input",
|
||||
type=VariableEntity.Type.TEXT_INPUT
|
||||
),
|
||||
|
|
@ -43,7 +43,7 @@ def test__convert_to_start_node(default_variables):
|
|||
# assert
|
||||
assert isinstance(result["data"]["variables"][0]["type"], str)
|
||||
assert result["data"]["variables"][0]["type"] == "text-input"
|
||||
assert result["data"]["variables"][0]["variable"] == "text-input"
|
||||
assert result["data"]["variables"][0]["variable"] == "text_input"
|
||||
assert result["data"]["variables"][1]["variable"] == "paragraph"
|
||||
assert result["data"]["variables"][2]["variable"] == "select"
|
||||
|
||||
|
|
@ -191,7 +191,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
|||
|
||||
|
||||
def test__convert_to_knowledge_retrieval_node_for_chatbot():
|
||||
new_app_mode = AppMode.CHAT
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_id_1", "dataset_id_2"],
|
||||
|
|
@ -221,7 +221,7 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot():
|
|||
)
|
||||
|
||||
assert node["data"]["type"] == "knowledge-retrieval"
|
||||
assert node["data"]["query_variable_selector"] == ["start", "sys.query"]
|
||||
assert node["data"]["query_variable_selector"] == ["sys", "query"]
|
||||
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
|
||||
assert (node["data"]["retrieval_mode"]
|
||||
== dataset_config.retrieve_config.retrieve_strategy.value)
|
||||
|
|
@ -276,7 +276,7 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app():
|
|||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
|
||||
new_app_mode = AppMode.CHAT
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-4"
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
|
|
@ -298,7 +298,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
|
|||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}."
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}."
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
|
|
@ -311,16 +311,15 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
|
|||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]['name'] == model
|
||||
assert llm_node["data"]['model']["mode"] == model_mode.value
|
||||
assert llm_node["data"]["variables"] == [{
|
||||
"variable": v.variable,
|
||||
"value_selector": ["start", v.variable]
|
||||
} for v in default_variables]
|
||||
assert llm_node["data"]["prompts"][0]['text'] == prompt_template.simple_prompt_template + '\n'
|
||||
template = prompt_template.simple_prompt_template
|
||||
for v in default_variables:
|
||||
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
|
||||
assert llm_node["data"]["prompt_template"][0]['text'] == template + '\n'
|
||||
assert llm_node["data"]['context']['enabled'] is False
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables):
|
||||
new_app_mode = AppMode.CHAT
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-3.5-turbo-instruct"
|
||||
model_mode = LLMMode.COMPLETION
|
||||
|
||||
|
|
@ -342,7 +341,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
|
|||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}."
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}."
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
|
|
@ -355,16 +354,15 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
|
|||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]['name'] == model
|
||||
assert llm_node["data"]['model']["mode"] == model_mode.value
|
||||
assert llm_node["data"]["variables"] == [{
|
||||
"variable": v.variable,
|
||||
"value_selector": ["start", v.variable]
|
||||
} for v in default_variables]
|
||||
assert llm_node["data"]["prompts"]['text'] == prompt_template.simple_prompt_template + '\n'
|
||||
template = prompt_template.simple_prompt_template
|
||||
for v in default_variables:
|
||||
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
|
||||
assert llm_node["data"]["prompt_template"]['text'] == template + '\n'
|
||||
assert llm_node["data"]['context']['enabled'] is False
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables):
|
||||
new_app_mode = AppMode.CHAT
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-4"
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
|
|
@ -404,17 +402,16 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables)
|
|||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]['name'] == model
|
||||
assert llm_node["data"]['model']["mode"] == model_mode.value
|
||||
assert llm_node["data"]["variables"] == [{
|
||||
"variable": v.variable,
|
||||
"value_selector": ["start", v.variable]
|
||||
} for v in default_variables]
|
||||
assert isinstance(llm_node["data"]["prompts"], list)
|
||||
assert len(llm_node["data"]["prompts"]) == len(prompt_template.advanced_chat_prompt_template.messages)
|
||||
assert llm_node["data"]["prompts"][0]['text'] == prompt_template.advanced_chat_prompt_template.messages[0].text
|
||||
assert isinstance(llm_node["data"]["prompt_template"], list)
|
||||
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
|
||||
template = prompt_template.advanced_chat_prompt_template.messages[0].text
|
||||
for v in default_variables:
|
||||
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
|
||||
assert llm_node["data"]["prompt_template"][0]['text'] == template
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables):
|
||||
new_app_mode = AppMode.CHAT
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-3.5-turbo-instruct"
|
||||
model_mode = LLMMode.COMPLETION
|
||||
|
||||
|
|
@ -456,9 +453,8 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
|
|||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]['name'] == model
|
||||
assert llm_node["data"]['model']["mode"] == model_mode.value
|
||||
assert llm_node["data"]["variables"] == [{
|
||||
"variable": v.variable,
|
||||
"value_selector": ["start", v.variable]
|
||||
} for v in default_variables]
|
||||
assert isinstance(llm_node["data"]["prompts"], dict)
|
||||
assert llm_node["data"]["prompts"]['text'] == prompt_template.advanced_completion_prompt_template.prompt
|
||||
assert isinstance(llm_node["data"]["prompt_template"], dict)
|
||||
template = prompt_template.advanced_completion_prompt_template.prompt
|
||||
for v in default_variables:
|
||||
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
|
||||
assert llm_node["data"]["prompt_template"]['text'] == template
|
||||
|
|
|
|||
Loading…
Reference in New Issue