This commit is contained in:
takatost 2024-03-19 15:32:10 +08:00
parent 24ac4996c0
commit 133d52deb9
5 changed files with 39 additions and 23 deletions

View File

@ -13,8 +13,10 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
@ -25,7 +27,7 @@ from libs.login import login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
@ -187,7 +189,8 @@ class MessageSuggestedQuestionApi(Resource):
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
message_id=message_id,
user=current_user
user=current_user,
invoke_from=InvokeFrom.DEBUGGER
)
except MessageNotExistsError:
raise NotFound("Message not found")
@ -201,6 +204,8 @@ class MessageSuggestedQuestionApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except Exception:
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -130,7 +130,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=current_user,
message_id=message_id
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")

View File

@ -1,6 +1,8 @@
import logging
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.service_api import api
@ -9,6 +11,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser
from services.errors.message import SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
@ -119,10 +122,16 @@ class MessageSuggestedApi(Resource):
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=end_user,
message_id=message_id
message_id=message_id,
invoke_from=InvokeFrom.SERVICE_API
)
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
except SuggestedQuestionsAfterAnswerDisabledError:
raise BadRequest("Message Not Exists.")
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
return {'result': 'success', 'data': questions}

View File

@ -166,7 +166,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=end_user,
message_id=message_id
message_id=message_id,
invoke_from=InvokeFrom.WEB_APP
)
except MessageNotExistsError:
raise NotFound("Message not found")

View File

@ -2,6 +2,7 @@ import json
from typing import Optional, Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelManager
@ -18,6 +19,7 @@ from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.workflow_service import WorkflowService
class MessageService:
@ -177,7 +179,7 @@ class MessageService:
@classmethod
def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]],
message_id: str) -> list[Message]:
message_id: str, invoke_from: InvokeFrom) -> list[Message]:
if not user:
raise ValueError('user cannot be None')
@ -201,8 +203,13 @@ class MessageService:
model_manager = ModelManager()
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
workflow = workflow_service.get_draft_workflow(app_model=app_model)
else:
workflow = workflow_service.get_published_workflow(app_model=app_model)
if workflow is None:
return []
@ -233,24 +240,17 @@ class MessageService:
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if app_model_config:
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
provider=app_model_config.model_dict['provider'],
model_type=ModelType.LLM,
model=app_model_config.model_dict['name']
)
else:
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id,
model_type=ModelType.LLM
)
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
provider=app_model_config.model_dict['provider'],
model_type=ModelType.LLM,
model=app_model_config.model_dict['name']
)
# get memory of conversation (read-only)
memory = TokenBufferMemory(
conversation=conversation,