diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 5d312149f7..daf9641121 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -21,7 +21,7 @@ from fields.conversation_fields import ( ) from libs.helper import datetime_string from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, AppMode +from models.model import AppMode, Conversation, Message, MessageAnnotation class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 9a177116ea..c384e878aa 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -26,7 +26,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, MessageFeedback, AppMode +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 2794735bbb..1bb0ea34c1 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse, marshal_with +from flask_restful import Resource, marshal_with, reqparse from controllers.console import api from controllers.console.app.error import DraftWorkflowNotExist @@ -6,8 +6,8 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.workflow_fields import workflow_fields -from libs.login import login_required, current_user -from models.model import App, ChatbotAppEngine, AppMode +from libs.login import current_user, login_required +from models.model import App, AppMode, ChatbotAppEngine from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index fe35e72304..1c2c4cf5c7 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,7 +5,7 @@ from typing import Optional, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models.model import App, ChatbotAppEngine, AppMode +from models.model import App, AppMode, ChatbotAppEngine def get_app_model(view: Optional[Callable] = None, *, diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index f9678b372f..c6f6268a7a 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -22,7 +22,7 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage, from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_transform import PromptTransform +from core.prompt.simple_prompt_transform import SimplePromptTransform from models.model import App, Message, MessageAnnotation @@ -140,12 +140,11 @@ class AppRunner: :param memory: memory :return: """ - prompt_transform = PromptTransform() + prompt_transform = SimplePromptTransform() # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: prompt_messages, stop = prompt_transform.get_prompt( - app_mode=app_record.mode, prompt_template_entity=prompt_template_entity, inputs=inputs, query=query if query else '', @@ -155,17 +154,7 @@ class AppRunner: model_config=model_config ) else: - prompt_messages = prompt_transform.get_advanced_prompt( - app_mode=app_record.mode, - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - stop = model_config.stop + raise NotImplementedError("Advanced prompt is not supported yet.") return prompt_messages, stop diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index 26e9cc84aa..0e0fe6e3bf 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -15,7 +15,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from extensions.ext_database import db -from models.model import App, Conversation, Message, AppMode +from models.model import App, AppMode, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 2fde422d47..cf463be1df 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -28,7 +28,8 @@ from core.entities.application_entities import ( ModelConfigEntity, PromptTemplateEntity, SensitiveWordAvoidanceEntity, - TextToSpeechEntity, VariableEntity, + TextToSpeechEntity, + VariableEntity, ) from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError @@ -541,8 +542,7 @@ class ApplicationManager: query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs['retrieval_model'] - ), - single_strategy=datasets.get('strategy', 'router') + ) ) ) else: diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index 092591a73f..f8f293d96a 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -156,7 +156,6 @@ class DatasetRetrieveConfigEntity(BaseModel): query_variable: Optional[str] = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy - single_strategy: Optional[str] = None # for temp top_k: Optional[int] = None score_threshold: Optional[float] = None reranking_model: Optional[dict] = None diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py new file mode 100644 index 0000000000..9ca3ef0375 --- /dev/null +++ b/api/core/prompt/advanced_prompt_transform.py @@ -0,0 +1,198 @@ +from typing import Optional + +from core.entities.application_entities import PromptTemplateEntity, ModelConfigEntity, \ + AdvancedCompletionPromptTemplateEntity +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, UserPromptMessage, \ + SystemPromptMessage, AssistantPromptMessage, TextPromptMessageContent +from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.prompt_transform import PromptTransform +from core.prompt.simple_prompt_transform import ModelMode + + +class AdvancePromptTransform(PromptTransform): + """ + Advanced Prompt Transform for Workflow LLM Node. + """ + + def get_prompt(self, prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> list[PromptMessage]: + prompt_messages = [] + + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.COMPLETION: + prompt_messages = self._get_completion_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + elif model_mode == ModelMode.CHAT: + prompt_messages = self._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + return prompt_messages + + def _get_completion_model_prompt_messages(self, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> list[PromptMessage]: + """ + Get completion model prompt messages. + """ + raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt + + prompt_messages = [] + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix + self._set_histories_variable( + memory=memory, + raw_prompt=raw_prompt, + role_prefix=role_prefix, + prompt_template=prompt_template, + prompt_inputs=prompt_inputs, + model_config=model_config + ) + + prompt = prompt_template.format( + prompt_inputs + ) + + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=prompt)) + + return prompt_messages + + def _get_chat_model_prompt_messages(self, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> list[PromptMessage]: + """ + Get chat model prompt messages. + """ + raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages + + prompt_messages = [] + + for prompt_item in raw_prompt_list: + raw_prompt = prompt_item.text + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + prompt = prompt_template.format( + prompt_inputs + ) + + if prompt_item.role == PromptMessageRole.USER: + prompt_messages.append(UserPromptMessage(content=prompt)) + elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: + prompt_messages.append(SystemPromptMessage(content=prompt)) + elif prompt_item.role == PromptMessageRole.ASSISTANT: + prompt_messages.append(AssistantPromptMessage(content=prompt)) + + if memory: + self._append_chat_histories(memory, prompt_messages, model_config) + + if files: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + elif files: + # get last message + last_message = prompt_messages[-1] if prompt_messages else None + if last_message and last_message.role == PromptMessageRole.USER: + # get last user message content and add files + prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + last_message.content = prompt_message_contents + else: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + + return prompt_messages + + def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + if '#context#' in prompt_template.variable_keys: + if context: + prompt_inputs['#context#'] = context + else: + prompt_inputs['#context#'] = '' + + def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + if '#query#' in prompt_template.variable_keys: + if query: + prompt_inputs['#query#'] = query + else: + prompt_inputs['#query#'] = '' + + def _set_histories_variable(self, memory: TokenBufferMemory, + raw_prompt: str, + role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, + prompt_template: PromptTemplateParser, + prompt_inputs: dict, + model_config: ModelConfigEntity) -> None: + if '#histories#' in prompt_template.variable_keys: + if memory: + inputs = {'#histories#': '', **prompt_inputs} + prompt_template = PromptTemplateParser(raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + tmp_human_message = UserPromptMessage( + content=prompt_template.format(prompt_inputs) + ) + + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + + histories = self._get_history_messages_from_memory( + memory=memory, + max_token_limit=rest_tokens, + human_prefix=role_prefix.user, + ai_prefix=role_prefix.assistant + ) + prompt_inputs['#histories#'] = histories + else: + prompt_inputs['#histories#'] = '' diff --git a/api/core/prompt/generate_prompts/baichuan_chat.json b/api/core/prompt/generate_prompts/baichuan_chat.json index 5bf83cd9c7..03b6a53cff 100644 --- a/api/core/prompt/generate_prompts/baichuan_chat.json +++ b/api/core/prompt/generate_prompts/baichuan_chat.json @@ -1,13 +1,13 @@ { "human_prefix": "用户", "assistant_prefix": "助手", - "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n\n", - "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n", + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n", + "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt", "histories_prompt" ], - "query_prompt": "\n\n用户:{{query}}", + "query_prompt": "\n\n用户:{{#query#}}", "stops": ["用户:"] } \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/baichuan_completion.json b/api/core/prompt/generate_prompts/baichuan_completion.json index a3a2054e83..ae8c0dac53 100644 --- a/api/core/prompt/generate_prompts/baichuan_completion.json +++ b/api/core/prompt/generate_prompts/baichuan_completion.json @@ -1,9 +1,9 @@ { - "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n", + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n", "system_prompt_orders": [ "context_prompt", "pre_prompt" ], - "query_prompt": "{{query}}", + "query_prompt": "{{#query#}}", "stops": null } \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/generate_prompts/common_chat.json index 709a8d8866..d398a512e6 100644 --- a/api/core/prompt/generate_prompts/common_chat.json +++ b/api/core/prompt/generate_prompts/common_chat.json @@ -1,13 +1,13 @@ { "human_prefix": "Human", "assistant_prefix": "Assistant", - "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", - "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{histories}}\n\n\n", + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt", "histories_prompt" ], - "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ", + "query_prompt": "\n\nHuman: {{#query#}}\n\nAssistant: ", "stops": ["\nHuman:", ""] } diff --git a/api/core/prompt/generate_prompts/common_completion.json b/api/core/prompt/generate_prompts/common_completion.json index 9e7e8d68ef..c148772010 100644 --- a/api/core/prompt/generate_prompts/common_completion.json +++ b/api/core/prompt/generate_prompts/common_completion.json @@ -1,9 +1,9 @@ { - "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt" ], - "query_prompt": "{{query}}", + "query_prompt": "{{#query#}}", "stops": null } \ No newline at end of file diff --git a/api/core/prompt/prompt_builder.py b/api/core/prompt/prompt_builder.py deleted file mode 100644 index 7727b0f92e..0000000000 --- a/api/core/prompt/prompt_builder.py +++ /dev/null @@ -1,10 +0,0 @@ -from core.prompt.prompt_template import PromptTemplateParser - - -class PromptBuilder: - @classmethod - def parse_prompt(cls, prompt: str, inputs: dict) -> str: - prompt_template = PromptTemplateParser(prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt = prompt_template.format(prompt_inputs) - return prompt diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/prompt_template.py index 32c5a791de..454f92e3b7 100644 --- a/api/core/prompt/prompt_template.py +++ b/api/core/prompt/prompt_template.py @@ -32,7 +32,8 @@ class PromptTemplateParser: return PromptTemplateParser.remove_template_variables(value) return value - return re.sub(REGEX, replacer, self.template) + prompt = re.sub(REGEX, replacer, self.template) + return re.sub(r'<\|.*?\|>', '', prompt) @classmethod def remove_template_variables(cls, text: str): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index abbfa96249..c0f70ae0bb 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,393 +1,13 @@ -import enum -import json -import os -import re from typing import Optional, cast -from core.entities.application_entities import ( - AdvancedCompletionPromptTemplateEntity, - ModelConfigEntity, - PromptTemplateEntity, -) -from core.file.file_obj import FileObj +from core.entities.application_entities import ModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) +from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import PromptTemplateParser -from models.model import AppMode - - -class ModelMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' - - @classmethod - def value_of(cls, value: str) -> 'ModelMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') class PromptTransform: - def get_prompt(self, - app_mode: str, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> \ - tuple[list[PromptMessage], Optional[list[str]]]: - app_mode = AppMode.value_of(app_mode) - model_mode = ModelMode.value_of(model_config.mode) - - prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name( - app_mode=app_mode, - provider=model_config.provider, - model=model_config.model - )) - - if app_mode == AppMode.CHAT and model_mode == ModelMode.CHAT: - stops = None - - prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages( - prompt_rules=prompt_rules, - pre_prompt=prompt_template_entity.simple_prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - else: - stops = prompt_rules.get('stops') - if stops is not None and len(stops) == 0: - stops = None - - prompt_messages = self._get_simple_others_prompt_messages( - prompt_rules=prompt_rules, - pre_prompt=prompt_template_entity.simple_prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - return prompt_messages, stops - - def get_advanced_prompt(self, app_mode: str, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - app_mode = AppMode.value_of(app_mode) - model_mode = ModelMode.value_of(model_config.mode) - - prompt_messages = [] - - if app_mode == AppMode.CHAT: - if model_mode == ModelMode.COMPLETION: - prompt_messages = self._get_chat_app_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - elif model_mode == ModelMode.CHAT: - prompt_messages = self._get_chat_app_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - elif app_mode == AppMode.COMPLETION: - if model_mode == ModelMode.CHAT: - prompt_messages = self._get_completion_app_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - files=files, - context=context, - ) - elif model_mode == ModelMode.COMPLETION: - prompt_messages = self._get_completion_app_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - context=context, - ) - - return prompt_messages - - def _get_history_messages_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None) -> str: - """Get memory messages.""" - kwargs = { - "max_token_limit": max_token_limit - } - - if human_prefix: - kwargs['human_prefix'] = human_prefix - - if ai_prefix: - kwargs['ai_prefix'] = ai_prefix - - return memory.get_history_prompt_text( - **kwargs - ) - - def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int) -> list[PromptMessage]: - """Get memory messages.""" - return memory.get_history_prompt_messages( - max_token_limit=max_token_limit - ) - - def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: - # baichuan - if provider == 'baichuan': - return self._prompt_file_name_for_baichuan(app_mode) - - baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] - if provider in baichuan_supported_providers and 'baichuan' in model.lower(): - return self._prompt_file_name_for_baichuan(app_mode) - - # common - if app_mode == AppMode.COMPLETION: - return 'common_completion' - else: - return 'common_chat' - - def _prompt_file_name_for_baichuan(self, app_mode: AppMode) -> str: - if app_mode == AppMode.COMPLETION: - return 'baichuan_completion' - else: - return 'baichuan_chat' - - def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: - # Get the absolute path of the subdirectory - prompt_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'generate_prompts') - - json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') - # Open the JSON file and read its content - with open(json_file_path, encoding='utf-8') as json_file: - return json.load(json_file) - - def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - prompt_messages = [] - - context_prompt_content = '' - if context and 'context_prompt' in prompt_rules: - prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) - context_prompt_content = prompt_template.format( - {'context': context} - ) - - pre_prompt_content = '' - if pre_prompt: - prompt_template = PromptTemplateParser(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - pre_prompt_content = prompt_template.format( - prompt_inputs - ) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += pre_prompt_content - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - if prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - - self._append_chat_histories( - memory=memory, - prompt_messages=prompt_messages, - model_config=model_config - ) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _get_simple_others_prompt_messages(self, prompt_rules: dict, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - files: list[FileObj], - model_config: ModelConfigEntity) -> list[PromptMessage]: - context_prompt_content = '' - if context and 'context_prompt' in prompt_rules: - prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) - context_prompt_content = prompt_template.format( - {'context': context} - ) - - pre_prompt_content = '' - if pre_prompt: - prompt_template = PromptTemplateParser(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - pre_prompt_content = prompt_template.format( - prompt_inputs - ) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += pre_prompt_content - - query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}' - - if memory and 'histories_prompt' in prompt_rules: - # append chat histories - tmp_human_message = UserPromptMessage( - content=PromptBuilder.parse_prompt( - prompt=prompt + query_prompt, - inputs={ - 'query': query - } - ) - ) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) - - histories = self._get_history_messages_from_memory( - memory=memory, - max_token_limit=rest_tokens, - ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', - human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' - ) - prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt']) - histories_prompt_content = prompt_template.format({'histories': histories}) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += (pre_prompt_content + '\n') if pre_prompt_content else '' - elif order == 'histories_prompt': - prompt += histories_prompt_content - - prompt_template = PromptTemplateParser(template=query_prompt) - query_prompt_content = prompt_template.format({'query': query}) - - prompt += query_prompt_content - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - model_mode = ModelMode.value_of(model_config.mode) - - if model_mode == ModelMode.CHAT and files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message = UserPromptMessage(content=prompt_message_contents) - else: - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message = UserPromptMessage(content=prompt_message_contents) - else: - prompt_message = UserPromptMessage(content=prompt) - - return [prompt_message] - - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: - if '#context#' in prompt_template.variable_keys: - if context: - prompt_inputs['#context#'] = context - else: - prompt_inputs['#context#'] = '' - - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: - if '#query#' in prompt_template.variable_keys: - if query: - prompt_inputs['#query#'] = query - else: - prompt_inputs['#query#'] = '' - - def _set_histories_variable(self, memory: TokenBufferMemory, - raw_prompt: str, - role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, - prompt_template: PromptTemplateParser, - prompt_inputs: dict, - model_config: ModelConfigEntity) -> None: - if '#histories#' in prompt_template.variable_keys: - if memory: - tmp_human_message = UserPromptMessage( - content=PromptBuilder.parse_prompt( - prompt=raw_prompt, - inputs={'#histories#': '', **prompt_inputs} - ) - ) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) - - histories = self._get_history_messages_from_memory( - memory=memory, - max_token_limit=rest_tokens, - human_prefix=role_prefix.user, - ai_prefix=role_prefix.assistant - ) - prompt_inputs['#histories#'] = histories - else: - prompt_inputs['#histories#'] = '' - def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> None: @@ -422,152 +42,28 @@ class PromptTransform: return rest_tokens - def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str: - prompt = prompt_template.format( - prompt_inputs + def _get_history_messages_from_memory(self, memory: TokenBufferMemory, + max_token_limit: int, + human_prefix: Optional[str] = None, + ai_prefix: Optional[str] = None) -> str: + """Get memory messages.""" + kwargs = { + "max_token_limit": max_token_limit + } + + if human_prefix: + kwargs['human_prefix'] = human_prefix + + if ai_prefix: + kwargs['ai_prefix'] = ai_prefix + + return memory.get_history_prompt_text( + **kwargs ) - prompt = re.sub(r'<\|.*?\|>', '', prompt) - return prompt - - def _get_chat_app_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt - role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - - prompt_messages = [] - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - self._set_query_variable(query, prompt_template, prompt_inputs) - - self._set_histories_variable( - memory=memory, - raw_prompt=raw_prompt, - role_prefix=role_prefix, - prompt_template=prompt_template, - prompt_inputs=prompt_inputs, - model_config=model_config + def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, + max_token_limit: int) -> list[PromptMessage]: + """Get memory messages.""" + return memory.get_history_prompt_messages( + max_token_limit=max_token_limit ) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=prompt)) - - return prompt_messages - - def _get_chat_app_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages - - prompt_messages = [] - - for prompt_item in raw_prompt_list: - raw_prompt = prompt_item.text - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if prompt_item.role == PromptMessageRole.USER: - prompt_messages.append(UserPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.ASSISTANT: - prompt_messages.append(AssistantPromptMessage(content=prompt)) - - self._append_chat_histories(memory, prompt_messages, model_config) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _get_completion_app_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - context: Optional[str]) -> list[PromptMessage]: - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt - - prompt_messages = [] - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - prompt_messages.append(UserPromptMessage(content=prompt)) - - return prompt_messages - - def _get_completion_app_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - files: list[FileObj], - context: Optional[str]) -> list[PromptMessage]: - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages - - prompt_messages = [] - - for prompt_item in raw_prompt_list: - raw_prompt = prompt_item.text - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if prompt_item.role == PromptMessageRole.USER: - prompt_messages.append(UserPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.ASSISTANT: - prompt_messages.append(AssistantPromptMessage(content=prompt)) - - for prompt_message in prompt_messages[::-1]: - if prompt_message.role == PromptMessageRole.USER: - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt_message.content)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message.content = prompt_message_contents - break - - return prompt_messages diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py new file mode 100644 index 0000000000..a898c37c4a --- /dev/null +++ b/api/core/prompt/simple_prompt_transform.py @@ -0,0 +1,298 @@ +import enum +import json +import os +from typing import Optional, Tuple + +from core.entities.application_entities import ( + ModelConfigEntity, + PromptTemplateEntity, +) +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.prompt_transform import PromptTransform +from models.model import AppMode + + +class ModelMode(enum.Enum): + COMPLETION = 'completion' + CHAT = 'chat' + + @classmethod + def value_of(cls, value: str) -> 'ModelMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + +prompt_file_contents = {} + + +class SimplePromptTransform(PromptTransform): + """ + Simple Prompt Transform for Chatbot App Basic Mode. + """ + def get_prompt(self, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> \ + tuple[list[PromptMessage], Optional[list[str]]]: + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.CHAT: + prompt_messages, stops = self._get_chat_model_prompt_messages( + pre_prompt=prompt_template_entity.simple_prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + else: + prompt_messages, stops = self._get_completion_model_prompt_messages( + pre_prompt=prompt_template_entity.simple_prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + return prompt_messages, stops + + def get_prompt_str_and_rules(self, app_mode: AppMode, + model_config: ModelConfigEntity, + pre_prompt: str, + inputs: dict, + query: Optional[str] = None, + context: Optional[str] = None, + histories: Optional[str] = None, + ) -> Tuple[str, dict]: + # get prompt template + prompt_template_config = self.get_prompt_template( + app_mode=app_mode, + provider=model_config.provider, + model=model_config.model, + pre_prompt=pre_prompt, + has_context=context is not None, + query_in_prompt=query is not None, + with_memory_prompt=histories is not None + ) + + variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs} + + for v in prompt_template_config['special_variable_keys']: + # support #context#, #query# and #histories# + if v == '#context#': + variables['#context#'] = context if context else '' + elif v == '#query#': + variables['#query#'] = query if query else '' + elif v == '#histories#': + variables['#histories#'] = histories if histories else '' + + prompt_template = prompt_template_config['prompt_template'] + prompt = prompt_template.format(variables) + + return prompt, prompt_template_config['prompt_rules'] + + def get_prompt_template(self, app_mode: AppMode, + provider: str, + model: str, + pre_prompt: str, + has_context: bool, + query_in_prompt: bool, + with_memory_prompt: bool = False) -> dict: + prompt_rules = self._get_prompt_rule( + app_mode=app_mode, + provider=provider, + model=model + ) + + custom_variable_keys = [] + special_variable_keys = [] + + prompt = '' + for order in prompt_rules['system_prompt_orders']: + if order == 'context_prompt' and has_context: + prompt += prompt_rules['context_prompt'] + special_variable_keys.append('#context#') + elif order == 'pre_prompt' and pre_prompt: + prompt += pre_prompt + '\n' + pre_prompt_template = PromptTemplateParser(template=pre_prompt) + custom_variable_keys = pre_prompt_template.variable_keys + elif order == 'histories_prompt' and with_memory_prompt: + prompt += prompt_rules['histories_prompt'] + special_variable_keys.append('#histories#') + + if query_in_prompt: + prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}' + special_variable_keys.append('#query#') + + return { + "prompt_template": PromptTemplateParser(template=prompt), + "custom_variable_keys": custom_variable_keys, + "special_variable_keys": special_variable_keys, + "prompt_rules": prompt_rules + } + + def _get_chat_model_prompt_messages(self, pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ + -> Tuple[list[PromptMessage], Optional[list[str]]]: + prompt_messages = [] + + # get prompt + prompt, _ = self.get_prompt_str_and_rules( + app_mode=AppMode.CHAT, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context + ) + + if prompt: + prompt_messages.append(SystemPromptMessage(content=prompt)) + + self._append_chat_histories( + memory=memory, + prompt_messages=prompt_messages, + model_config=model_config + ) + + prompt_messages.append(self.get_last_user_message(query, files)) + + return prompt_messages, None + + def _get_completion_model_prompt_messages(self, pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ + -> Tuple[list[PromptMessage], Optional[list[str]]]: + # get prompt + prompt, prompt_rules = self.get_prompt_str_and_rules( + app_mode=AppMode.CHAT, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context + ) + + if memory: + tmp_human_message = UserPromptMessage( + content=prompt + ) + + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + histories = self._get_history_messages_from_memory( + memory=memory, + max_token_limit=rest_tokens, + ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + ) + + # get prompt + prompt, prompt_rules = self.get_prompt_str_and_rules( + app_mode=AppMode.CHAT, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context, + histories=histories + ) + + stops = prompt_rules.get('stops') + if stops is not None and len(stops) == 0: + stops = None + + return [self.get_last_user_message(prompt, files)], stops + + def get_last_user_message(self, prompt: str, files: list[FileObj]) -> UserPromptMessage: + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_message = UserPromptMessage(content=prompt_message_contents) + else: + prompt_message = UserPromptMessage(content=prompt) + + return prompt_message + + def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict: + """ + Get simple prompt rule. + :param app_mode: app mode + :param provider: model provider + :param model: model name + :return: + """ + prompt_file_name = self._prompt_file_name( + app_mode=app_mode, + provider=provider, + model=model + ) + + # Check if the prompt file is already loaded + if prompt_file_name in prompt_file_contents: + return prompt_file_contents[prompt_file_name] + + # Get the absolute path of the subdirectory + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') + json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') + + # Open the JSON file and read its content + with open(json_file_path, encoding='utf-8') as json_file: + content = json.load(json_file) + + # Store the content of the prompt file + prompt_file_contents[prompt_file_name] = content + + def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: + # baichuan + is_baichuan = False + if provider == 'baichuan': + is_baichuan = True + else: + baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] + if provider in baichuan_supported_providers and 'baichuan' in model.lower(): + is_baichuan = True + + if is_baichuan: + if app_mode == AppMode.WORKFLOW: + return 'baichuan_completion' + else: + return 'baichuan_chat' + + # common + if app_mode == AppMode.WORKFLOW: + return 'common_completion' + else: + return 'common_chat' diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index d9cd6c03bb..c778084475 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -2,7 +2,6 @@ from flask_restful import fields from libs.helper import TimestampField - annotation_fields = { "id": fields.String, "question": fields.String, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 9dc92ea43b..decdc0567f 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,7 +5,6 @@ from flask_restful import fields from fields.member_fields import simple_account_fields from libs.helper import TimestampField - workflow_fields = { 'id': fields.String, 'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index c2fad83aaf..7d18f4f675 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -2,9 +2,17 @@ import json from typing import Optional from core.application_manager import ApplicationManager -from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, FileUploadEntity, \ - ExternalDataVariableEntity, DatasetEntity, VariableEntity +from core.entities.application_entities import ( + DatasetEntity, + ExternalDataVariableEntity, + FileUploadEntity, + ModelConfigEntity, + PromptTemplateEntity, + VariableEntity, DatasetRetrieveConfigEntity, +) +from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils import helper +from core.prompt.simple_prompt_transform import SimplePromptTransform from core.workflow.entities.NodeEntities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType from extensions.ext_database import db @@ -32,6 +40,9 @@ class WorkflowConverter: :param account: Account instance :return: workflow instance """ + # get new app mode + new_app_mode = self._get_new_app_mode(app_model) + # get original app config app_model_config = app_model.app_model_config @@ -75,14 +86,17 @@ class WorkflowConverter: # convert to knowledge retrieval node if app_model_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( - dataset=app_model_config.dataset, - show_retrieve_source=app_model_config.show_retrieve_source + new_app_mode=new_app_mode, + dataset_config=app_model_config.dataset ) - graph = self._append_node(graph, knowledge_retrieval_node) + if knowledge_retrieval_node: + graph = self._append_node(graph, knowledge_retrieval_node) # convert to llm node llm_node = self._convert_to_llm_node( + new_app_mode=new_app_mode, + graph=graph, model_config=app_model_config.model_config, prompt_template=app_model_config.prompt_template, file_upload=app_model_config.file_upload @@ -95,14 +109,11 @@ class WorkflowConverter: graph = self._append_node(graph, end_node) - # get new app mode - app_mode = self._get_new_app_mode(app_model) - # create workflow record workflow = Workflow( tenant_id=app_model.tenant_id, app_id=app_model.id, - type=WorkflowType.from_app_mode(app_mode).value, + type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), created_by=account.id @@ -124,7 +135,7 @@ class WorkflowConverter: new_app_model_config.completion_prompt_config = '' new_app_model_config.dataset_configs = '' new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ - if app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value + if new_app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value new_app_model_config.workflow_id = workflow.id db.session.add(new_app_model_config) @@ -157,18 +168,22 @@ class WorkflowConverter: # TODO: implement pass - def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset: DatasetEntity) -> dict: + def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset_config: DatasetEntity) \ + -> Optional[dict]: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode - :param dataset: dataset + :param dataset_config: dataset :return: """ - # TODO: implement + retrieve_config = dataset_config.retrieve_config if new_app_mode == AppMode.CHAT: query_variable_selector = ["start", "sys.query"] + elif retrieve_config.query_variable: + # fetch query variable + query_variable_selector = ["start", retrieve_config.query_variable] else: - pass + return None return { "id": "knowledge-retrieval", @@ -176,20 +191,139 @@ class WorkflowConverter: "data": { "title": "KNOWLEDGE RETRIEVAL", "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + "query_variable_selector": query_variable_selector, + "dataset_ids": dataset_config.dataset_ids, + "retrieval_mode": retrieve_config.retrieve_strategy.value, + "multiple_retrieval_config": { + "top_k": retrieve_config.top_k, + "score_threshold": retrieve_config.score_threshold, + "reranking_model": retrieve_config.reranking_model + } + if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE + else None, } } - def _convert_to_llm_node(self, model_config: ModelConfigEntity, + def _convert_to_llm_node(self, new_app_mode: AppMode, + graph: dict, + model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, file_upload: Optional[FileUploadEntity] = None) -> dict: """ Convert to LLM Node + :param new_app_mode: new app mode + :param graph: graph :param model_config: model config :param prompt_template: prompt template :param file_upload: file upload config (optional) """ - # TODO: implement - pass + # fetch start and knowledge retrieval node + start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes'])) + knowledge_retrieval_node = next(filter( + lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value, + graph['nodes'] + ), None) + + role_prefix = None + + # Chat Model + if model_config.mode == LLMMode.CHAT.value: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False + ) + prompts = [ + { + "role": 'user', + "text": prompt_template_config['prompt_template'].template + } + ] + else: + advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template + prompts = [helper.dump_model(m) for m in advanced_chat_prompt_template.messages] \ + if advanced_chat_prompt_template else [] + # Completion Model + else: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False + ) + prompts = { + "text": prompt_template_config['prompt_template'].template + } + + prompt_rules = prompt_template_config['prompt_rules'] + role_prefix = { + "user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + "assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + } + else: + advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template + prompts = { + "text": advanced_completion_prompt_template.prompt, + } if advanced_completion_prompt_template else {"text": ""} + + if advanced_completion_prompt_template.role_prefix: + role_prefix = { + "user": advanced_completion_prompt_template.role_prefix.user, + "assistant": advanced_completion_prompt_template.role_prefix.assistant + } + + memory = None + if new_app_mode == AppMode.CHAT: + memory = { + "role_prefix": role_prefix, + "window": { + "enabled": False + } + } + + return { + "id": "llm", + "position": None, + "data": { + "title": "LLM", + "type": NodeType.LLM.value, + "model": { + "provider": model_config.provider, + "name": model_config.model, + "mode": model_config.mode, + "completion_params": model_config.parameters.update({"stop": model_config.stop}) + }, + "variables": [{ + "variable": v['variable'], + "value_selector": ["start", v['variable']] + } for v in start_node['data']['variables']], + "prompts": prompts, + "memory": memory, + "context": { + "enabled": knowledge_retrieval_node is not None, + "variable_selector": ["knowledge-retrieval", "result"] + if knowledge_retrieval_node is not None else None + }, + "vision": { + "enabled": file_upload is not None, + "variable_selector": ["start", "sys.files"] if file_upload is not None else None, + "configs": { + "detail": file_upload.image_config['detail'] + } if file_upload is not None else None + } + } + } def _convert_to_end_node(self, app_model: App) -> dict: """