mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 06:06:35 +08:00
fixed single retrival
This commit is contained in:
parent
bab88efda9
commit
75ffdc9d3f
@ -4,6 +4,8 @@ from typing import Optional, Union, cast
|
|||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
|
||||||
|
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||||
@ -129,7 +131,8 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||||
prompt_template = self._get_prompt_template(node_data, query, memory)
|
rest_token = self._calculate_rest_token(node_data, query, model_config, context)
|
||||||
|
prompt_template = self._get_prompt_template(node_data, query, memory, rest_token)
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
inputs={},
|
inputs={},
|
||||||
@ -144,8 +147,49 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
|
|
||||||
return prompt_messages, stop
|
return prompt_messages, stop
|
||||||
|
|
||||||
|
def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str,
|
||||||
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
context: Optional[str]) -> int:
|
||||||
|
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||||
|
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||||
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
inputs={},
|
||||||
|
query='',
|
||||||
|
files=[],
|
||||||
|
context=context,
|
||||||
|
memory_config=node_data.memory,
|
||||||
|
memory=None,
|
||||||
|
model_config=model_config
|
||||||
|
)
|
||||||
|
rest_tokens = 2000
|
||||||
|
|
||||||
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
|
if model_context_tokens:
|
||||||
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||||
|
model_config.model,
|
||||||
|
model_config.credentials,
|
||||||
|
prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
max_tokens = 0
|
||||||
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||||
|
if (parameter_rule.name == 'max_tokens'
|
||||||
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||||
|
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||||
|
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||||
|
|
||||||
|
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||||
|
rest_tokens = max(rest_tokens, 0)
|
||||||
|
|
||||||
|
return rest_tokens
|
||||||
|
|
||||||
def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str,
|
def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str,
|
||||||
memory: Optional[TokenBufferMemory]) \
|
memory: Optional[TokenBufferMemory],
|
||||||
|
max_token_limit: int = 2000) \
|
||||||
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
|
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
|
||||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||||
classes = node_data.classes
|
classes = node_data.classes
|
||||||
@ -155,7 +199,7 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
input_text = query
|
input_text = query
|
||||||
memory_str = ''
|
memory_str = ''
|
||||||
if memory:
|
if memory:
|
||||||
memory_str = memory.get_history_prompt_text(max_token_limit=2000,
|
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
|
||||||
message_limit=node_data.memory.window.size)
|
message_limit=node_data.memory.window.size)
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
if model_mode == ModelMode.CHAT:
|
if model_mode == ModelMode.CHAT:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user