add message error record

This commit is contained in:
takatost 2024-03-21 18:30:23 +08:00
parent c4e6ed1aa2
commit 34e8d2f6bb
8 changed files with 126 additions and 12 deletions

View File

@ -38,6 +38,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
@ -167,7 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
event = message.event
if isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
@ -285,6 +286,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
@ -295,7 +298,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price
self._message.currency = usage.currency

View File

@ -182,6 +182,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
system_instruction="",
system_instruction_tokens=0,
status='normal',
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
@ -210,6 +211,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
provider_response_latency=0,
total_price=0,
currency='USD',
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id

View File

@ -14,10 +14,12 @@ from core.app.entities.task_entities import (
PingStreamResponse,
TaskState,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
from models.model import EndUser, Message
logger = logging.getLogger(__name__)
@ -48,21 +50,60 @@ class BasedGenerateTaskPipeline:
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
def _handle_error(self, event: QueueErrorEvent) -> Exception:
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception:
"""
Handle error event.
:param event: event
:param message: message
:return:
"""
logger.debug("error: %s", event.error)
e = event.error
if isinstance(e, InvokeAuthorizationError):
return InvokeAuthorizationError('Incorrect API key provided')
err = InvokeAuthorizationError('Incorrect API key provided')
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
return e
err = e
else:
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
if message:
message = db.session.query(Message).filter(Message.id == message.id).first()
err_desc = self._error_to_desc(err)
message.status = 'error'
message.error = err_desc
db.session.commit()
return err
def _error_to_desc(cls, e: Exception) -> str:
"""
Error to desc.
:param e: exception
:return:
"""
error_responses = {
ValueError: None,
ProviderTokenNotInitError: None,
QuotaExceededError: "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
ModelCurrentlyNotSupportError: None,
InvokeError: None
}
# Determine the response based on the type of exception
data = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v
if data:
message = getattr(e, 'description', str(e)) if data is None else data
else:
message = 'Internal Server Error, please contact support.'
return message
def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse:
"""

View File

@ -1,3 +1,4 @@
import json
import logging
import time
from collections.abc import Generator
@ -195,7 +196,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
event = message.event
if isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
@ -281,6 +282,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price
self._message.currency = usage.currency
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()

View File

@ -458,11 +458,24 @@ class WorkflowCycleManage:
def _handle_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> Optional[WorkflowRun]:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
if not workflow_run:
return None
if isinstance(event, QueueStopEvent):
latest_node_execution_info = self._task_state.latest_node_execution_info
if latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
if (workflow_node_execution
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=latest_node_execution_info.start_at,
error='Workflow stopped.'
)
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
if not workflow_run:
return None
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
start_at=self._task_state.start_at,

View File

@ -413,6 +413,11 @@ class WorkflowEngineManager:
node_run_result = node.run(
variable_pool=workflow_run_state.variable_pool
)
except GenerateTaskStoppedException as e:
node_run_result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error='Workflow stopped.'
)
except Exception as e:
logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
node_run_result = NodeRunResult(

View File

@ -0,0 +1,43 @@
"""add status for message
Revision ID: e2eacc9a1b63
Revises: 563cf8bf777b
Create Date: 2024-03-21 09:31:27.342221
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'e2eacc9a1b63'
down_revision = '563cf8bf777b'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('invoke_from', sa.String(length=255), nullable=True))
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.add_column(sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False))
batch_op.add_column(sa.Column('error', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('message_metadata', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('invoke_from', sa.String(length=255), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_column('invoke_from')
batch_op.drop_column('message_metadata')
batch_op.drop_column('error')
batch_op.drop_column('status')
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.drop_column('invoke_from')
# ### end Alembic commands ###

View File

@ -475,6 +475,7 @@ class Conversation(db.Model):
system_instruction = db.Column(db.Text)
system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
status = db.Column(db.String(255), nullable=False)
invoke_from = db.Column(db.String(255), nullable=True)
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(UUID)
from_account_id = db.Column(UUID)
@ -619,6 +620,10 @@ class Message(db.Model):
provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0'))
total_price = db.Column(db.Numeric(10, 7))
currency = db.Column(db.String(255), nullable=False)
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
error = db.Column(db.Text)
message_metadata = db.Column(db.Text)
invoke_from = db.Column(db.String(255), nullable=True)
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(UUID)
from_account_id = db.Column(UUID)