mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
lint fix
This commit is contained in:
parent
a4d6954d4f
commit
451ea5308f
563
api/core/app/apps/advanced_chat/generate_task_pipeline.py
Normal file
563
api/core/app/apps/advanced_chat/generate_task_pipeline.py
Normal file
@ -0,0 +1,563 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
AdvancedChatAppGenerateEntity,
|
||||||
|
InvokeFrom,
|
||||||
|
)
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
QueueAnnotationReplyEvent,
|
||||||
|
QueueErrorEvent,
|
||||||
|
QueueMessageFileEvent,
|
||||||
|
QueueMessageReplaceEvent,
|
||||||
|
QueueNodeFinishedEvent,
|
||||||
|
QueueNodeStartedEvent,
|
||||||
|
QueuePingEvent,
|
||||||
|
QueueRetrieverResourcesEvent,
|
||||||
|
QueueStopEvent,
|
||||||
|
QueueTextChunkEvent,
|
||||||
|
QueueWorkflowFinishedEvent,
|
||||||
|
QueueWorkflowStartedEvent,
|
||||||
|
)
|
||||||
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
|
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||||
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
from events.message_event import message_was_created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import Conversation, Message, MessageFile
|
||||||
|
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus
|
||||||
|
from services.annotation_service import AppAnnotationService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskState(BaseModel):
|
||||||
|
"""
|
||||||
|
TaskState entity
|
||||||
|
"""
|
||||||
|
answer: str = ""
|
||||||
|
metadata: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatAppGenerateTaskPipeline:
|
||||||
|
"""
|
||||||
|
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
conversation: Conversation,
|
||||||
|
message: Message) -> None:
|
||||||
|
"""
|
||||||
|
Initialize GenerateTaskPipeline.
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param conversation: conversation
|
||||||
|
:param message: message
|
||||||
|
"""
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._queue_manager = queue_manager
|
||||||
|
self._conversation = conversation
|
||||||
|
self._message = message
|
||||||
|
self._task_state = TaskState(
|
||||||
|
usage=LLMUsage.empty_usage()
|
||||||
|
)
|
||||||
|
self._start_at = time.perf_counter()
|
||||||
|
self._output_moderation_handler = self._init_output_moderation()
|
||||||
|
|
||||||
|
def process(self, stream: bool) -> Union[dict, Generator]:
|
||||||
|
"""
|
||||||
|
Process generate task pipeline.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if stream:
|
||||||
|
return self._process_stream_response()
|
||||||
|
else:
|
||||||
|
return self._process_blocking_response()
|
||||||
|
|
||||||
|
def _process_blocking_response(self) -> dict:
|
||||||
|
"""
|
||||||
|
Process blocking response.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for queue_message in self._queue_manager.listen():
|
||||||
|
event = queue_message.event
|
||||||
|
|
||||||
|
if isinstance(event, QueueErrorEvent):
|
||||||
|
raise self._handle_error(event)
|
||||||
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
|
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||||
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
|
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||||
|
if annotation:
|
||||||
|
account = annotation.account
|
||||||
|
self._task_state.metadata['annotation_reply'] = {
|
||||||
|
'id': annotation.id,
|
||||||
|
'account': {
|
||||||
|
'id': annotation.account_id,
|
||||||
|
'name': account.name if account else 'Dify user'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self._task_state.answer = annotation.content
|
||||||
|
elif isinstance(event, QueueNodeFinishedEvent):
|
||||||
|
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||||
|
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
|
||||||
|
if workflow_node_execution.node_type == 'llm': # todo use enum
|
||||||
|
outputs = workflow_node_execution.outputs_dict
|
||||||
|
usage_dict = outputs.get('usage', {})
|
||||||
|
self._task_state.metadata['usage'] = usage_dict
|
||||||
|
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||||
|
if isinstance(event, QueueWorkflowFinishedEvent):
|
||||||
|
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||||
|
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||||
|
outputs = workflow_run.outputs
|
||||||
|
self._task_state.answer = outputs.get('text', '')
|
||||||
|
else:
|
||||||
|
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
|
||||||
|
|
||||||
|
# response moderation
|
||||||
|
if self._output_moderation_handler:
|
||||||
|
self._output_moderation_handler.stop_thread()
|
||||||
|
|
||||||
|
self._task_state.answer = self._output_moderation_handler.moderation_completion(
|
||||||
|
completion=self._task_state.answer,
|
||||||
|
public_event=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save message
|
||||||
|
self._save_message()
|
||||||
|
|
||||||
|
response = {
|
||||||
|
'event': 'message',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'id': self._message.id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
'conversation_id': self._conversation.id,
|
||||||
|
'mode': self._conversation.mode,
|
||||||
|
'answer': self._task_state.answer,
|
||||||
|
'metadata': {},
|
||||||
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._task_state.metadata:
|
||||||
|
response['metadata'] = self._get_response_metadata()
|
||||||
|
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _process_stream_response(self) -> Generator:
|
||||||
|
"""
|
||||||
|
Process stream response.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for message in self._queue_manager.listen():
|
||||||
|
event = message.event
|
||||||
|
|
||||||
|
if isinstance(event, QueueErrorEvent):
|
||||||
|
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||||
|
yield self._yield_response(data)
|
||||||
|
break
|
||||||
|
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||||
|
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||||
|
response = {
|
||||||
|
'event': 'workflow_started',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'workflow_run_id': event.workflow_run_id,
|
||||||
|
'data': {
|
||||||
|
'id': workflow_run.id,
|
||||||
|
'workflow_id': workflow_run.workflow_id,
|
||||||
|
'created_at': int(workflow_run.created_at.timestamp())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueueNodeStartedEvent):
|
||||||
|
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||||
|
response = {
|
||||||
|
'event': 'node_started',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'workflow_run_id': workflow_node_execution.workflow_run_id,
|
||||||
|
'data': {
|
||||||
|
'id': workflow_node_execution.id,
|
||||||
|
'node_id': workflow_node_execution.node_id,
|
||||||
|
'index': workflow_node_execution.index,
|
||||||
|
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
|
||||||
|
'inputs': workflow_node_execution.inputs_dict,
|
||||||
|
'created_at': int(workflow_node_execution.created_at.timestamp())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueueNodeFinishedEvent):
|
||||||
|
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||||
|
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
|
||||||
|
if workflow_node_execution.node_type == 'llm': # todo use enum
|
||||||
|
outputs = workflow_node_execution.outputs_dict
|
||||||
|
usage_dict = outputs.get('usage', {})
|
||||||
|
self._task_state.metadata['usage'] = usage_dict
|
||||||
|
|
||||||
|
response = {
|
||||||
|
'event': 'node_finished',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'workflow_run_id': workflow_node_execution.workflow_run_id,
|
||||||
|
'data': {
|
||||||
|
'id': workflow_node_execution.id,
|
||||||
|
'node_id': workflow_node_execution.node_id,
|
||||||
|
'index': workflow_node_execution.index,
|
||||||
|
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
|
||||||
|
'inputs': workflow_node_execution.inputs_dict,
|
||||||
|
'process_data': workflow_node_execution.process_data_dict,
|
||||||
|
'outputs': workflow_node_execution.outputs_dict,
|
||||||
|
'status': workflow_node_execution.status,
|
||||||
|
'error': workflow_node_execution.error,
|
||||||
|
'elapsed_time': workflow_node_execution.elapsed_time,
|
||||||
|
'execution_metadata': workflow_node_execution.execution_metadata_dict,
|
||||||
|
'created_at': int(workflow_node_execution.created_at.timestamp()),
|
||||||
|
'finished_at': int(workflow_node_execution.finished_at.timestamp())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||||
|
if isinstance(event, QueueWorkflowFinishedEvent):
|
||||||
|
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||||
|
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||||
|
outputs = workflow_run.outputs
|
||||||
|
self._task_state.answer = outputs.get('text', '')
|
||||||
|
else:
|
||||||
|
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||||
|
data = self._error_to_stream_response_data(self._handle_error(err_event))
|
||||||
|
yield self._yield_response(data)
|
||||||
|
break
|
||||||
|
|
||||||
|
workflow_run_response = {
|
||||||
|
'event': 'workflow_finished',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'workflow_run_id': event.workflow_run_id,
|
||||||
|
'data': {
|
||||||
|
'id': workflow_run.id,
|
||||||
|
'workflow_id': workflow_run.workflow_id,
|
||||||
|
'status': workflow_run.status,
|
||||||
|
'outputs': workflow_run.outputs_dict,
|
||||||
|
'error': workflow_run.error,
|
||||||
|
'elapsed_time': workflow_run.elapsed_time,
|
||||||
|
'total_tokens': workflow_run.total_tokens,
|
||||||
|
'total_price': workflow_run.total_price,
|
||||||
|
'currency': workflow_run.currency,
|
||||||
|
'total_steps': workflow_run.total_steps,
|
||||||
|
'created_at': int(workflow_run.created_at.timestamp()),
|
||||||
|
'finished_at': int(workflow_run.finished_at.timestamp())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yield self._yield_response(workflow_run_response)
|
||||||
|
|
||||||
|
# response moderation
|
||||||
|
if self._output_moderation_handler:
|
||||||
|
self._output_moderation_handler.stop_thread()
|
||||||
|
|
||||||
|
self._task_state.answer = self._output_moderation_handler.moderation_completion(
|
||||||
|
completion=self._task_state.answer,
|
||||||
|
public_event=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self._output_moderation_handler = None
|
||||||
|
|
||||||
|
replace_response = {
|
||||||
|
'event': 'message_replace',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
'conversation_id': self._conversation.id,
|
||||||
|
'answer': self._task_state.answer,
|
||||||
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
yield self._yield_response(replace_response)
|
||||||
|
|
||||||
|
# Save message
|
||||||
|
self._save_message()
|
||||||
|
|
||||||
|
response = {
|
||||||
|
'event': 'message_end',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'id': self._message.id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
'conversation_id': self._conversation.id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._task_state.metadata:
|
||||||
|
response['metadata'] = self._get_response_metadata()
|
||||||
|
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
|
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||||
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
|
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||||
|
if annotation:
|
||||||
|
account = annotation.account
|
||||||
|
self._task_state.metadata['annotation_reply'] = {
|
||||||
|
'id': annotation.id,
|
||||||
|
'account': {
|
||||||
|
'id': annotation.account_id,
|
||||||
|
'name': account.name if account else 'Dify user'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self._task_state.answer = annotation.content
|
||||||
|
elif isinstance(event, QueueMessageFileEvent):
|
||||||
|
message_file: MessageFile = (
|
||||||
|
db.session.query(MessageFile)
|
||||||
|
.filter(MessageFile.id == event.message_file_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
# get extension
|
||||||
|
if '.' in message_file.url:
|
||||||
|
extension = f'.{message_file.url.split(".")[-1]}'
|
||||||
|
if len(extension) > 10:
|
||||||
|
extension = '.bin'
|
||||||
|
else:
|
||||||
|
extension = '.bin'
|
||||||
|
# add sign url
|
||||||
|
url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
|
||||||
|
|
||||||
|
if message_file:
|
||||||
|
response = {
|
||||||
|
'event': 'message_file',
|
||||||
|
'conversation_id': self._conversation.id,
|
||||||
|
'id': message_file.id,
|
||||||
|
'type': message_file.type,
|
||||||
|
'belongs_to': message_file.belongs_to or 'user',
|
||||||
|
'url': url
|
||||||
|
}
|
||||||
|
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueueTextChunkEvent):
|
||||||
|
delta_text = event.chunk_text
|
||||||
|
if delta_text is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._output_moderation_handler:
|
||||||
|
if self._output_moderation_handler.should_direct_output():
|
||||||
|
# stop subscribe new token when output moderation should direct output
|
||||||
|
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||||
|
self._queue_manager.publish_text_chunk(self._task_state.answer, PublishFrom.TASK_PIPELINE)
|
||||||
|
self._queue_manager.publish(
|
||||||
|
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||||
|
PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
self._output_moderation_handler.append_new_token(delta_text)
|
||||||
|
|
||||||
|
self._task_state.answer += delta_text
|
||||||
|
response = self._handle_chunk(delta_text)
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueueMessageReplaceEvent):
|
||||||
|
response = {
|
||||||
|
'event': 'message_replace',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
'conversation_id': self._conversation.id,
|
||||||
|
'answer': event.text,
|
||||||
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueuePingEvent):
|
||||||
|
yield "event: ping\n\n"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||||
|
"""
|
||||||
|
Get workflow run.
|
||||||
|
:param workflow_run_id: workflow run id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||||
|
|
||||||
|
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
|
||||||
|
"""
|
||||||
|
Get workflow node execution.
|
||||||
|
:param workflow_node_execution_id: workflow node execution id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()
|
||||||
|
|
||||||
|
def _save_message(self) -> None:
|
||||||
|
"""
|
||||||
|
Save message.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||||
|
|
||||||
|
self._message.answer = self._task_state.answer
|
||||||
|
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||||
|
|
||||||
|
if self._task_state.metadata and self._task_state.metadata.get('usage'):
|
||||||
|
usage = LLMUsage(**self._task_state.metadata['usage'])
|
||||||
|
|
||||||
|
self._message.message_tokens = usage.prompt_tokens
|
||||||
|
self._message.message_unit_price = usage.prompt_unit_price
|
||||||
|
self._message.message_price_unit = usage.prompt_price_unit
|
||||||
|
self._message.answer_tokens = usage.completion_tokens
|
||||||
|
self._message.answer_unit_price = usage.completion_unit_price
|
||||||
|
self._message.answer_price_unit = usage.completion_price_unit
|
||||||
|
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||||
|
self._message.total_price = usage.total_price
|
||||||
|
self._message.currency = usage.currency
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
message_was_created.send(
|
||||||
|
self._message,
|
||||||
|
application_generate_entity=self._application_generate_entity,
|
||||||
|
conversation=self._conversation,
|
||||||
|
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||||
|
extras=self._application_generate_entity.extras
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_chunk(self, text: str) -> dict:
|
||||||
|
"""
|
||||||
|
Handle completed event.
|
||||||
|
:param text: text
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
response = {
|
||||||
|
'event': 'message',
|
||||||
|
'id': self._message.id,
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
'conversation_id': self._conversation.id,
|
||||||
|
'answer': text,
|
||||||
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_error(self, event: QueueErrorEvent) -> Exception:
|
||||||
|
"""
|
||||||
|
Handle error event.
|
||||||
|
:param event: event
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
logger.debug("error: %s", event.error)
|
||||||
|
e = event.error
|
||||||
|
|
||||||
|
if isinstance(e, InvokeAuthorizationError):
|
||||||
|
return InvokeAuthorizationError('Incorrect API key provided')
|
||||||
|
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||||
|
return e
|
||||||
|
else:
|
||||||
|
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||||
|
|
||||||
|
def _error_to_stream_response_data(self, e: Exception) -> dict:
|
||||||
|
"""
|
||||||
|
Error to stream response.
|
||||||
|
:param e: exception
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
error_responses = {
|
||||||
|
ValueError: {'code': 'invalid_param', 'status': 400},
|
||||||
|
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
|
||||||
|
QuotaExceededError: {
|
||||||
|
'code': 'provider_quota_exceeded',
|
||||||
|
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||||
|
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||||
|
'status': 400
|
||||||
|
},
|
||||||
|
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
|
||||||
|
InvokeError: {'code': 'completion_request_error', 'status': 400}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine the response based on the type of exception
|
||||||
|
data = None
|
||||||
|
for k, v in error_responses.items():
|
||||||
|
if isinstance(e, k):
|
||||||
|
data = v
|
||||||
|
|
||||||
|
if data:
|
||||||
|
data.setdefault('message', getattr(e, 'description', str(e)))
|
||||||
|
else:
|
||||||
|
logging.error(e)
|
||||||
|
data = {
|
||||||
|
'code': 'internal_server_error',
|
||||||
|
'message': 'Internal Server Error, please contact support.',
|
||||||
|
'status': 500
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'event': 'error',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
**data
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_response_metadata(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get response metadata by invoke from.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# show_retrieve_source
|
||||||
|
if 'retriever_resources' in self._task_state.metadata:
|
||||||
|
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||||
|
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
|
||||||
|
else:
|
||||||
|
metadata['retriever_resources'] = []
|
||||||
|
for resource in self._task_state.metadata['retriever_resources']:
|
||||||
|
metadata['retriever_resources'].append({
|
||||||
|
'segment_id': resource['segment_id'],
|
||||||
|
'position': resource['position'],
|
||||||
|
'document_name': resource['document_name'],
|
||||||
|
'score': resource['score'],
|
||||||
|
'content': resource['content'],
|
||||||
|
})
|
||||||
|
# show annotation reply
|
||||||
|
if 'annotation_reply' in self._task_state.metadata:
|
||||||
|
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||||
|
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
|
||||||
|
|
||||||
|
# show usage
|
||||||
|
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||||
|
metadata['usage'] = self._task_state.metadata['usage']
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _yield_response(self, response: dict) -> str:
|
||||||
|
"""
|
||||||
|
Yield response.
|
||||||
|
:param response: response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return "data: " + json.dumps(response) + "\n\n"
|
||||||
|
|
||||||
|
def _init_output_moderation(self) -> Optional[OutputModeration]:
|
||||||
|
"""
|
||||||
|
Init output moderation.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
app_config = self._application_generate_entity.app_config
|
||||||
|
sensitive_word_avoidance = app_config.sensitive_word_avoidance
|
||||||
|
|
||||||
|
if sensitive_word_avoidance:
|
||||||
|
return OutputModeration(
|
||||||
|
tenant_id=app_config.tenant_id,
|
||||||
|
app_id=app_config.app_id,
|
||||||
|
rule=ModerationRule(
|
||||||
|
type=sensitive_word_avoidance.type,
|
||||||
|
config=sensitive_word_avoidance.config
|
||||||
|
),
|
||||||
|
on_message_replace_func=self._queue_manager.publish_message_replace
|
||||||
|
)
|
||||||
@ -14,12 +14,12 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
AnnotationReplyEvent,
|
|
||||||
QueueAgentMessageEvent,
|
QueueAgentMessageEvent,
|
||||||
QueueAgentThoughtEvent,
|
QueueAgentThoughtEvent,
|
||||||
|
QueueAnnotationReplyEvent,
|
||||||
QueueErrorEvent,
|
QueueErrorEvent,
|
||||||
|
QueueLLMChunkEvent,
|
||||||
QueueMessageEndEvent,
|
QueueMessageEndEvent,
|
||||||
QueueMessageEvent,
|
|
||||||
QueueMessageFileEvent,
|
QueueMessageFileEvent,
|
||||||
QueueMessageReplaceEvent,
|
QueueMessageReplaceEvent,
|
||||||
QueuePingEvent,
|
QueuePingEvent,
|
||||||
@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
|
|||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||||
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
@ -58,9 +59,9 @@ class TaskState(BaseModel):
|
|||||||
metadata: dict = {}
|
metadata: dict = {}
|
||||||
|
|
||||||
|
|
||||||
class GenerateTaskPipeline:
|
class EasyUIBasedGenerateTaskPipeline:
|
||||||
"""
|
"""
|
||||||
GenerateTaskPipeline is a class that generate stream output and state management for Application.
|
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, application_generate_entity: Union[
|
def __init__(self, application_generate_entity: Union[
|
||||||
@ -79,12 +80,13 @@ class GenerateTaskPipeline:
|
|||||||
:param message: message
|
:param message: message
|
||||||
"""
|
"""
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._model_config = application_generate_entity.model_config
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
self._conversation = conversation
|
self._conversation = conversation
|
||||||
self._message = message
|
self._message = message
|
||||||
self._task_state = TaskState(
|
self._task_state = TaskState(
|
||||||
llm_result=LLMResult(
|
llm_result=LLMResult(
|
||||||
model=self._application_generate_entity.model_config.model,
|
model=self._model_config.model,
|
||||||
prompt_messages=[],
|
prompt_messages=[],
|
||||||
message=AssistantPromptMessage(content=""),
|
message=AssistantPromptMessage(content=""),
|
||||||
usage=LLMUsage.empty_usage()
|
usage=LLMUsage.empty_usage()
|
||||||
@ -119,7 +121,7 @@ class GenerateTaskPipeline:
|
|||||||
raise self._handle_error(event)
|
raise self._handle_error(event)
|
||||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||||
elif isinstance(event, AnnotationReplyEvent):
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||||
if annotation:
|
if annotation:
|
||||||
account = annotation.account
|
account = annotation.account
|
||||||
@ -136,7 +138,7 @@ class GenerateTaskPipeline:
|
|||||||
if isinstance(event, QueueMessageEndEvent):
|
if isinstance(event, QueueMessageEndEvent):
|
||||||
self._task_state.llm_result = event.llm_result
|
self._task_state.llm_result = event.llm_result
|
||||||
else:
|
else:
|
||||||
model_config = self._application_generate_entity.model_config
|
model_config = self._model_config
|
||||||
model = model_config.model
|
model = model_config.model
|
||||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
@ -193,7 +195,7 @@ class GenerateTaskPipeline:
|
|||||||
'created_at': int(self._message.created_at.timestamp())
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._conversation.mode == 'chat':
|
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||||
response['conversation_id'] = self._conversation.id
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
if self._task_state.metadata:
|
if self._task_state.metadata:
|
||||||
@ -219,7 +221,7 @@ class GenerateTaskPipeline:
|
|||||||
if isinstance(event, QueueMessageEndEvent):
|
if isinstance(event, QueueMessageEndEvent):
|
||||||
self._task_state.llm_result = event.llm_result
|
self._task_state.llm_result = event.llm_result
|
||||||
else:
|
else:
|
||||||
model_config = self._application_generate_entity.model_config
|
model_config = self._model_config
|
||||||
model = model_config.model
|
model = model_config.model
|
||||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
@ -272,7 +274,7 @@ class GenerateTaskPipeline:
|
|||||||
'created_at': int(self._message.created_at.timestamp())
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._conversation.mode == 'chat':
|
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||||
replace_response['conversation_id'] = self._conversation.id
|
replace_response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
yield self._yield_response(replace_response)
|
yield self._yield_response(replace_response)
|
||||||
@ -287,7 +289,7 @@ class GenerateTaskPipeline:
|
|||||||
'message_id': self._message.id,
|
'message_id': self._message.id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._conversation.mode == 'chat':
|
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||||
response['conversation_id'] = self._conversation.id
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
if self._task_state.metadata:
|
if self._task_state.metadata:
|
||||||
@ -296,7 +298,7 @@ class GenerateTaskPipeline:
|
|||||||
yield self._yield_response(response)
|
yield self._yield_response(response)
|
||||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||||
elif isinstance(event, AnnotationReplyEvent):
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||||
if annotation:
|
if annotation:
|
||||||
account = annotation.account
|
account = annotation.account
|
||||||
@ -334,7 +336,7 @@ class GenerateTaskPipeline:
|
|||||||
'message_files': agent_thought.files
|
'message_files': agent_thought.files
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._conversation.mode == 'chat':
|
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||||
response['conversation_id'] = self._conversation.id
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
yield self._yield_response(response)
|
yield self._yield_response(response)
|
||||||
@ -365,12 +367,12 @@ class GenerateTaskPipeline:
|
|||||||
'url': url
|
'url': url
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._conversation.mode == 'chat':
|
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||||
response['conversation_id'] = self._conversation.id
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
yield self._yield_response(response)
|
yield self._yield_response(response)
|
||||||
|
|
||||||
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
|
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||||
chunk = event.chunk
|
chunk = event.chunk
|
||||||
delta_text = chunk.delta.message.content
|
delta_text = chunk.delta.message.content
|
||||||
if delta_text is None:
|
if delta_text is None:
|
||||||
@ -383,7 +385,7 @@ class GenerateTaskPipeline:
|
|||||||
if self._output_moderation_handler.should_direct_output():
|
if self._output_moderation_handler.should_direct_output():
|
||||||
# stop subscribe new token when output moderation should direct output
|
# stop subscribe new token when output moderation should direct output
|
||||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||||
self._queue_manager.publish_chunk_message(LLMResultChunk(
|
self._queue_manager.publish_llm_chunk(LLMResultChunk(
|
||||||
model=self._task_state.llm_result.model,
|
model=self._task_state.llm_result.model,
|
||||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
@ -411,7 +413,7 @@ class GenerateTaskPipeline:
|
|||||||
'created_at': int(self._message.created_at.timestamp())
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._conversation.mode == 'chat':
|
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||||
response['conversation_id'] = self._conversation.id
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
yield self._yield_response(response)
|
yield self._yield_response(response)
|
||||||
@ -452,8 +454,7 @@ class GenerateTaskPipeline:
|
|||||||
conversation=self._conversation,
|
conversation=self._conversation,
|
||||||
is_first_message=self._application_generate_entity.app_config.app_mode in [
|
is_first_message=self._application_generate_entity.app_config.app_mode in [
|
||||||
AppMode.AGENT_CHAT,
|
AppMode.AGENT_CHAT,
|
||||||
AppMode.CHAT,
|
AppMode.CHAT
|
||||||
AppMode.ADVANCED_CHAT
|
|
||||||
] and self._application_generate_entity.conversation_id is None,
|
] and self._application_generate_entity.conversation_id is None,
|
||||||
extras=self._application_generate_entity.extras
|
extras=self._application_generate_entity.extras
|
||||||
)
|
)
|
||||||
@ -473,7 +474,7 @@ class GenerateTaskPipeline:
|
|||||||
'created_at': int(self._message.created_at.timestamp())
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._conversation.mode == 'chat':
|
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||||
response['conversation_id'] = self._conversation.id
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@ -583,7 +584,7 @@ class GenerateTaskPipeline:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompts = []
|
prompts = []
|
||||||
if self._application_generate_entity.model_config.mode == 'chat':
|
if self._model_config.mode == ModelMode.CHAT.value:
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
if prompt_message.role == PromptMessageRole.USER:
|
if prompt_message.role == PromptMessageRole.USER:
|
||||||
role = 'user'
|
role = 'user'
|
||||||
Loading…
Reference in New Issue
Block a user