dify/api/core/app/completion/app_runner.py

174 lines
6.7 KiB
Python

import logging
from core.app.app_queue_manager import AppQueueManager
from core.app.base_app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import (
ApplicationGenerateEntity,
)
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db
from models.model import App, Message
logger = logging.getLogger(__name__)
class CompletionAppRunner(AppRunner):
"""
Completion Application Runner
"""
def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: AppQueueManager,
message: Message) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param message: message
:return:
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError("App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query
)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query
)
# moderation
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=application_generate_entity.tenant_id,
app_orchestration_config_entity=app_orchestration_config,
inputs=inputs,
query=query,
)
except ModerationException as e:
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get context from datasets
context = None
if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
application_generate_entity.user_id,
application_generate_entity.invoke_from
)
dataset_config = app_orchestration_config.dataset
if dataset_config and dataset_config.retrieve_config.query_variable:
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval()
context = dataset_retrieval.retrieve(
tenant_id=app_record.tenant_id,
model_config=app_orchestration_config.model_config,
config=dataset_config,
query=query,
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_orchestration_config.show_retrieve_source,
hit_callback=hit_callback
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
)