mirror of https://github.com/langgenius/dify.git
fixed single retrival
This commit is contained in:
parent
bab88efda9
commit
75ffdc9d3f
|
|
@ -4,6 +4,8 @@ from typing import Optional, Union, cast
|
|||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
|
|
@ -129,7 +131,8 @@ class QuestionClassifierNode(LLMNode):
|
|||
:return:
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, memory)
|
||||
rest_token = self._calculate_rest_token(node_data, query, model_config, context)
|
||||
prompt_template = self._get_prompt_template(node_data, query, memory, rest_token)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
|
|
@ -144,8 +147,49 @@ class QuestionClassifierNode(LLMNode):
|
|||
|
||||
return prompt_messages, stop
|
||||
|
||||
def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str]) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str,
|
||||
memory: Optional[TokenBufferMemory]) \
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
|
|
@ -155,7 +199,7 @@ class QuestionClassifierNode(LLMNode):
|
|||
input_text = query
|
||||
memory_str = ''
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(max_token_limit=2000,
|
||||
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size)
|
||||
prompt_messages = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
|
|
|
|||
Loading…
Reference in New Issue