From 34e8d2f6bba03a01c0bec3c445a422adc6c41857 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 21 Mar 2024 18:30:23 +0800 Subject: [PATCH] add message error record --- .../advanced_chat/generate_task_pipeline.py | 6 ++- .../app/apps/message_based_app_generator.py | 2 + .../based_generate_task_pipeline.py | 51 +++++++++++++++++-- .../easy_ui_based_generate_task_pipeline.py | 5 +- .../task_pipeline/workflow_cycle_manage.py | 21 ++++++-- api/core/workflow/workflow_engine_manager.py | 5 ++ .../e2eacc9a1b63_add_status_for_message.py | 43 ++++++++++++++++ api/models/model.py | 5 ++ 8 files changed, 126 insertions(+), 12 deletions(-) create mode 100644 api/migrations/versions/e2eacc9a1b63_add_status_for_message.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 5a7adda3e8..042bc5c8f1 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -38,6 +38,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.file.file_obj import FileVar from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType, SystemVariable from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk @@ -167,7 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc event = message.event if isinstance(event, QueueErrorEvent): - err = self._handle_error(event) + err = self._handle_error(event, self._message) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): @@ -285,6 +286,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ + if self._task_state.metadata else None if self._task_state.metadata and self._task_state.metadata.get('usage'): usage = LLMUsage(**self._task_state.metadata['usage']) @@ -295,7 +298,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message.answer_tokens = usage.completion_tokens self._message.answer_unit_price = usage.completion_unit_price self._message.answer_price_unit = usage.completion_price_unit - self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price self._message.currency = usage.currency diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 8c475b755f..c70c5a97ae 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -182,6 +182,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): system_instruction="", system_instruction_tokens=0, status='normal', + invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, from_account_id=account_id, @@ -210,6 +211,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): provider_response_latency=0, total_price=0, currency='USD', + invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, from_account_id=account_id diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 9e50926ebb..b8d7d731b8 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -14,10 +14,12 @@ from core.app.entities.task_entities import ( PingStreamResponse, TaskState, ) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration +from extensions.ext_database import db from models.account import Account -from models.model import EndUser +from models.model import EndUser, Message logger = logging.getLogger(__name__) @@ -48,21 +50,60 @@ class BasedGenerateTaskPipeline: self._output_moderation_handler = self._init_output_moderation() self._stream = stream - def _handle_error(self, event: QueueErrorEvent) -> Exception: + def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception: """ Handle error event. :param event: event + :param message: message :return: """ logger.debug("error: %s", event.error) e = event.error if isinstance(e, InvokeAuthorizationError): - return InvokeAuthorizationError('Incorrect API key provided') + err = InvokeAuthorizationError('Incorrect API key provided') elif isinstance(e, InvokeError) or isinstance(e, ValueError): - return e + err = e else: - return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + + if message: + message = db.session.query(Message).filter(Message.id == message.id).first() + err_desc = self._error_to_desc(err) + message.status = 'error' + message.error = err_desc + + db.session.commit() + + return err + + def _error_to_desc(cls, e: Exception) -> str: + """ + Error to desc. + :param e: exception + :return: + """ + error_responses = { + ValueError: None, + ProviderTokenNotInitError: None, + QuotaExceededError: "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + ModelCurrentlyNotSupportError: None, + InvokeError: None + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + message = getattr(e, 'description', str(e)) if data is None else data + else: + message = 'Internal Server Error, please contact support.' + + return message def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse: """ diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 3d936e2b44..4fc9d6abaa 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Generator @@ -195,7 +196,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan event = message.event if isinstance(event, QueueErrorEvent): - err = self._handle_error(event) + err = self._handle_error(event, self._message) yield self._error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): @@ -281,6 +282,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price self._message.currency = usage.currency + self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ + if self._task_state.metadata else None db.session.commit() diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index eb2170fad0..7600a57854 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -458,11 +458,24 @@ class WorkflowCycleManage: def _handle_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ -> Optional[WorkflowRun]: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - if not workflow_run: - return None - if isinstance(event, QueueStopEvent): + latest_node_execution_info = self._task_state.latest_node_execution_info + if latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first() + if (workflow_node_execution + and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value): + self._workflow_node_execution_failed( + workflow_node_execution=workflow_node_execution, + start_at=latest_node_execution_info.start_at, + error='Workflow stopped.' + ) + + workflow_run = db.session.query(WorkflowRun).filter( + WorkflowRun.id == self._task_state.workflow_run_id).first() + if not workflow_run: + return None + workflow_run = self._workflow_run_failed( workflow_run=workflow_run, start_at=self._task_state.start_at, diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index be5bd1c17a..a9fa646bb5 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -413,6 +413,11 @@ class WorkflowEngineManager: node_run_result = node.run( variable_pool=workflow_run_state.variable_pool ) + except GenerateTaskStoppedException as e: + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error='Workflow stopped.' + ) except Exception as e: logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") node_run_result = NodeRunResult( diff --git a/api/migrations/versions/e2eacc9a1b63_add_status_for_message.py b/api/migrations/versions/e2eacc9a1b63_add_status_for_message.py new file mode 100644 index 0000000000..08f994a41f --- /dev/null +++ b/api/migrations/versions/e2eacc9a1b63_add_status_for_message.py @@ -0,0 +1,43 @@ +"""add status for message + +Revision ID: e2eacc9a1b63 +Revises: 563cf8bf777b +Create Date: 2024-03-21 09:31:27.342221 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'e2eacc9a1b63' +down_revision = '563cf8bf777b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.add_column(sa.Column('invoke_from', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('error', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('message_metadata', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('invoke_from', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('invoke_from') + batch_op.drop_column('message_metadata') + batch_op.drop_column('error') + batch_op.drop_column('status') + + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_column('invoke_from') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 6571a31c43..9914658272 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -475,6 +475,7 @@ class Conversation(db.Model): system_instruction = db.Column(db.Text) system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) status = db.Column(db.String(255), nullable=False) + invoke_from = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(UUID) from_account_id = db.Column(UUID) @@ -619,6 +620,10 @@ class Message(db.Model): provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + error = db.Column(db.Text) + message_metadata = db.Column(db.Text) + invoke_from = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(UUID) from_account_id = db.Column(UUID)