mirror of https://github.com/langgenius/dify.git
lint fix
This commit is contained in:
parent
be709d4b84
commit
e9004a06a5
|
|
@ -0,0 +1,563 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
answer: str = ""
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._task_state = TaskState(
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
|
||||
def process(self, stream: bool) -> Union[dict, Generator]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
return self._process_blocking_response()
|
||||
|
||||
def _process_blocking_response(self) -> dict:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
}
|
||||
|
||||
self._task_state.answer = annotation.content
|
||||
elif isinstance(event, QueueNodeFinishedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
|
||||
if workflow_node_execution.node_type == 'llm': # todo use enum
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
usage_dict = outputs.get('usage', {})
|
||||
self._task_state.metadata['usage'] = usage_dict
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueWorkflowFinishedEvent):
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.answer = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.answer,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
# Save message
|
||||
self._save_message()
|
||||
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
'message_id': self._message.id,
|
||||
'conversation_id': self._conversation.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': self._task_state.answer,
|
||||
'metadata': {},
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._get_response_metadata()
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
def _process_stream_response(self) -> Generator:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
response = {
|
||||
'event': 'workflow_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
'created_at': int(workflow_run.created_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
response = {
|
||||
'event': 'node_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': workflow_node_execution.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_node_execution.id,
|
||||
'node_id': workflow_node_execution.node_id,
|
||||
'index': workflow_node_execution.index,
|
||||
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
|
||||
'inputs': workflow_node_execution.inputs_dict,
|
||||
'created_at': int(workflow_node_execution.created_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeFinishedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
|
||||
if workflow_node_execution.node_type == 'llm': # todo use enum
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
usage_dict = outputs.get('usage', {})
|
||||
self._task_state.metadata['usage'] = usage_dict
|
||||
|
||||
response = {
|
||||
'event': 'node_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': workflow_node_execution.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_node_execution.id,
|
||||
'node_id': workflow_node_execution.node_id,
|
||||
'index': workflow_node_execution.index,
|
||||
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
|
||||
'inputs': workflow_node_execution.inputs_dict,
|
||||
'process_data': workflow_node_execution.process_data_dict,
|
||||
'outputs': workflow_node_execution.outputs_dict,
|
||||
'status': workflow_node_execution.status,
|
||||
'error': workflow_node_execution.error,
|
||||
'elapsed_time': workflow_node_execution.elapsed_time,
|
||||
'execution_metadata': workflow_node_execution.execution_metadata_dict,
|
||||
'created_at': int(workflow_node_execution.created_at.timestamp()),
|
||||
'finished_at': int(workflow_node_execution.finished_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueWorkflowFinishedEvent):
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
data = self._error_to_stream_response_data(self._handle_error(err_event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
|
||||
workflow_run_response = {
|
||||
'event': 'workflow_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
'status': workflow_run.status,
|
||||
'outputs': workflow_run.outputs_dict,
|
||||
'error': workflow_run.error,
|
||||
'elapsed_time': workflow_run.elapsed_time,
|
||||
'total_tokens': workflow_run.total_tokens,
|
||||
'total_price': workflow_run.total_price,
|
||||
'currency': workflow_run.currency,
|
||||
'total_steps': workflow_run.total_steps,
|
||||
'created_at': int(workflow_run.created_at.timestamp()),
|
||||
'finished_at': int(workflow_run.finished_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(workflow_run_response)
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.answer = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.answer,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
||||
replace_response = {
|
||||
'event': 'message_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'conversation_id': self._conversation.id,
|
||||
'answer': self._task_state.answer,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
yield self._yield_response(replace_response)
|
||||
|
||||
# Save message
|
||||
self._save_message()
|
||||
|
||||
response = {
|
||||
'event': 'message_end',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
'message_id': self._message.id,
|
||||
'conversation_id': self._conversation.id,
|
||||
}
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._get_response_metadata()
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
}
|
||||
|
||||
self._task_state.answer = annotation.content
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
message_file: MessageFile = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = '.bin'
|
||||
else:
|
||||
extension = '.bin'
|
||||
# add sign url
|
||||
url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
|
||||
|
||||
if message_file:
|
||||
response = {
|
||||
'event': 'message_file',
|
||||
'conversation_id': self._conversation.id,
|
||||
'id': message_file.id,
|
||||
'type': message_file.type,
|
||||
'belongs_to': message_file.belongs_to or 'user',
|
||||
'url': url
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.chunk_text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish_text_chunk(self._task_state.answer, PublishFrom.TASK_PIPELINE)
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
continue
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(delta_text)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
response = self._handle_chunk(delta_text)
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
response = {
|
||||
'event': 'message_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'conversation_id': self._conversation.id,
|
||||
'answer': event.text,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield "event: ping\n\n"
|
||||
else:
|
||||
continue
|
||||
|
||||
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Get workflow run.
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Get workflow node execution.
|
||||
:param workflow_node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()
|
||||
|
||||
def _save_message(self) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
|
||||
self._message.answer = self._task_state.answer
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
|
||||
if self._task_state.metadata and self._task_state.metadata.get('usage'):
|
||||
usage = LLMUsage(**self._task_state.metadata['usage'])
|
||||
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.total_price = usage.total_price
|
||||
self._message.currency = usage.currency
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
)
|
||||
|
||||
def _handle_chunk(self, text: str) -> dict:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'id': self._message.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'conversation_id': self._conversation.id,
|
||||
'answer': text,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent) -> Exception:
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
return InvokeAuthorizationError('Incorrect API key provided')
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
return e
|
||||
else:
|
||||
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
def _error_to_stream_response_data(self, e: Exception) -> dict:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses = {
|
||||
ValueError: {'code': 'invalid_param', 'status': 400},
|
||||
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
|
||||
QuotaExceededError: {
|
||||
'code': 'provider_quota_exceeded',
|
||||
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
'status': 400
|
||||
},
|
||||
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
|
||||
InvokeError: {'code': 'completion_request_error', 'status': 400}
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
if data:
|
||||
data.setdefault('message', getattr(e, 'description', str(e)))
|
||||
else:
|
||||
logging.error(e)
|
||||
data = {
|
||||
'code': 'internal_server_error',
|
||||
'message': 'Internal Server Error, please contact support.',
|
||||
'status': 500
|
||||
}
|
||||
|
||||
return {
|
||||
'event': 'error',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
**data
|
||||
}
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
Get response metadata by invoke from.
|
||||
:return:
|
||||
"""
|
||||
metadata = {}
|
||||
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in self._task_state.metadata:
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
|
||||
else:
|
||||
metadata['retriever_resources'] = []
|
||||
for resource in self._task_state.metadata['retriever_resources']:
|
||||
metadata['retriever_resources'].append({
|
||||
'segment_id': resource['segment_id'],
|
||||
'position': resource['position'],
|
||||
'document_name': resource['document_name'],
|
||||
'score': resource['score'],
|
||||
'content': resource['content'],
|
||||
})
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in self._task_state.metadata:
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
|
||||
|
||||
# show usage
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['usage'] = self._task_state.metadata['usage']
|
||||
|
||||
return metadata
|
||||
|
||||
def _yield_response(self, response: dict) -> str:
|
||||
"""
|
||||
Yield response.
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return "data: " + json.dumps(response) + "\n\n"
|
||||
|
||||
def _init_output_moderation(self) -> Optional[OutputModeration]:
|
||||
"""
|
||||
Init output moderation.
|
||||
:return:
|
||||
"""
|
||||
app_config = self._application_generate_entity.app_config
|
||||
sensitive_word_avoidance = app_config.sensitive_word_avoidance
|
||||
|
||||
if sensitive_word_avoidance:
|
||||
return OutputModeration(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
on_message_replace_func=self._queue_manager.publish_message_replace
|
||||
)
|
||||
|
|
@ -14,12 +14,12 @@ from core.app.entities.app_invoke_entities import (
|
|||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AnnotationReplyEvent,
|
||||
QueueAgentMessageEvent,
|
||||
QueueAgentThoughtEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueuePingEvent,
|
||||
|
|
@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
|
|||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from events.message_event import message_was_created
|
||||
|
|
@ -58,9 +59,9 @@ class TaskState(BaseModel):
|
|||
metadata: dict = {}
|
||||
|
||||
|
||||
class GenerateTaskPipeline:
|
||||
class EasyUIBasedGenerateTaskPipeline:
|
||||
"""
|
||||
GenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(self, application_generate_entity: Union[
|
||||
|
|
@ -79,12 +80,13 @@ class GenerateTaskPipeline:
|
|||
:param message: message
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._model_config = application_generate_entity.model_config
|
||||
self._queue_manager = queue_manager
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._task_state = TaskState(
|
||||
llm_result=LLMResult(
|
||||
model=self._application_generate_entity.model_config.model,
|
||||
model=self._model_config.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=LLMUsage.empty_usage()
|
||||
|
|
@ -119,7 +121,7 @@ class GenerateTaskPipeline:
|
|||
raise self._handle_error(event)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, AnnotationReplyEvent):
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
|
|
@ -136,7 +138,7 @@ class GenerateTaskPipeline:
|
|||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
model_config = self._application_generate_entity.model_config
|
||||
model_config = self._model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
|
@ -193,7 +195,7 @@ class GenerateTaskPipeline:
|
|||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
|
|
@ -219,7 +221,7 @@ class GenerateTaskPipeline:
|
|||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
model_config = self._application_generate_entity.model_config
|
||||
model_config = self._model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
|
@ -272,7 +274,7 @@ class GenerateTaskPipeline:
|
|||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||
replace_response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(replace_response)
|
||||
|
|
@ -287,7 +289,7 @@ class GenerateTaskPipeline:
|
|||
'message_id': self._message.id,
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
|
|
@ -296,7 +298,7 @@ class GenerateTaskPipeline:
|
|||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, AnnotationReplyEvent):
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
|
|
@ -334,7 +336,7 @@ class GenerateTaskPipeline:
|
|||
'message_files': agent_thought.files
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
|
@ -365,12 +367,12 @@ class GenerateTaskPipeline:
|
|||
'url': url
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
||||
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
|
||||
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
|
|
@ -383,7 +385,7 @@ class GenerateTaskPipeline:
|
|||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish_chunk_message(LLMResultChunk(
|
||||
self._queue_manager.publish_llm_chunk(LLMResultChunk(
|
||||
model=self._task_state.llm_result.model,
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
|
|
@ -411,7 +413,7 @@ class GenerateTaskPipeline:
|
|||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
|
@ -452,8 +454,7 @@ class GenerateTaskPipeline:
|
|||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in [
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.CHAT,
|
||||
AppMode.ADVANCED_CHAT
|
||||
AppMode.CHAT
|
||||
] and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
)
|
||||
|
|
@ -473,7 +474,7 @@ class GenerateTaskPipeline:
|
|||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
if self._conversation.mode != AppMode.COMPLETION.value:
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
return response
|
||||
|
|
@ -583,7 +584,7 @@ class GenerateTaskPipeline:
|
|||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if self._application_generate_entity.model_config.mode == 'chat':
|
||||
if self._model_config.mode == ModelMode.CHAT.value:
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = 'user'
|
||||
Loading…
Reference in New Issue