mirror of https://github.com/langgenius/dify.git
fix qc
This commit is contained in:
parent
a05fcedd61
commit
d71eae8f93
|
|
@ -15,12 +15,13 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
|
|
@ -64,10 +65,10 @@ class LLMNode(BaseNode):
|
|||
node_inputs['#context#'] = context
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data, variable_pool, model_instance)
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
|
|
@ -89,7 +90,7 @@ class LLMNode(BaseNode):
|
|||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
node_data=node_data,
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
|
|
@ -119,13 +120,13 @@ class LLMNode(BaseNode):
|
|||
}
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data: LLMNodeData,
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str]) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data: node data
|
||||
:param node_data_model: node data model
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
|
|
@ -135,7 +136,7 @@ class LLMNode(BaseNode):
|
|||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data.model.completion_params,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
|
|
@ -286,14 +287,14 @@ class LLMNode(BaseNode):
|
|||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data: node data
|
||||
:param node_data_model: node data model
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data.model.name
|
||||
provider_name = node_data.model.provider
|
||||
model_name = node_data_model.name
|
||||
provider_name = node_data_model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
|
|
@ -326,14 +327,14 @@ class LLMNode(BaseNode):
|
|||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = node_data.model.completion_params
|
||||
completion_params = node_data_model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data.model.mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
|
|
@ -356,26 +357,25 @@ class LLMNode(BaseNode):
|
|||
stop=stop,
|
||||
)
|
||||
|
||||
def _fetch_memory(self, node_data: LLMNodeData,
|
||||
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
"""
|
||||
Fetch memory
|
||||
:param node_data: node data
|
||||
:param node_data_memory: node data memory
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.memory:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION])
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
# get conversation
|
||||
conversation = db.session.query(Conversation).filter(
|
||||
Conversation.tenant_id == self.tenant_id,
|
||||
Conversation.app_id == self.app_id,
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Any, Optional
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
|
|
@ -23,21 +24,6 @@ class ClassConfig(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
class WindowConfig(BaseModel):
|
||||
"""
|
||||
Window Config.
|
||||
"""
|
||||
enabled: bool
|
||||
size: int
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""
|
||||
Memory Config.
|
||||
"""
|
||||
window: WindowConfig
|
||||
|
||||
|
||||
class QuestionClassifierNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
|
|
|
|||
|
|
@ -1,25 +1,17 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.question_classifier.template_prompts import (
|
||||
|
|
@ -31,28 +23,28 @@ from core.workflow.nodes.question_classifier.template_prompts import (
|
|||
QUESTION_CLASSIFIER_USER_PROMPT_2,
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class QuestionClassifierNode(BaseNode):
|
||||
class QuestionClassifierNode(LLMNode):
|
||||
_node_data_cls = QuestionClassifierNodeData
|
||||
_node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
|
||||
# extract variables
|
||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data, variable_pool, model_instance)
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_messages, stop = self._fetch_prompt(
|
||||
node_data=node_data,
|
||||
context='',
|
||||
query=query,
|
||||
|
|
@ -62,7 +54,7 @@ class QuestionClassifierNode(BaseNode):
|
|||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
node_data=node_data,
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
|
|
@ -117,126 +109,20 @@ class QuestionClassifierNode(BaseNode):
|
|||
return {
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"instructions": "" # TODO
|
||||
"instructions": ""
|
||||
}
|
||||
}
|
||||
|
||||
def _fetch_model_config(self, node_data: QuestionClassifierNodeData) \
|
||||
-> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data.model.name
|
||||
provider_name = node_data.model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider_name,
|
||||
model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = node_data.model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data.model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _fetch_memory(self, node_data: QuestionClassifierNodeData,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
"""
|
||||
Fetch memory
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
# get conversation
|
||||
conversation = db.session.query(Conversation).filter(
|
||||
Conversation.tenant_id == self.tenant_id,
|
||||
Conversation.app_id == self.app_id,
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_prompt_messages(self, node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
def _fetch_prompt(self, node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt messages
|
||||
Fetch prompt
|
||||
:param node_data: node data
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param query: inputs
|
||||
:param context: context
|
||||
:param memory: memory
|
||||
:param model_config: model config
|
||||
|
|
@ -310,63 +196,3 @@ class QuestionClassifierNode(BaseNode):
|
|||
return prompt_messages
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _invoke_llm(self, node_data: QuestionClassifierNodeData,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str]) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data: node data
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data.model.completion_params,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
LLMNode.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
full_text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
|
|
|||
Loading…
Reference in New Issue