mirror of https://github.com/langgenius/dify.git
add llm node
This commit is contained in:
parent
e8751bebfa
commit
d88ac6c238
|
|
@ -23,7 +23,8 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
|
|||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
|
||||
|
||||
|
|
@ -155,13 +156,39 @@ class AppRunner:
|
|||
model_config=model_config
|
||||
)
|
||||
else:
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
||||
prompt_template = CompletionModelPromptTemplate(
|
||||
text=advanced_completion_prompt_template.prompt
|
||||
)
|
||||
|
||||
memory_config.role_prefix = MemoryConfig.RolePrefix(
|
||||
user=advanced_completion_prompt_template.role_prefix.user,
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant
|
||||
)
|
||||
else:
|
||||
prompt_template = []
|
||||
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
|
||||
prompt_template.append(ChatModelMessage(
|
||||
text=message.text,
|
||||
role=message.role
|
||||
))
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,17 +30,12 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
|||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from events.message_event import message_was_created
|
||||
|
|
@ -438,7 +433,10 @@ class EasyUIBasedGenerateTaskPipeline:
|
|||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
|
||||
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
|
||||
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
|
|
@ -582,77 +580,6 @@ class EasyUIBasedGenerateTaskPipeline:
|
|||
"""
|
||||
return "data: " + json.dumps(response) + "\n\n"
|
||||
|
||||
def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if self._model_config.mode == ModelMode.CHAT.value:
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = 'user'
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
role = 'assistant'
|
||||
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
role = 'system'
|
||||
else:
|
||||
continue
|
||||
|
||||
text = ''
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
prompts.append({
|
||||
"role": role,
|
||||
"text": text,
|
||||
"files": files
|
||||
})
|
||||
else:
|
||||
prompt_message = prompt_messages[0]
|
||||
text = ''
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
params = {
|
||||
"role": 'user',
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if files:
|
||||
params['files'] = files
|
||||
|
||||
prompts.append(params)
|
||||
|
||||
return prompts
|
||||
|
||||
def _init_output_moderation(self) -> Optional[OutputModeration]:
|
||||
"""
|
||||
Init output moderation.
|
||||
|
|
|
|||
|
|
@ -24,11 +24,11 @@ class ModelInstance:
|
|||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
|
||||
self._provider_model_bundle = provider_model_bundle
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.model_type_instance = self._provider_model_bundle.model_type_instance
|
||||
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
||||
|
||||
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
@ -12,6 +11,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
|
|
@ -22,11 +22,12 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
Advanced Prompt Transform for Workflow LLM Node.
|
||||
"""
|
||||
|
||||
def get_prompt(self, prompt_template_entity: PromptTemplateEntity,
|
||||
def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list[FileObj],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
|
|
@ -34,21 +35,23 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
prompt_messages = self._get_completion_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
elif model_mode == ModelMode.CHAT:
|
||||
prompt_messages = self._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
|
|
@ -56,17 +59,18 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
return prompt_messages
|
||||
|
||||
def _get_completion_model_prompt_messages(self,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
prompt_template: CompletionModelPromptTemplate,
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileObj],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
"""
|
||||
Get completion model prompt messages.
|
||||
"""
|
||||
raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt
|
||||
raw_prompt = prompt_template.text
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
|
|
@ -75,15 +79,17 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config
|
||||
)
|
||||
if memory and memory_config:
|
||||
role_prefix = memory_config.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
|
@ -104,17 +110,18 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
return prompt_messages
|
||||
|
||||
def _get_chat_model_prompt_messages(self,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
prompt_template: list[ChatModelMessage],
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileObj],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
"""
|
||||
Get chat model prompt messages.
|
||||
"""
|
||||
raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages
|
||||
raw_prompt_list = prompt_template
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
|
|
@ -137,8 +144,8 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
||||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||
|
||||
if memory:
|
||||
prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config)
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
|
|
@ -195,8 +202,9 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
return prompt_inputs
|
||||
|
||||
def _set_histories_variable(self, memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
prompt_template: PromptTemplateParser,
|
||||
prompt_inputs: dict,
|
||||
model_config: ModelConfigWithCredentialsEntity) -> dict:
|
||||
|
|
@ -213,6 +221,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
max_token_limit=rest_tokens,
|
||||
human_prefix=role_prefix.user,
|
||||
ai_prefix=role_prefix.assistant
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
"""
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
"""
|
||||
Completion Model Prompt Template.
|
||||
"""
|
||||
text: str
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""
|
||||
Memory Config.
|
||||
"""
|
||||
class RolePrefix(BaseModel):
|
||||
"""
|
||||
Role Prefix.
|
||||
"""
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
class WindowConfig(BaseModel):
|
||||
"""
|
||||
Window Config.
|
||||
"""
|
||||
enabled: bool
|
||||
size: Optional[int] = None
|
||||
|
||||
role_prefix: Optional[RolePrefix] = None
|
||||
window: WindowConfig
|
||||
|
|
@ -5,19 +5,22 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
|||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
|
||||
|
||||
class PromptTransform:
|
||||
def _append_chat_histories(self, memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
|
||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> int:
|
||||
def _calculate_rest_token(self, prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
|
@ -44,6 +47,7 @@ class PromptTransform:
|
|||
return rest_tokens
|
||||
|
||||
def _get_history_messages_from_memory(self, memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: Optional[str] = None,
|
||||
ai_prefix: Optional[str] = None) -> str:
|
||||
|
|
@ -58,13 +62,22 @@ class PromptTransform:
|
|||
if ai_prefix:
|
||||
kwargs['ai_prefix'] = ai_prefix
|
||||
|
||||
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
|
||||
kwargs['message_limit'] = memory_config.window.size
|
||||
|
||||
return memory.get_history_prompt_text(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=memory_config.window.size
|
||||
if (memory_config.window.enabled
|
||||
and memory_config.window.size is not None
|
||||
and memory_config.window.size > 0)
|
||||
else 10
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.model import AppMode
|
||||
|
|
@ -182,6 +183,11 @@ class SimplePromptTransform(PromptTransform):
|
|||
if memory:
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory=memory,
|
||||
memory_config=MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False,
|
||||
)
|
||||
),
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=model_config
|
||||
)
|
||||
|
|
@ -220,6 +226,11 @@ class SimplePromptTransform(PromptTransform):
|
|||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False,
|
||||
)
|
||||
),
|
||||
max_token_limit=rest_tokens,
|
||||
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
|
|
|
|||
|
|
@ -0,0 +1,85 @@
|
|||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
|
||||
|
||||
class PromptMessageUtil:
|
||||
@staticmethod
|
||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param model_mode: model mode
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if model_mode == ModelMode.CHAT.value:
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = 'user'
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
role = 'assistant'
|
||||
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
role = 'system'
|
||||
else:
|
||||
continue
|
||||
|
||||
text = ''
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
prompts.append({
|
||||
"role": role,
|
||||
"text": text,
|
||||
"files": files
|
||||
})
|
||||
else:
|
||||
prompt_message = prompt_messages[0]
|
||||
text = ''
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
params = {
|
||||
"role": 'user',
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if files:
|
||||
params['files'] = files
|
||||
|
||||
prompts.append(params)
|
||||
|
||||
return prompts
|
||||
|
|
@ -12,7 +12,7 @@ class NodeType(Enum):
|
|||
"""
|
||||
START = 'start'
|
||||
END = 'end'
|
||||
DIRECT_ANSWER = 'direct-answer'
|
||||
ANSWER = 'answer'
|
||||
LLM = 'llm'
|
||||
KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval'
|
||||
IF_ELSE = 'if-else'
|
||||
|
|
|
|||
|
|
@ -5,14 +5,14 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
|||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class DirectAnswerNode(BaseNode):
|
||||
_node_data_cls = DirectAnswerNodeData
|
||||
node_type = NodeType.DIRECT_ANSWER
|
||||
class AnswerNode(BaseNode):
|
||||
_node_data_cls = AnswerNodeData
|
||||
node_type = NodeType.ANSWER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
|
|
@ -2,9 +2,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
|||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class DirectAnswerNodeData(BaseNodeData):
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
DirectAnswer Node Data.
|
||||
Answer Node Data.
|
||||
"""
|
||||
variables: list[VariableSelector] = []
|
||||
answer: str
|
||||
|
|
@ -1,8 +1,51 @@
|
|||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
"""
|
||||
Context Config.
|
||||
"""
|
||||
enabled: bool
|
||||
variable_selector: Optional[list[str]] = None
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
"""
|
||||
Vision Config.
|
||||
"""
|
||||
class Configs(BaseModel):
|
||||
"""
|
||||
Configs.
|
||||
"""
|
||||
detail: Literal['low', 'high']
|
||||
|
||||
enabled: bool
|
||||
configs: Optional[Configs] = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
"""
|
||||
LLM Node Data.
|
||||
"""
|
||||
pass
|
||||
model: ModelConfig
|
||||
variables: list[VariableSelector] = []
|
||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig
|
||||
|
|
|
|||
|
|
@ -1,10 +1,27 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
|
|
@ -20,7 +37,341 @@ class LLMNode(BaseNode):
|
|||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
pass
|
||||
node_inputs = None
|
||||
process_data = None
|
||||
|
||||
try:
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data, variable_pool)
|
||||
|
||||
node_inputs = {
|
||||
**inputs
|
||||
}
|
||||
|
||||
# fetch files
|
||||
files: list[FileObj] = self._fetch_files(node_data, variable_pool)
|
||||
|
||||
if files:
|
||||
node_inputs['#files#'] = [{
|
||||
'type': file.type.value,
|
||||
'transfer_method': file.transfer_method.value,
|
||||
'url': file.url,
|
||||
'upload_file_id': file.upload_file_id,
|
||||
} for file in files]
|
||||
|
||||
# fetch context value
|
||||
context = self._fetch_context(node_data, variable_pool)
|
||||
|
||||
if context:
|
||||
node_inputs['#context#'] = context
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data, variable_pool, model_instance)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
process_data = {
|
||||
'model_mode': model_config.mode,
|
||||
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
node_data=node_data,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data
|
||||
)
|
||||
|
||||
outputs = {
|
||||
'text': result_text,
|
||||
'usage': jsonable_encoder(usage)
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data: LLMNodeData,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str]) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data: node data
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data.model.completion_params,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
return self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
full_text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
self.publish_text_chunk(text=text)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
||||
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Fetch inputs
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
inputs = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.value_selector} not found')
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileObj]:
|
||||
"""
|
||||
Fetch files
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.vision.enabled:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
|
||||
"""
|
||||
Fetch context
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.context.enabled:
|
||||
return None
|
||||
|
||||
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
return context_value
|
||||
elif isinstance(context_value, list):
|
||||
context_str = ''
|
||||
for item in context_value:
|
||||
if 'content' not in item:
|
||||
raise ValueError(f'Invalid context structure: {item}')
|
||||
|
||||
context_str += item['content'] + '\n'
|
||||
|
||||
return context_str.strip()
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data.model.name
|
||||
provider_name = node_data.model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider_name,
|
||||
model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = node_data.model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data.model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _fetch_memory(self, node_data: LLMNodeData,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
"""
|
||||
Fetch memory
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
# get conversation
|
||||
conversation = db.session.query(Conversation).filter(
|
||||
Conversation.tenant_id == self.tenant_id,
|
||||
Conversation.app_id == self.app_id,
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt messages
|
||||
:param node_data: node data
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param context: context
|
||||
:param memory: memory
|
||||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=node_data.prompt_template,
|
||||
inputs=inputs,
|
||||
query='',
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
|
|
@ -29,9 +380,20 @@ class LLMNode(BaseNode):
|
|||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# TODO extract variable selector to variable mapping for single step debugging
|
||||
return {}
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
|
|||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import BaseNode, UserFrom
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
|
|
@ -24,13 +24,12 @@ from extensions.ext_database import db
|
|||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
node_classes = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.DIRECT_ANSWER: DirectAnswerNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
|
|
@ -156,7 +155,7 @@ class WorkflowEngineManager:
|
|||
callbacks=callbacks
|
||||
)
|
||||
|
||||
if next_node.node_type == NodeType.END:
|
||||
if next_node.node_type in [NodeType.END, NodeType.ANSWER]:
|
||||
break
|
||||
|
||||
predecessor_node = next_node
|
||||
|
|
@ -402,10 +401,16 @@ class WorkflowEngineManager:
|
|||
# add to workflow_nodes_and_results
|
||||
workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
|
||||
|
||||
# run node, result must have inputs, process_data, outputs, execution_metadata
|
||||
node_run_result = node.run(
|
||||
variable_pool=workflow_run_state.variable_pool
|
||||
)
|
||||
try:
|
||||
# run node, result must have inputs, process_data, outputs, execution_metadata
|
||||
node_run_result = node.run(
|
||||
variable_pool=workflow_run_state.variable_pool
|
||||
)
|
||||
except Exception as e:
|
||||
node_run_result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
# node run failed
|
||||
|
|
@ -420,9 +425,6 @@ class WorkflowEngineManager:
|
|||
|
||||
raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
|
||||
|
||||
# set end node output if in chat
|
||||
self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result)
|
||||
|
||||
workflow_nodes_and_result.result = node_run_result
|
||||
|
||||
# node run success
|
||||
|
|
@ -453,29 +455,6 @@ class WorkflowEngineManager:
|
|||
|
||||
db.session.close()
|
||||
|
||||
def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState,
|
||||
node: BaseNode,
|
||||
node_run_result: NodeRunResult) -> None:
|
||||
"""
|
||||
Set end node output if in chat
|
||||
:param workflow_run_state: workflow run state
|
||||
:param node: current node
|
||||
:param node_run_result: node run result
|
||||
:return:
|
||||
"""
|
||||
if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END:
|
||||
workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2]
|
||||
if workflow_nodes_and_result_before_end:
|
||||
if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM:
|
||||
if not node_run_result.outputs:
|
||||
node_run_result.outputs = {}
|
||||
|
||||
node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text')
|
||||
elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER:
|
||||
if not node_run_result.outputs:
|
||||
node_run_result.outputs = {}
|
||||
|
||||
node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('answer')
|
||||
|
||||
def _append_variables_recursively(self, variable_pool: VariablePool,
|
||||
node_id: str,
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \
|
||||
ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity, FileUploadEntity
|
||||
from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity
|
||||
from core.file.file_obj import FileObj, FileType, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig, ChatModelMessage
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.model import Conversation
|
||||
|
||||
|
|
@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages():
|
|||
model_config_mock.model = 'gpt-3.5-turbo-instruct'
|
||||
|
||||
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt=prompt_template,
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(
|
||||
user="Human",
|
||||
assistant="Assistant"
|
||||
)
|
||||
prompt_template_config = CompletionModelPromptTemplate(
|
||||
text=prompt_template
|
||||
)
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(
|
||||
user="Human",
|
||||
assistant="Assistant"
|
||||
),
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
|
|
@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages():
|
|||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_completion_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=prompt_template_config,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
|
@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages():
|
|||
|
||||
|
||||
def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
|
||||
model_config_mock, memory_config, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = []
|
||||
query = "Hi2."
|
||||
|
|
@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
|
|||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
|
@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
|
|||
assert len(prompt_messages) == 6
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert prompt_messages[5].content == query
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
||||
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = []
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
|
@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
|||
assert len(prompt_messages) == 3
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
|
||||
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = [
|
||||
FileObj(
|
||||
|
|
@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
|||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
|
@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
|||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert isinstance(prompt_messages[3].content, list)
|
||||
assert len(prompt_messages[3].content) == 2
|
||||
|
|
@ -173,22 +181,31 @@ def get_chat_model_args():
|
|||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-4'
|
||||
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[
|
||||
AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM),
|
||||
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
|
||||
prompt_messages = [
|
||||
ChatModelMessage(
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM
|
||||
),
|
||||
ChatModelMessage(
|
||||
text="Hi.",
|
||||
role=PromptMessageRole.USER
|
||||
),
|
||||
ChatModelMessage(
|
||||
text="Hello!",
|
||||
role=PromptMessageRole.ASSISTANT
|
||||
)
|
||||
]
|
||||
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
|
||||
context = "I am superman."
|
||||
|
||||
return model_config_mock, prompt_template_entity, inputs, context
|
||||
return model_config_mock, memory_config, prompt_messages, inputs, context
|
||||
|
|
|
|||
Loading…
Reference in New Issue