diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 9a8de8ae3d..a29900fc8d 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -187,8 +187,7 @@ class MessageSuggestedQuestionApi(Resource): questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, message_id=message_id, - user=current_user, - check_enabled=False + user=current_user ) except MessageNotExistsError: raise NotFound("Message not found") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 703ff6e258..d7ccd25c2a 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -119,8 +119,7 @@ class MessageSuggestedApi(Resource): questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=end_user, - message_id=message_id, - check_enabled=False + message_id=message_id ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/services/message_service.py b/api/services/message_service.py index ced4b812b7..8236362b52 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,6 +1,7 @@ import json from typing import Optional, Union +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.llm_generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager @@ -8,13 +9,14 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.model import App, AppModelConfig, EndUser, Message, MessageFeedback +from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback from services.conversation_service import ConversationService from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, LastMessageNotExistsError, MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, ) @@ -175,7 +177,7 @@ class MessageService: @classmethod def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]], - message_id: str, check_enabled: bool = True) -> list[Message]: + message_id: str) -> list[Message]: if not user: raise ValueError('user cannot be None') @@ -197,36 +199,59 @@ class MessageService: if conversation.status != 'normal': raise ConversationCompletedError() - if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() - else: - conversation_override_model_configs = json.loads(conversation.override_model_configs) - app_model_config = AppModelConfig( - id=conversation.app_model_config_id, - app_id=app_model.id, - ) - - app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) - - # get memory of conversation (read-only) model_manager = ModelManager() - if app_model_config: - model_instance = model_manager.get_model_instance( - tenant_id=app_model.tenant_id, - provider=app_model_config.model_dict['provider'], - model_type=ModelType.LLM, - model=app_model_config.model_dict['name'] + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + return [] + + app_config = AdvancedChatAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow ) - else: + + if not app_config.additional_features.suggested_questions_after_answer: + raise SuggestedQuestionsAfterAnswerDisabledError() + model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.LLM ) + else: + if not conversation.override_model_configs: + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == conversation.app_model_config_id, + AppModelConfig.app_id == app_model.id + ).first() + else: + conversation_override_model_configs = json.loads(conversation.override_model_configs) + app_model_config = AppModelConfig( + id=conversation.app_model_config_id, + app_id=app_model.id, + ) + app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) + + if app_model_config: + model_instance = model_manager.get_model_instance( + tenant_id=app_model.tenant_id, + provider=app_model_config.model_dict['provider'], + model_type=ModelType.LLM, + model=app_model_config.model_dict['name'] + ) + else: + model_instance = model_manager.get_default_model_instance( + tenant_id=app_model.tenant_id, + model_type=ModelType.LLM + ) + + suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict + + if check_enabled and suggested_questions_after_answer.get("enabled", False) is False: + raise SuggestedQuestionsAfterAnswerDisabledError() + + # get memory of conversation (read-only) memory = TokenBufferMemory( conversation=conversation, model_instance=model_instance