diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index e50ce8ab06..674ba29b6e 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -21,6 +21,8 @@ class AdvancedPromptTransform(PromptTransform): """ Advanced Prompt Transform for Workflow LLM Node. """ + def __init__(self, with_variable_tmpl: bool = False) -> None: + self.with_variable_tmpl = with_variable_tmpl def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], inputs: dict, @@ -74,7 +76,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = [] - prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) @@ -128,7 +130,7 @@ class AdvancedPromptTransform(PromptTransform): for prompt_item in raw_prompt_list: raw_prompt = prompt_item.text - prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) @@ -211,7 +213,7 @@ class AdvancedPromptTransform(PromptTransform): if '#histories#' in prompt_template.variable_keys: if memory: inputs = {'#histories#': '', **prompt_inputs} - prompt_template = PromptTemplateParser(raw_prompt) + prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} tmp_human_message = UserPromptMessage( content=prompt_template.format(prompt_inputs) diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 454f92e3b7..3e68492df2 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -1,6 +1,9 @@ import re REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}") +WITH_VARIABLE_TMPL_REGEX = re.compile( + r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#[a-zA-Z0-9_]{1,50}\.[a-zA-Z0-9_\.]{1,100}#|#histories#|#query#|#context#)\}\}" +) class PromptTemplateParser: @@ -15,13 +18,15 @@ class PromptTemplateParser: `{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed. """ - def __init__(self, template: str): + def __init__(self, template: str, with_variable_tmpl: bool = False): self.template = template + self.with_variable_tmpl = with_variable_tmpl + self.regex = WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX self.variable_keys = self.extract() def extract(self) -> list: # Regular expression to match the template rules - return re.findall(REGEX, self.template) + return re.findall(self.regex, self.template) def format(self, inputs: dict, remove_template_variables: bool = True) -> str: def replacer(match): @@ -29,12 +34,12 @@ class PromptTemplateParser: value = inputs.get(key, match.group(0)) # return original matched string if key not found if remove_template_variables: - return PromptTemplateParser.remove_template_variables(value) + return PromptTemplateParser.remove_template_variables(value, self.with_variable_tmpl) return value - prompt = re.sub(REGEX, replacer, self.template) + prompt = re.sub(self.regex, replacer, self.template) return re.sub(r'<\|.*?\|>', '', prompt) @classmethod - def remove_template_variables(cls, text: str): - return re.sub(REGEX, r'{\1}', text) + def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False): + return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 7a98150aab..9194d3fef7 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -13,6 +13,7 @@ from core.workflow.nodes.answer.entities import ( VarGenerateRouteChunk, ) from core.workflow.nodes.base_node import BaseNode +from core.workflow.utils.variable_template_parser import VariableTemplateParser from models.workflow import WorkflowNodeExecutionStatus @@ -66,32 +67,8 @@ class AnswerNode(BaseNode): part = cast(TextGenerateRouteChunk, part) answer += part.text - # re-fetch variable values - variable_values = {} - for variable_selector in node_data.variables: - value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector - ) - - if isinstance(value, str | int | float): - value = str(value) - elif isinstance(value, FileVar): - value = value.to_dict() - elif isinstance(value, list): - new_value = [] - for item in value: - if isinstance(item, FileVar): - new_value.append(item.to_dict()) - else: - new_value.append(item) - - value = new_value - - variable_values[variable_selector.variable] = value - return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variable_values, outputs={ "answer": answer } @@ -116,15 +93,18 @@ class AnswerNode(BaseNode): :param node_data: node data object :return: """ + variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_selectors = variable_template_parser.extract_variable_selectors() + value_selector_mapping = { variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables + for variable_selector in variable_selectors } variable_keys = list(value_selector_mapping.keys()) # format answer template - template_parser = PromptTemplateParser(node_data.answer) + template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) template_variable_keys = template_parser.variable_keys # Take the intersection of variable_keys and template_variable_keys @@ -164,8 +144,11 @@ class AnswerNode(BaseNode): """ node_data = cast(cls._node_data_cls, node_data) + variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_selectors = variable_template_parser.extract_variable_selectors() + variable_mapping = {} - for variable_selector in node_data.variables: + for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector return variable_mapping diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 8aed752ccb..9effbbbe67 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -2,14 +2,12 @@ from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.variable_entities import VariableSelector class AnswerNodeData(BaseNodeData): """ Answer Node Data. """ - variables: list[VariableSelector] = [] answer: str diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 67163c93cd..c390aaf8c9 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -4,7 +4,6 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.variable_entities import VariableSelector class ModelConfig(BaseModel): @@ -44,7 +43,6 @@ class LLMNodeData(BaseNodeData): LLM Node Data. """ model: ModelConfig - variables: list[VariableSelector] = [] prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate] memory: Optional[MemoryConfig] = None context: ContextConfig diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index cc49a22020..c0049c5bb3 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -15,13 +15,14 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig +from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import Conversation from models.provider import Provider, ProviderType @@ -48,9 +49,7 @@ class LLMNode(BaseNode): # fetch variables and fetch values from variable pool inputs = self._fetch_inputs(node_data, variable_pool) - node_inputs = { - **inputs - } + node_inputs = {} # fetch files files: list[FileVar] = self._fetch_files(node_data, variable_pool) @@ -192,10 +191,21 @@ class LLMNode(BaseNode): :return: """ inputs = {} - for variable_selector in node_data.variables: + prompt_template = node_data.prompt_template + + variable_selectors = [] + if isinstance(prompt_template, list): + for prompt in prompt_template: + variable_template_parser = VariableTemplateParser(template=prompt.text) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + elif isinstance(prompt_template, CompletionModelPromptTemplate): + variable_template_parser = VariableTemplateParser(template=prompt_template.text) + variable_selectors = variable_template_parser.extract_variable_selectors() + + for variable_selector in variable_selectors: variable_value = variable_pool.get_variable_value(variable_selector.value_selector) if variable_value is None: - raise ValueError(f'Variable {variable_selector.value_selector} not found') + raise ValueError(f'Variable {variable_selector.variable} not found') inputs[variable_selector.variable] = variable_value @@ -411,7 +421,7 @@ class LLMNode(BaseNode): :param model_config: model config :return: """ - prompt_transform = AdvancedPromptTransform() + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_messages = prompt_transform.get_prompt( prompt_template=node_data.prompt_template, inputs=inputs, @@ -486,9 +496,6 @@ class LLMNode(BaseNode): node_data = cast(cls._node_data_cls, node_data) variable_mapping = {} - for variable_selector in node_data.variables: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - if node_data.context.enabled: variable_mapping['#context#'] = node_data.context.variable_selector diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 01b2908f85..4cc698a840 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -128,7 +128,7 @@ class QuestionClassifierNode(LLMNode): :param model_config: model config :return: """ - prompt_transform = AdvancedPromptTransform() + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, diff --git a/api/core/workflow/utils/__init__.py b/api/core/workflow/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py new file mode 100644 index 0000000000..23b8ce2974 --- /dev/null +++ b/api/core/workflow/utils/variable_template_parser.py @@ -0,0 +1,58 @@ +import re + +from core.workflow.entities.variable_entities import VariableSelector + +REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") + + +class VariableTemplateParser: + """ + Rules: + + 1. Template variables must be enclosed in `{{}}`. + 2. The template variable Key can only be: #node_id.var1.var2#. + 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2. + """ + + def __init__(self, template: str): + self.template = template + self.variable_keys = self.extract() + + def extract(self) -> list: + # Regular expression to match the template rules + matches = re.findall(REGEX, self.template) + + first_group_matches = [match[0] for match in matches] + + return list(set(first_group_matches)) + + def extract_variable_selectors(self) -> list[VariableSelector]: + variable_selectors = [] + for variable_key in self.variable_keys: + remove_hash = variable_key.replace('#', '') + split_result = remove_hash.split('.') + if len(split_result) < 2: + continue + + variable_selectors.append(VariableSelector( + variable=variable_key, + value_selector=split_result + )) + + return variable_selectors + + def format(self, inputs: dict, remove_template_variables: bool = True) -> str: + def replacer(match): + key = match.group(1) + value = inputs.get(key, match.group(0)) # return original matched string if key not found + + if remove_template_variables: + return VariableTemplateParser.remove_template_variables(value) + return value + + prompt = re.sub(REGEX, replacer, self.template) + return re.sub(r'<\|.*?\|>', '', prompt) + + @classmethod + def remove_template_variables(cls, text: str): + return re.sub(REGEX, r'{\1}', text) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index c7424f3f95..d597941ef6 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -291,7 +291,7 @@ class WorkflowConverter: if app_model.mode == AppMode.CHAT.value: http_request_variables.append({ "variable": "_query", - "value_selector": ["start", "sys.query"] + "value_selector": ["sys", ".query"] }) request_body = { @@ -375,7 +375,7 @@ class WorkflowConverter: """ retrieve_config = dataset_config.retrieve_config if new_app_mode == AppMode.ADVANCED_CHAT: - query_variable_selector = ["start", "sys.query"] + query_variable_selector = ["sys", "query"] elif retrieve_config.query_variable: # fetch query variable query_variable_selector = ["start", retrieve_config.query_variable] @@ -449,19 +449,31 @@ class WorkflowConverter: has_context=knowledge_retrieval_node is not None, query_in_prompt=False ) + + template = prompt_template_config['prompt_template'].template + for v in start_node['data']['variables']: + template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}') + prompts = [ { "role": 'user', - "text": prompt_template_config['prompt_template'].template + "text": template } ] else: advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template - prompts = [{ - "role": m.role.value, - "text": m.text - } for m in advanced_chat_prompt_template.messages] \ - if advanced_chat_prompt_template else [] + + prompts = [] + for m in advanced_chat_prompt_template.messages: + if advanced_chat_prompt_template: + text = m.text + for v in start_node['data']['variables']: + text = text.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}') + + prompts.append({ + "role": m.role.value, + "text": text + }) # Completion Model else: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: @@ -475,8 +487,13 @@ class WorkflowConverter: has_context=knowledge_retrieval_node is not None, query_in_prompt=False ) + + template = prompt_template_config['prompt_template'].template + for v in start_node['data']['variables']: + template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}') + prompts = { - "text": prompt_template_config['prompt_template'].template + "text": template } prompt_rules = prompt_template_config['prompt_rules'] @@ -486,9 +503,16 @@ class WorkflowConverter: } else: advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template + if advanced_completion_prompt_template: + text = advanced_completion_prompt_template.prompt + for v in start_node['data']['variables']: + text = text.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}') + else: + text = "" + prompts = { - "text": advanced_completion_prompt_template.prompt, - } if advanced_completion_prompt_template else {"text": ""} + "text": text, + } if advanced_completion_prompt_template.role_prefix: role_prefix = { @@ -519,10 +543,6 @@ class WorkflowConverter: "mode": model_config.mode, "completion_params": completion_params }, - "variables": [{ - "variable": v['variable'], - "value_selector": ["start", v['variable']] - } for v in start_node['data']['variables']], "prompt_template": prompts, "memory": memory, "context": { @@ -532,7 +552,7 @@ class WorkflowConverter: }, "vision": { "enabled": file_upload is not None, - "variable_selector": ["start", "sys.files"] if file_upload is not None else None, + "variable_selector": ["sys", "files"] if file_upload is not None else None, "configs": { "detail": file_upload.image_config['detail'] } if file_upload is not None else None @@ -571,11 +591,7 @@ class WorkflowConverter: "data": { "title": "ANSWER", "type": NodeType.ANSWER.value, - "variables": [{ - "variable": "text", - "value_selector": ["llm", "text"] - }], - "answer": "{{text}}" + "answer": "{{#llm.text#}}" } } diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 999ebf7734..73794336c2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -40,32 +40,17 @@ def test_execute_llm(setup_openai_mock): 'mode': 'chat', 'completion_params': {} }, - 'variables': [ - { - 'variable': 'weather', - 'value_selector': ['abc', 'output'], - }, - { - 'variable': 'query', - 'value_selector': ['sys', 'query'] - } - ], 'prompt_template': [ { 'role': 'system', - 'text': 'you are a helpful assistant.\ntoday\'s weather is {{weather}}.' + 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.' }, { 'role': 'user', - 'text': '{{query}}' + 'text': '{{#sys.query#}}' } ], - 'memory': { - 'window': { - 'enabled': True, - 'size': 2 - } - }, + 'memory': None, 'context': { 'enabled': False }, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 038fda9dac..e2d5be769c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -4,7 +4,6 @@ from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import UserFrom -from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus @@ -21,17 +20,7 @@ def test_execute_answer(): 'data': { 'title': '123', 'type': 'answer', - 'variables': [ - { - 'value_selector': ['llm', 'text'], - 'variable': 'text' - }, - { - 'value_selector': ['start', 'weather'], - 'variable': 'weather' - }, - ], - 'answer': 'Today\'s weather is {{weather}}\n{{text}}\n{{img}}\nFin.' + 'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.' } } ) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index b4a4d6707a..6c1402a518 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -19,7 +19,7 @@ from services.workflow.workflow_converter import WorkflowConverter def default_variables(): return [ VariableEntity( - variable="text-input", + variable="text_input", label="text-input", type=VariableEntity.Type.TEXT_INPUT ), @@ -43,7 +43,7 @@ def test__convert_to_start_node(default_variables): # assert assert isinstance(result["data"]["variables"][0]["type"], str) assert result["data"]["variables"][0]["type"] == "text-input" - assert result["data"]["variables"][0]["variable"] == "text-input" + assert result["data"]["variables"][0]["variable"] == "text_input" assert result["data"]["variables"][1]["variable"] == "paragraph" assert result["data"]["variables"][2]["variable"] == "select" @@ -191,7 +191,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): def test__convert_to_knowledge_retrieval_node_for_chatbot(): - new_app_mode = AppMode.CHAT + new_app_mode = AppMode.ADVANCED_CHAT dataset_config = DatasetEntity( dataset_ids=["dataset_id_1", "dataset_id_2"], @@ -221,7 +221,7 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot(): ) assert node["data"]["type"] == "knowledge-retrieval" - assert node["data"]["query_variable_selector"] == ["start", "sys.query"] + assert node["data"]["query_variable_selector"] == ["sys", "query"] assert node["data"]["dataset_ids"] == dataset_config.dataset_ids assert (node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value) @@ -276,7 +276,7 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app(): def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): - new_app_mode = AppMode.CHAT + new_app_mode = AppMode.ADVANCED_CHAT model = "gpt-4" model_mode = LLMMode.CHAT @@ -298,7 +298,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." ) llm_node = workflow_converter._convert_to_llm_node( @@ -311,16 +311,15 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): assert llm_node["data"]["type"] == "llm" assert llm_node["data"]["model"]['name'] == model assert llm_node["data"]['model']["mode"] == model_mode.value - assert llm_node["data"]["variables"] == [{ - "variable": v.variable, - "value_selector": ["start", v.variable] - } for v in default_variables] - assert llm_node["data"]["prompts"][0]['text'] == prompt_template.simple_prompt_template + '\n' + template = prompt_template.simple_prompt_template + for v in default_variables: + template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') + assert llm_node["data"]["prompt_template"][0]['text'] == template + '\n' assert llm_node["data"]['context']['enabled'] is False def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): - new_app_mode = AppMode.CHAT + new_app_mode = AppMode.ADVANCED_CHAT model = "gpt-3.5-turbo-instruct" model_mode = LLMMode.COMPLETION @@ -342,7 +341,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." ) llm_node = workflow_converter._convert_to_llm_node( @@ -355,16 +354,15 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab assert llm_node["data"]["type"] == "llm" assert llm_node["data"]["model"]['name'] == model assert llm_node["data"]['model']["mode"] == model_mode.value - assert llm_node["data"]["variables"] == [{ - "variable": v.variable, - "value_selector": ["start", v.variable] - } for v in default_variables] - assert llm_node["data"]["prompts"]['text'] == prompt_template.simple_prompt_template + '\n' + template = prompt_template.simple_prompt_template + for v in default_variables: + template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') + assert llm_node["data"]["prompt_template"]['text'] == template + '\n' assert llm_node["data"]['context']['enabled'] is False def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): - new_app_mode = AppMode.CHAT + new_app_mode = AppMode.ADVANCED_CHAT model = "gpt-4" model_mode = LLMMode.CHAT @@ -404,17 +402,16 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) assert llm_node["data"]["type"] == "llm" assert llm_node["data"]["model"]['name'] == model assert llm_node["data"]['model']["mode"] == model_mode.value - assert llm_node["data"]["variables"] == [{ - "variable": v.variable, - "value_selector": ["start", v.variable] - } for v in default_variables] - assert isinstance(llm_node["data"]["prompts"], list) - assert len(llm_node["data"]["prompts"]) == len(prompt_template.advanced_chat_prompt_template.messages) - assert llm_node["data"]["prompts"][0]['text'] == prompt_template.advanced_chat_prompt_template.messages[0].text + assert isinstance(llm_node["data"]["prompt_template"], list) + assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages) + template = prompt_template.advanced_chat_prompt_template.messages[0].text + for v in default_variables: + template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') + assert llm_node["data"]["prompt_template"][0]['text'] == template def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): - new_app_mode = AppMode.CHAT + new_app_mode = AppMode.ADVANCED_CHAT model = "gpt-3.5-turbo-instruct" model_mode = LLMMode.COMPLETION @@ -456,9 +453,8 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var assert llm_node["data"]["type"] == "llm" assert llm_node["data"]["model"]['name'] == model assert llm_node["data"]['model']["mode"] == model_mode.value - assert llm_node["data"]["variables"] == [{ - "variable": v.variable, - "value_selector": ["start", v.variable] - } for v in default_variables] - assert isinstance(llm_node["data"]["prompts"], dict) - assert llm_node["data"]["prompts"]['text'] == prompt_template.advanced_completion_prompt_template.prompt + assert isinstance(llm_node["data"]["prompt_template"], dict) + template = prompt_template.advanced_completion_prompt_template.prompt + for v in default_variables: + template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') + assert llm_node["data"]["prompt_template"]['text'] == template