fixed single retrival

This commit is contained in:
jyong 2024-03-29 19:44:26 +08:00
parent bab88efda9
commit 75ffdc9d3f
1 changed files with 47 additions and 3 deletions

View File

@ -4,6 +4,8 @@ from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@ -129,7 +131,8 @@ class QuestionClassifierNode(LLMNode):
:return:
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, memory)
rest_token = self._calculate_rest_token(node_data, query, model_config, context)
prompt_template = self._get_prompt_template(node_data, query, memory, rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
@ -144,8 +147,49 @@ class QuestionClassifierNode(LLMNode):
return prompt_messages, stop
def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str]) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str,
memory: Optional[TokenBufferMemory]) \
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
model_mode = ModelMode.value_of(node_data.model.mode)
classes = node_data.classes
@ -155,7 +199,7 @@ class QuestionClassifierNode(LLMNode):
input_text = query
memory_str = ''
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=2000,
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
prompt_messages = []
if model_mode == ModelMode.CHAT: