From d71eae8f93d2586f1f335643b9e77f3a2065ea99 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 21 Mar 2024 15:02:55 +0800 Subject: [PATCH] fix qc --- api/core/workflow/nodes/llm/llm_node.py | 36 +-- .../nodes/question_classifier/entities.py | 16 +- .../question_classifier_node.py | 206 ++---------------- 3 files changed, 35 insertions(+), 223 deletions(-) diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 27a8302537..cbb6d954b9 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -15,12 +15,13 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.llm.entities import LLMNodeData +from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig from extensions.ext_database import db from models.model import Conversation from models.provider import Provider, ProviderType @@ -64,10 +65,10 @@ class LLMNode(BaseNode): node_inputs['#context#'] = context # fetch model config - model_instance, model_config = self._fetch_model_config(node_data) + model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory - memory = self._fetch_memory(node_data, variable_pool, model_instance) + memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( @@ -89,7 +90,7 @@ class LLMNode(BaseNode): # handle invoke result result_text, usage = self._invoke_llm( - node_data=node_data, + node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop @@ -119,13 +120,13 @@ class LLMNode(BaseNode): } ) - def _invoke_llm(self, node_data: LLMNodeData, + def _invoke_llm(self, node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: list[str]) -> tuple[str, LLMUsage]: """ Invoke large language model - :param node_data: node data + :param node_data_model: node data model :param model_instance: model instance :param prompt_messages: prompt messages :param stop: stop @@ -135,7 +136,7 @@ class LLMNode(BaseNode): invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=node_data.model.completion_params, + model_parameters=node_data_model.completion_params, stop=stop, stream=True, user=self.user_id, @@ -286,14 +287,14 @@ class LLMNode(BaseNode): return None - def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config - :param node_data: node data + :param node_data_model: node data model :return: """ - model_name = node_data.model.name - provider_name = node_data.model.provider + model_name = node_data_model.name + provider_name = node_data_model.provider model_manager = ModelManager() model_instance = model_manager.get_model_instance( @@ -326,14 +327,14 @@ class LLMNode(BaseNode): raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config - completion_params = node_data.model.completion_params + completion_params = node_data_model.completion_params stop = [] if 'stop' in completion_params: stop = completion_params['stop'] del completion_params['stop'] # get model mode - model_mode = node_data.model.mode + model_mode = node_data_model.mode if not model_mode: raise ValueError("LLM mode is required.") @@ -356,26 +357,25 @@ class LLMNode(BaseNode): stop=stop, ) - def _fetch_memory(self, node_data: LLMNodeData, + def _fetch_memory(self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance) -> Optional[TokenBufferMemory]: """ Fetch memory - :param node_data: node data + :param node_data_memory: node data memory :param variable_pool: variable pool :return: """ - if not node_data.memory: + if not node_data_memory: return None # get conversation id - conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION]) + conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value]) if conversation_id is None: return None # get conversation conversation = db.session.query(Conversation).filter( - Conversation.tenant_id == self.tenant_id, Conversation.app_id == self.app_id, Conversation.id == conversation_id ).first() diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index f9a72f562b..9e660a88dd 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -2,6 +2,7 @@ from typing import Any, Optional from pydantic import BaseModel +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -23,21 +24,6 @@ class ClassConfig(BaseModel): name: str -class WindowConfig(BaseModel): - """ - Window Config. - """ - enabled: bool - size: int - - -class MemoryConfig(BaseModel): - """ - Memory Config. - """ - window: WindowConfig - - class QuestionClassifierNodeData(BaseNodeData): """ Knowledge retrieval Node Data. diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index a4696845ea..ceebfe2e25 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,25 +1,17 @@ import json -from collections.abc import Generator from typing import Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.question_classifier.template_prompts import ( @@ -31,28 +23,28 @@ from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_3, ) -from extensions.ext_database import db -from models.model import Conversation from models.workflow import WorkflowNodeExecutionStatus -class QuestionClassifierNode(BaseNode): +class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData _node_type = NodeType.QUESTION_CLASSIFIER def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) + node_data = cast(QuestionClassifierNodeData, node_data) + # extract variables query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) variables = { 'query': query } # fetch model config - model_instance, model_config = self._fetch_model_config(node_data) + model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory - memory = self._fetch_memory(node_data, variable_pool, model_instance) + memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) # fetch prompt messages - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = self._fetch_prompt( node_data=node_data, context='', query=query, @@ -62,7 +54,7 @@ class QuestionClassifierNode(BaseNode): # handle invoke result result_text, usage = self._invoke_llm( - node_data=node_data, + node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop @@ -117,126 +109,20 @@ class QuestionClassifierNode(BaseNode): return { "type": "question-classifier", "config": { - "instructions": "" # TODO + "instructions": "" } } - def _fetch_model_config(self, node_data: QuestionClassifierNodeData) \ - -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - """ - Fetch model config - :param node_data: node data - :return: - """ - model_name = node_data.model.name - provider_name = node_data.model.provider - - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name - ) - - provider_model_bundle = model_instance.provider_model_bundle - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_credentials = model_instance.credentials - - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM - ) - - if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") - - # model config - completion_params = node_data.model.completion_params - stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] - - # get model mode - model_mode = node_data.model.mode - if not model_mode: - raise ValueError("LLM mode is required.") - - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) - - if not model_schema: - raise ValueError(f"Model {model_name} not exist.") - - return model_instance, ModelConfigWithCredentialsEntity( - provider=provider_name, - model=model_name, - model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, - stop=stop, - ) - - def _fetch_memory(self, node_data: QuestionClassifierNodeData, - variable_pool: VariablePool, - model_instance: ModelInstance) -> Optional[TokenBufferMemory]: - """ - Fetch memory - :param node_data: node data - :param variable_pool: variable pool - :return: - """ - if not node_data.memory: - return None - - # get conversation id - conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION]) - if conversation_id is None: - return None - - # get conversation - conversation = db.session.query(Conversation).filter( - Conversation.tenant_id == self.tenant_id, - Conversation.app_id == self.app_id, - Conversation.id == conversation_id - ).first() - - if not conversation: - return None - - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) - - return memory - - def _fetch_prompt_messages(self, node_data: QuestionClassifierNodeData, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ + def _fetch_prompt(self, node_data: QuestionClassifierNodeData, + query: str, + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: """ - Fetch prompt messages + Fetch prompt :param node_data: node data - :param inputs: inputs - :param files: files + :param query: inputs :param context: context :param memory: memory :param model_config: model config @@ -310,63 +196,3 @@ class QuestionClassifierNode(BaseNode): return prompt_messages else: raise ValueError(f"Model mode {model_mode} not support.") - - def _invoke_llm(self, node_data: QuestionClassifierNodeData, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: list[str]) -> tuple[str, LLMUsage]: - """ - Invoke large language model - :param node_data: node data - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: - """ - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=node_data.model.completion_params, - stop=stop, - stream=True, - user=self.user_id, - ) - - # handle invoke result - text, usage = self._handle_invoke_result( - invoke_result=invoke_result - ) - - # deduct quota - LLMNode.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - - return text, usage - - def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: - """ - Handle invoke result - :param invoke_result: invoke result - :return: - """ - model = None - prompt_messages = [] - full_text = '' - usage = None - for result in invoke_result: - text = result.delta.message.content - full_text += text - - self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) - - if not model: - model = result.model - - if not prompt_messages: - prompt_messages = result.prompt_messages - - if not usage and result.delta.usage: - usage = result.delta.usage - - if not usage: - usage = LLMUsage.empty_usage() - - return full_text, usage