diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 905ee1f80d..c0f66a9b69 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -45,7 +45,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - variables: list[VariableSelector] + query_variable_selector: list[str] dataset_ids: list[str] retrieval_mode: Literal['single', 'multiple'] multiple_retrieval_config: MultipleRetrievalConfig diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index a501113dc3..b9756b4b63 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -42,16 +42,14 @@ class KnowledgeRetrievalNode(BaseNode): node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) # extract variables + query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) variables = { - variable_selector.variable: variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector) - for variable_selector in node_data.variables + 'query': query } - # retrieve knowledge try: outputs = self._fetch_dataset_retriever( - node_data=node_data, variables=variables + node_data=node_data, query=query ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -68,12 +66,12 @@ class KnowledgeRetrievalNode(BaseNode): error=str(e) ) - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[ + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[ dict[str, Any]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param node_data: node data - :param variables: variables + :param query: query """ tools = [] available_datasets = [] @@ -97,9 +95,9 @@ class KnowledgeRetrievalNode(BaseNode): available_datasets.append(dataset) all_documents = [] if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: - all_documents = self._single_retrieve(available_datasets, node_data, variables) + all_documents = self._single_retrieve(available_datasets, node_data, query) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: - all_documents = self._multiple_retrieve(available_datasets, node_data, variables) + all_documents = self._multiple_retrieve(available_datasets, node_data, query) document_score_list = {} for item in all_documents: @@ -169,7 +167,7 @@ class KnowledgeRetrievalNode(BaseNode): variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables } - def _single_retrieve(self, available_datasets, node_data, variables): + def _single_retrieve(self, available_datasets, node_data, query): tools = [] for dataset in available_datasets: description = dataset.description @@ -191,7 +189,7 @@ class KnowledgeRetrievalNode(BaseNode): model_instance, model_config = self._fetch_model_config(node_data) prompt_messages = [ SystemPromptMessage(content='You are a helpful AI assistant.'), - UserPromptMessage(content=variables['#query#']) + UserPromptMessage(content=query) ] result = model_instance.invoke_llm( prompt_messages=prompt_messages, @@ -227,7 +225,7 @@ class KnowledgeRetrievalNode(BaseNode): score_threshold = retrieval_model_config.get("score_threshold") results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, - query=variables['#query#'], + query=query, top_k=top_k, score_threshold=score_threshold, reranking_model=reranking_model) return results @@ -303,7 +301,7 @@ class KnowledgeRetrievalNode(BaseNode): stop=stop, ) - def _multiple_retrieve(self, available_datasets, node_data, variables): + def _multiple_retrieve(self, available_datasets, node_data, query): threads = [] all_documents = [] dataset_ids = [dataset.id for dataset in available_datasets] @@ -311,7 +309,7 @@ class KnowledgeRetrievalNode(BaseNode): retrieval_thread = threading.Thread(target=self._retriever, kwargs={ 'flask_app': current_app._get_current_object(), 'dataset_id': dataset.id, - 'query': variables['#query#'], + 'query': query, 'top_k': node_data.multiple_retrieval_config.top_k, 'all_documents': all_documents, }) @@ -329,7 +327,7 @@ class KnowledgeRetrievalNode(BaseNode): ) rerank_runner = RerankRunner(rerank_model_instance) - all_documents = rerank_runner.run(variables['#query#'], all_documents, + all_documents = rerank_runner.run(query, all_documents, node_data.multiple_retrieval_config.score_threshold, node_data.multiple_retrieval_config.top_k) diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py new file mode 100644 index 0000000000..a407ea01c9 --- /dev/null +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -0,0 +1,52 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + + +class ClassConfig(BaseModel): + """ + Class Config. + """ + id: str + name: str + + +class WindowConfig(BaseModel): + """ + Window Config. + """ + enabled: bool + size: int + + +class MemoryConfig(BaseModel): + """ + Memory Config. + """ + window: WindowConfig + + +class QuestionClassifierNodeData(BaseNodeData): + """ + Knowledge retrieval Node Data. + """ + query_variable_selector: list[str] + title: str + description: str + model: ModelConfig + classes: list[ClassConfig] + instruction: str + memory: MemoryConfig diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index f676b6372a..fdeb40c53d 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,9 +1,108 @@ -from typing import Optional - +import json +from typing import Optional, cast, Union +from collections.abc import Generator +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import QuotaUnit +from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData +from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_SYSTEM_PROMPT, \ + QUESTION_CLASSIFIER_USER_PROMPT_1, QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_3, \ + QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, \ + QUESTION_CLASSIFIER_COMPLETION_PROMPT +from extensions.ext_database import db +from models.model import Conversation +from models.provider import ProviderType, Provider +from core.model_runtime.utils.encoders import jsonable_encoder +from models.workflow import WorkflowNodeExecutionStatus class QuestionClassifierNode(BaseNode): + _node_data_cls = QuestionClassifierNodeData + _node_type = NodeType.QUESTION_CLASSIFIER + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) + # extract variables + query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) + variables = { + 'query': query + } + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + # fetch memory + memory = self._fetch_memory(node_data, variable_pool, model_instance) + # fetch prompt messages + prompt_messages, stop = self._fetch_prompt_messages( + node_data=node_data, + context='', + query=query, + memory=memory, + model_config=model_config + ) + + # handle invoke result + result_text, usage = self._invoke_llm( + node_data=node_data, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop + ) + try: + result_text_json = json.loads(result_text) + categories = result_text_json.get('categories', []) + process_data = { + 'model_mode': model_config.mode, + 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, + prompt_messages=prompt_messages + ), + 'usage': jsonable_encoder(usage), + 'topics': categories[0] if categories else '' + } + outputs = { + 'class_name': categories[0] if categories else '' + } + classes = node_data.classes + classes_map = {class_.name: class_.id for class_ in classes} + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=process_data, + outputs=outputs, + edge_source_handle=classes_map.get(categories[0], None) + ) + + except ValueError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e) + ) + + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + node_data = cast(cls._node_data_cls, node_data) + return { + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -17,3 +116,300 @@ class QuestionClassifierNode(BaseNode): "instructions": "" # TODO } } + + def _fetch_model_config(self, node_data: QuestionClassifierNodeData) \ + -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + model_name = node_data.model.name + provider_name = node_data.model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=provider_name, + model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, + model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data.model.completion_params + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = node_data.model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema( + model_name, + model_credentials + ) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _fetch_memory(self, node_data: QuestionClassifierNodeData, + variable_pool: VariablePool, + model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + """ + Fetch memory + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.memory: + return None + + # get conversation id + conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION]) + if conversation_id is None: + return None + + # get conversation + conversation = db.session.query(Conversation).filter( + Conversation.tenant_id == self.tenant_id, + Conversation.app_id == self.app_id, + Conversation.id == conversation_id + ).first() + + if not conversation: + return None + + memory = TokenBufferMemory( + conversation=conversation, + model_instance=model_instance + ) + + return memory + + def _fetch_prompt_messages(self, node_data: QuestionClassifierNodeData, + query: str, + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) \ + -> tuple[list[PromptMessage], Optional[list[str]]]: + """ + Fetch prompt messages + :param node_data: node data + :param inputs: inputs + :param files: files + :param context: context + :param memory: memory + :param model_config: model config + :return: + """ + prompt_transform = AdvancedPromptTransform() + prompt_template = self._get_prompt_template(node_data, query) + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query='', + files=[], + context=context, + memory_config=node_data.memory, + memory=memory, + model_config=model_config + ) + stop = model_config.stop + + return prompt_messages, stop + + def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str) \ + -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: + model_mode = ModelMode.value_of(node_data.model.mode) + classes = node_data.classes + class_names = [class_.name for class_ in classes] + class_names_str = ','.join(class_names) + instruction = node_data.instruction if node_data.instruction else '' + input_text = query + + prompt_messages = [] + if model_mode == ModelMode.CHAT: + system_prompt_messages = ChatModelMessage( + role=PromptMessageRole.SYSTEM, + text=QUESTION_CLASSIFIER_SYSTEM_PROMPT + ) + prompt_messages.append(system_prompt_messages) + user_prompt_message_1 = ChatModelMessage( + role=PromptMessageRole.USER, + text=QUESTION_CLASSIFIER_USER_PROMPT_1 + ) + prompt_messages.append(user_prompt_message_1) + assistant_prompt_message_1 = ChatModelMessage( + role=PromptMessageRole.ASSISTANT, + text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 + ) + prompt_messages.append(assistant_prompt_message_1) + user_prompt_message_2 = ChatModelMessage( + role=PromptMessageRole.USER, + text=QUESTION_CLASSIFIER_USER_PROMPT_2 + ) + prompt_messages.append(user_prompt_message_2) + assistant_prompt_message_2 = ChatModelMessage( + role=PromptMessageRole.ASSISTANT, + text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 + ) + prompt_messages.append(assistant_prompt_message_2) + user_prompt_message_3 = ChatModelMessage( + role=PromptMessageRole.USER, + text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text, categories=class_names_str, + classification_instructions=instruction) + ) + prompt_messages.append(user_prompt_message_3) + return prompt_messages + elif model_mode == ModelMode.COMPLETION: + prompt_messages.append(CompletionModelPromptTemplate( + text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(input_text=input_text, categories=class_names_str, + classification_instructions=instruction) + )) + + return prompt_messages + else: + raise ValueError(f"Model mode {model_mode} not support.") + + def _invoke_llm(self, node_data: QuestionClassifierNodeData, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str]) -> tuple[str, LLMUsage]: + """ + Invoke large language model + :param node_data: node data + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: + """ + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=node_data.model.completion_params, + stop=stop, + stream=True, + user=self.user_id, + ) + + # handle invoke result + text, usage = self._handle_invoke_result( + invoke_result=invoke_result + ) + + # deduct quota + self._deduct_llm_quota(model_instance=model_instance, usage=usage) + + return text, usage + + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + """ + Handle invoke result + :param invoke_result: invoke result + :return: + """ + model = None + prompt_messages = [] + full_text = '' + usage = None + for result in invoke_result: + text = result.delta.message.content + full_text += text + + self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) + + if not model: + model = result.model + + if not prompt_messages: + prompt_messages = result.prompt_messages + + if not usage and result.delta.usage: + usage = result.delta.usage + + if not usage: + usage = LLMUsage.empty_usage() + + return full_text, usage + + def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None: + """ + Deduct LLM quota + :param model_instance: model instance + :param usage: usage + :return: + """ + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = 1 + + if 'gpt-4' in model_instance.model: + used_quota = 20 + else: + used_quota = 1 + + if used_quota is not None: + db.session.query(Provider).filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_name == model_instance.provider, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used + ).update({'quota_used': Provider.quota_used + used_quota}) + db.session.commit() diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py new file mode 100644 index 0000000000..871fc8d3e9 --- /dev/null +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -0,0 +1,62 @@ + + +QUESTION_CLASSIFIER_SYSTEM_PROMPT = ( + '### Job Description', + 'You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.', + '### Task', + 'Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output.Additionally, you need to extract the key words from the text that are related to the classification.', + '### Format', + 'The input text is in the variable text_field.Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination.Classification instructions may be included to improve the classification accuracy.', + '### Constraint', + 'DO NOT include anything other than the JSON array in your response.' +) + +QUESTION_CLASSIFIER_USER_PROMPT_1 = ( + '{ "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],', + '"categories": ["Customer Service, Satisfaction, Sales, Product"],', + '"classification_instructions": ["classify the text based on the feedback provided by customer"]}```JSON' +) + +QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = ( + '{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],', + '"categories": ["Customer Service"]}```' +) + +QUESTION_CLASSIFIER_USER_PROMPT_2 = ( + '{"input_text": ["bad service, slow to bring the food"],', + '"categories": ["Food Quality, Experience, Price" ], ', + '"classification_instructions": []}```JSON' +) + +QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = ( + '{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],', + '"categories": ["Experience""]}```' +) + +QUESTION_CLASSIFIER_USER_PROMPT_3 = ( + '{"input_text": ["{input_text}"],', + '"categories": ["{categories}" ], ', + '"classification_instructions": ["{classification_instructions}"]}```JSON' +) + +QUESTION_CLASSIFIER_COMPLETION_PROMPT = """ +### Job Description +You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. +### Task +Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. +### Format +The input text is in the variable text_field. Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination. Classification instructions may be included to improve the classification accuracy. +### Constraint +DO NOT include anything other than the JSON array in your response. +### Example +Input: +{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],"categories": ["Customer Service, Satisfaction, Sales, Product"], "classification_instructions": ["classify the text based on the feedback provided by customer"]} +{"input_text": ["bad service, slow to bring the food"],"categories": ["Food Quality, Experience, Price" ], "classification_instructions": []} +Output: +{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]} +{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"categories": ["Experience""]} +### Memory +Here is the chat histories between human and assistant, inside XML tags. +### User Input +{"input_text" : [{{input_text}}], "class" : [{{class}}],"classification_instruction" : [{{classification_instructions}}]} +""" \ No newline at end of file