From 451ea5308f1fb803fe06bb5c0a9c65247d584b3d Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 02:04:46 +0800 Subject: [PATCH] lint fix --- .../advanced_chat/generate_task_pipeline.py | 563 ++++++++++++++++++ .../easy_ui_based_generate_task_pipeline.py} | 43 +- 2 files changed, 585 insertions(+), 21 deletions(-) create mode 100644 api/core/app/apps/advanced_chat/generate_task_pipeline.py rename api/core/app/{generate_task_pipeline.py => apps/easy_ui_based_generate_task_pipeline.py} (95%) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py new file mode 100644 index 0000000000..d443435fc1 --- /dev/null +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -0,0 +1,563 @@ +import json +import logging +import time +from collections.abc import Generator +from typing import Optional, Union + +from pydantic import BaseModel + +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.tools.tool_file_manager import ToolFileManager +from events.message_event import message_was_created +from extensions.ext_database import db +from models.model import Conversation, Message, MessageFile +from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus +from services.annotation_service import AppAnnotationService + +logger = logging.getLogger(__name__) + + +class TaskState(BaseModel): + """ + TaskState entity + """ + answer: str = "" + metadata: dict = {} + + +class AdvancedChatAppGenerateTaskPipeline: + """ + AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Initialize GenerateTaskPipeline. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + """ + self._application_generate_entity = application_generate_entity + self._queue_manager = queue_manager + self._conversation = conversation + self._message = message + self._task_state = TaskState( + usage=LLMUsage.empty_usage() + ) + self._start_at = time.perf_counter() + self._output_moderation_handler = self._init_output_moderation() + + def process(self, stream: bool) -> Union[dict, Generator]: + """ + Process generate task pipeline. + :return: + """ + if stream: + return self._process_stream_response() + else: + return self._process_blocking_response() + + def _process_blocking_response(self) -> dict: + """ + Process blocking response. + :return: + """ + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueueErrorEvent): + raise self._handle_error(event) + elif isinstance(event, QueueRetrieverResourcesEvent): + self._task_state.metadata['retriever_resources'] = event.retriever_resources + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata['annotation_reply'] = { + 'id': annotation.id, + 'account': { + 'id': annotation.account_id, + 'name': account.name if account else 'Dify user' + } + } + + self._task_state.answer = annotation.content + elif isinstance(event, QueueNodeFinishedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: + if workflow_node_execution.node_type == 'llm': # todo use enum + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueWorkflowFinishedEvent): + workflow_run = self._get_workflow_run(event.workflow_run_id) + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + # Save message + self._save_message() + + response = { + 'event': 'message', + 'task_id': self._application_generate_entity.task_id, + 'id': self._message.id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'mode': self._conversation.mode, + 'answer': self._task_state.answer, + 'metadata': {}, + 'created_at': int(self._message.created_at.timestamp()) + } + + if self._task_state.metadata: + response['metadata'] = self._get_response_metadata() + + return response + else: + continue + + def _process_stream_response(self) -> Generator: + """ + Process stream response. + :return: + """ + for message in self._queue_manager.listen(): + event = message.event + + if isinstance(event, QueueErrorEvent): + data = self._error_to_stream_response_data(self._handle_error(event)) + yield self._yield_response(data) + break + elif isinstance(event, QueueWorkflowStartedEvent): + workflow_run = self._get_workflow_run(event.workflow_run_id) + response = { + 'event': 'workflow_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'created_at': int(workflow_run.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeStartedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + response = { + 'event': 'node_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeFinishedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: + if workflow_node_execution.node_type == 'llm': # todo use enum + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + + response = { + 'event': 'node_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'process_data': workflow_node_execution.process_data_dict, + 'outputs': workflow_node_execution.outputs_dict, + 'status': workflow_node_execution.status, + 'error': workflow_node_execution.error, + 'elapsed_time': workflow_node_execution.elapsed_time, + 'execution_metadata': workflow_node_execution.execution_metadata_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()), + 'finished_at': int(workflow_node_execution.finished_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueWorkflowFinishedEvent): + workflow_run = self._get_workflow_run(event.workflow_run_id) + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + data = self._error_to_stream_response_data(self._handle_error(err_event)) + yield self._yield_response(data) + break + + workflow_run_response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_price': workflow_run.total_price, + 'currency': workflow_run.currency, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) + } + } + + yield self._yield_response(workflow_run_response) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + self._output_moderation_handler = None + + replace_response = { + 'event': 'message_replace', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': self._task_state.answer, + 'created_at': int(self._message.created_at.timestamp()) + } + + yield self._yield_response(replace_response) + + # Save message + self._save_message() + + response = { + 'event': 'message_end', + 'task_id': self._application_generate_entity.task_id, + 'id': self._message.id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + } + + if self._task_state.metadata: + response['metadata'] = self._get_response_metadata() + + yield self._yield_response(response) + elif isinstance(event, QueueRetrieverResourcesEvent): + self._task_state.metadata['retriever_resources'] = event.retriever_resources + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata['annotation_reply'] = { + 'id': annotation.id, + 'account': { + 'id': annotation.account_id, + 'name': account.name if account else 'Dify user' + } + } + + self._task_state.answer = annotation.content + elif isinstance(event, QueueMessageFileEvent): + message_file: MessageFile = ( + db.session.query(MessageFile) + .filter(MessageFile.id == event.message_file_id) + .first() + ) + # get extension + if '.' in message_file.url: + extension = f'.{message_file.url.split(".")[-1]}' + if len(extension) > 10: + extension = '.bin' + else: + extension = '.bin' + # add sign url + url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) + + if message_file: + response = { + 'event': 'message_file', + 'conversation_id': self._conversation.id, + 'id': message_file.id, + 'type': message_file.type, + 'belongs_to': message_file.belongs_to or 'user', + 'url': url + } + + yield self._yield_response(response) + elif isinstance(event, QueueTextChunkEvent): + delta_text = event.chunk_text + if delta_text is None: + continue + + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._output_moderation_handler.get_final_output() + self._queue_manager.publish_text_chunk(self._task_state.answer, PublishFrom.TASK_PIPELINE) + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + continue + else: + self._output_moderation_handler.append_new_token(delta_text) + + self._task_state.answer += delta_text + response = self._handle_chunk(delta_text) + yield self._yield_response(response) + elif isinstance(event, QueueMessageReplaceEvent): + response = { + 'event': 'message_replace', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': event.text, + 'created_at': int(self._message.created_at.timestamp()) + } + + yield self._yield_response(response) + elif isinstance(event, QueuePingEvent): + yield "event: ping\n\n" + else: + continue + + def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Get workflow run. + :param workflow_run_id: workflow run id + :return: + """ + return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + + def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: + """ + Get workflow node execution. + :param workflow_node_execution_id: workflow node execution id + :return: + """ + return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + + def _save_message(self) -> None: + """ + Save message. + :return: + """ + self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + + self._message.answer = self._task_state.answer + self._message.provider_response_latency = time.perf_counter() - self._start_at + + if self._task_state.metadata and self._task_state.metadata.get('usage'): + usage = LLMUsage(**self._task_state.metadata['usage']) + + self._message.message_tokens = usage.prompt_tokens + self._message.message_unit_price = usage.prompt_unit_price + self._message.message_price_unit = usage.prompt_price_unit + self._message.answer_tokens = usage.completion_tokens + self._message.answer_unit_price = usage.completion_unit_price + self._message.answer_price_unit = usage.completion_price_unit + self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.total_price = usage.total_price + self._message.currency = usage.currency + + db.session.commit() + + message_was_created.send( + self._message, + application_generate_entity=self._application_generate_entity, + conversation=self._conversation, + is_first_message=self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras + ) + + def _handle_chunk(self, text: str) -> dict: + """ + Handle completed event. + :param text: text + :return: + """ + response = { + 'event': 'message', + 'id': self._message.id, + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': text, + 'created_at': int(self._message.created_at.timestamp()) + } + + return response + + def _handle_error(self, event: QueueErrorEvent) -> Exception: + """ + Handle error event. + :param event: event + :return: + """ + logger.debug("error: %s", event.error) + e = event.error + + if isinstance(e, InvokeAuthorizationError): + return InvokeAuthorizationError('Incorrect API key provided') + elif isinstance(e, InvokeError) or isinstance(e, ValueError): + return e + else: + return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + + def _error_to_stream_response_data(self, e: Exception) -> dict: + """ + Error to stream response. + :param e: exception + :return: + """ + error_responses = { + ValueError: {'code': 'invalid_param', 'status': 400}, + ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + QuotaExceededError: { + 'code': 'provider_quota_exceeded', + 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + 'status': 400 + }, + ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, + InvokeError: {'code': 'completion_request_error', 'status': 400} + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + data.setdefault('message', getattr(e, 'description', str(e))) + else: + logging.error(e) + data = { + 'code': 'internal_server_error', + 'message': 'Internal Server Error, please contact support.', + 'status': 500 + } + + return { + 'event': 'error', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + **data + } + + def _get_response_metadata(self) -> dict: + """ + Get response metadata by invoke from. + :return: + """ + metadata = {} + + # show_retrieve_source + if 'retriever_resources' in self._task_state.metadata: + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] + else: + metadata['retriever_resources'] = [] + for resource in self._task_state.metadata['retriever_resources']: + metadata['retriever_resources'].append({ + 'segment_id': resource['segment_id'], + 'position': resource['position'], + 'document_name': resource['document_name'], + 'score': resource['score'], + 'content': resource['content'], + }) + # show annotation reply + if 'annotation_reply' in self._task_state.metadata: + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] + + # show usage + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['usage'] = self._task_state.metadata['usage'] + + return metadata + + def _yield_response(self, response: dict) -> str: + """ + Yield response. + :param response: response + :return: + """ + return "data: " + json.dumps(response) + "\n\n" + + def _init_output_moderation(self) -> Optional[OutputModeration]: + """ + Init output moderation. + :return: + """ + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance + + if sensitive_word_avoidance: + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, + rule=ModerationRule( + type=sensitive_word_avoidance.type, + config=sensitive_word_avoidance.config + ), + on_message_replace_func=self._queue_manager.publish_message_replace + ) diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py similarity index 95% rename from api/core/app/generate_task_pipeline.py rename to api/core/app/apps/easy_ui_based_generate_task_pipeline.py index 60dfc5cdad..80596668b8 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py @@ -14,12 +14,12 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import ( - AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, QueueErrorEvent, + QueueLLMChunkEvent, QueueMessageEndEvent, - QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, QueuePingEvent, @@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created @@ -58,9 +59,9 @@ class TaskState(BaseModel): metadata: dict = {} -class GenerateTaskPipeline: +class EasyUIBasedGenerateTaskPipeline: """ - GenerateTaskPipeline is a class that generate stream output and state management for Application. + EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ def __init__(self, application_generate_entity: Union[ @@ -79,12 +80,13 @@ class GenerateTaskPipeline: :param message: message """ self._application_generate_entity = application_generate_entity + self._model_config = application_generate_entity.model_config self._queue_manager = queue_manager self._conversation = conversation self._message = message self._task_state = TaskState( llm_result=LLMResult( - model=self._application_generate_entity.model_config.model, + model=self._model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), usage=LLMUsage.empty_usage() @@ -119,7 +121,7 @@ class GenerateTaskPipeline: raise self._handle_error(event) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, AnnotationReplyEvent): + elif isinstance(event, QueueAnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account @@ -136,7 +138,7 @@ class GenerateTaskPipeline: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.model_config + model_config = self._model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -193,7 +195,7 @@ class GenerateTaskPipeline: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id if self._task_state.metadata: @@ -219,7 +221,7 @@ class GenerateTaskPipeline: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.model_config + model_config = self._model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -272,7 +274,7 @@ class GenerateTaskPipeline: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: replace_response['conversation_id'] = self._conversation.id yield self._yield_response(replace_response) @@ -287,7 +289,7 @@ class GenerateTaskPipeline: 'message_id': self._message.id, } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id if self._task_state.metadata: @@ -296,7 +298,7 @@ class GenerateTaskPipeline: yield self._yield_response(response) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, AnnotationReplyEvent): + elif isinstance(event, QueueAnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account @@ -334,7 +336,7 @@ class GenerateTaskPipeline: 'message_files': agent_thought.files } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) @@ -365,12 +367,12 @@ class GenerateTaskPipeline: 'url': url } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) - elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): + elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): chunk = event.chunk delta_text = chunk.delta.message.content if delta_text is None: @@ -383,7 +385,7 @@ class GenerateTaskPipeline: if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() - self._queue_manager.publish_chunk_message(LLMResultChunk( + self._queue_manager.publish_llm_chunk(LLMResultChunk( model=self._task_state.llm_result.model, prompt_messages=self._task_state.llm_result.prompt_messages, delta=LLMResultChunkDelta( @@ -411,7 +413,7 @@ class GenerateTaskPipeline: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) @@ -452,8 +454,7 @@ class GenerateTaskPipeline: conversation=self._conversation, is_first_message=self._application_generate_entity.app_config.app_mode in [ AppMode.AGENT_CHAT, - AppMode.CHAT, - AppMode.ADVANCED_CHAT + AppMode.CHAT ] and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras ) @@ -473,7 +474,7 @@ class GenerateTaskPipeline: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id return response @@ -583,7 +584,7 @@ class GenerateTaskPipeline: :return: """ prompts = [] - if self._application_generate_entity.model_config.mode == 'chat': + if self._model_config.mode == ModelMode.CHAT.value: for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: role = 'user'