This commit is contained in:
takatost 2024-03-21 15:02:55 +08:00
parent a05fcedd61
commit d71eae8f93
3 changed files with 35 additions and 223 deletions

View File

@ -15,12 +15,13 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
@ -64,10 +65,10 @@ class LLMNode(BaseNode):
node_inputs['#context#'] = context
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data, variable_pool, model_instance)
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
@ -89,7 +90,7 @@ class LLMNode(BaseNode):
# handle invoke result
result_text, usage = self._invoke_llm(
node_data=node_data,
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
@ -119,13 +120,13 @@ class LLMNode(BaseNode):
}
)
def _invoke_llm(self, node_data: LLMNodeData,
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data: node data
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
@ -135,7 +136,7 @@ class LLMNode(BaseNode):
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data.model.completion_params,
model_parameters=node_data_model.completion_params,
stop=stop,
stream=True,
user=self.user_id,
@ -286,14 +287,14 @@ class LLMNode(BaseNode):
return None
def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:param node_data_model: node data model
:return:
"""
model_name = node_data.model.name
provider_name = node_data.model.provider
model_name = node_data_model.name
provider_name = node_data_model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
@ -326,14 +327,14 @@ class LLMNode(BaseNode):
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data.model.completion_params
completion_params = node_data_model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = node_data.model.mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
@ -356,26 +357,25 @@ class LLMNode(BaseNode):
stop=stop,
)
def _fetch_memory(self, node_data: LLMNodeData,
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
variable_pool: VariablePool,
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data: node data
:param node_data_memory: node data memory
:param variable_pool: variable pool
:return:
"""
if not node_data.memory:
if not node_data_memory:
return None
# get conversation id
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION])
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
if conversation_id is None:
return None
# get conversation
conversation = db.session.query(Conversation).filter(
Conversation.tenant_id == self.tenant_id,
Conversation.app_id == self.app_id,
Conversation.id == conversation_id
).first()

View File

@ -2,6 +2,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -23,21 +24,6 @@ class ClassConfig(BaseModel):
name: str
class WindowConfig(BaseModel):
"""
Window Config.
"""
enabled: bool
size: int
class MemoryConfig(BaseModel):
"""
Memory Config.
"""
window: WindowConfig
class QuestionClassifierNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.

View File

@ -1,25 +1,17 @@
import json
from collections.abc import Generator
from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.question_classifier.template_prompts import (
@ -31,28 +23,28 @@ from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
from extensions.ext_database import db
from models.model import Conversation
from models.workflow import WorkflowNodeExecutionStatus
class QuestionClassifierNode(BaseNode):
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
_node_type = NodeType.QUESTION_CLASSIFIER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
node_data = cast(QuestionClassifierNodeData, node_data)
# extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
variables = {
'query': query
}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data, variable_pool, model_instance)
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
prompt_messages, stop = self._fetch_prompt(
node_data=node_data,
context='',
query=query,
@ -62,7 +54,7 @@ class QuestionClassifierNode(BaseNode):
# handle invoke result
result_text, usage = self._invoke_llm(
node_data=node_data,
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
@ -117,126 +109,20 @@ class QuestionClassifierNode(BaseNode):
return {
"type": "question-classifier",
"config": {
"instructions": "" # TODO
"instructions": ""
}
}
def _fetch_model_config(self, node_data: QuestionClassifierNodeData) \
-> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:return:
"""
model_name = node_data.model.name
provider_name = node_data.model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=provider_name,
model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name,
model_type=ModelType.LLM
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data.model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = node_data.model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(
model_name,
model_credentials
)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _fetch_memory(self, node_data: QuestionClassifierNodeData,
variable_pool: VariablePool,
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.memory:
return None
# get conversation id
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION])
if conversation_id is None:
return None
# get conversation
conversation = db.session.query(Conversation).filter(
Conversation.tenant_id == self.tenant_id,
Conversation.app_id == self.app_id,
Conversation.id == conversation_id
).first()
if not conversation:
return None
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
return memory
def _fetch_prompt_messages(self, node_data: QuestionClassifierNodeData,
query: str,
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
def _fetch_prompt(self, node_data: QuestionClassifierNodeData,
query: str,
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt messages
Fetch prompt
:param node_data: node data
:param inputs: inputs
:param files: files
:param query: inputs
:param context: context
:param memory: memory
:param model_config: model config
@ -310,63 +196,3 @@ class QuestionClassifierNode(BaseNode):
return prompt_messages
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _invoke_llm(self, node_data: QuestionClassifierNodeData,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data: node data
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data.model.completion_params,
stop=stop,
stream=True,
user=self.user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(
invoke_result=invoke_result
)
# deduct quota
LLMNode.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model = None
prompt_messages = []
full_text = ''
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage