From a047a9846276a390a438e8af066dc1c83644bf5d Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 16 Mar 2024 14:25:04 +0800 Subject: [PATCH] advanced chat support --- api/controllers/console/explore/completion.py | 7 +++-- .../console/explore/conversation.py | 16 ++++++---- api/controllers/console/explore/error.py | 2 +- api/controllers/console/explore/message.py | 9 ++++-- api/controllers/service_api/app/completion.py | 8 +++-- .../service_api/app/conversation.py | 12 +++++--- api/controllers/service_api/app/error.py | 2 +- api/controllers/service_api/app/message.py | 8 +++-- api/controllers/web/completion.py | 7 +++-- api/controllers/web/conversation.py | 16 ++++++---- api/controllers/web/error.py | 2 +- api/controllers/web/message.py | 7 +++-- api/services/message_service.py | 29 +++++++++---------- 13 files changed, 77 insertions(+), 48 deletions(-) diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index bff494dccb..f0bf46f1a6 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -24,6 +24,7 @@ from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper from libs.helper import uuid_value +from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -95,7 +96,8 @@ class CompletionStopApi(InstalledAppResource): class ChatApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() parser = reqparse.RequestParser() @@ -148,7 +150,8 @@ class ChatApi(InstalledAppResource): class ChatStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 34a5904eca..7892840aeb 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -8,6 +8,7 @@ from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value +from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService @@ -18,7 +19,8 @@ class ConversationListApi(InstalledAppResource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() parser = reqparse.RequestParser() @@ -47,7 +49,8 @@ class ConversationListApi(InstalledAppResource): class ConversationApi(InstalledAppResource): def delete(self, installed_app, c_id): app_model = installed_app.app - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) @@ -65,7 +68,8 @@ class ConversationRenameApi(InstalledAppResource): @marshal_with(simple_conversation_fields) def post(self, installed_app, c_id): app_model = installed_app.app - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) @@ -91,7 +95,8 @@ class ConversationPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) @@ -107,7 +112,8 @@ class ConversationPinApi(InstalledAppResource): class ConversationUnPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 89c4d113a3..e1e3a2a877 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -9,7 +9,7 @@ class NotCompletionAppError(BaseHTTPException): class NotChatAppError(BaseHTTPException): error_code = 'not_chat_app' - description = "Not Chat App" + description = "App mode is invalid." code = 400 diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index ef051233b0..50e7eeb551 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -26,6 +26,7 @@ from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs import helper from libs.helper import uuid_value +from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError @@ -38,7 +39,8 @@ class MessageListApi(InstalledAppResource): def get(self, installed_app): app_model = installed_app.app - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() parser = reqparse.RequestParser() @@ -118,8 +120,9 @@ class MessageMoreLikeThisApi(InstalledAppResource): class MessageSuggestedQuestionApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app - if app_model.mode != 'chat': - raise NotCompletionAppError() + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + raise NotChatAppError() message_id = str(message_id) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 3f284d2326..c1fdf249bb 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -21,7 +21,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value -from models.model import App, EndUser +from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService @@ -90,7 +90,8 @@ class CompletionStopApi(Resource): class ChatApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() parser = reqparse.RequestParser() @@ -141,7 +142,8 @@ class ChatApi(Resource): class ChatStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 4a5fe2f19f..fc60f94ec9 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -8,7 +8,7 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value -from models.model import App, EndUser +from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService @@ -17,7 +17,8 @@ class ConversationApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() parser = reqparse.RequestParser() @@ -30,11 +31,13 @@ class ConversationApi(Resource): except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") + class ConversationDetailApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) def delete(self, app_model: App, end_user: EndUser, c_id): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) @@ -51,7 +54,8 @@ class ConversationRenameApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index eb953d0950..590d462deb 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -15,7 +15,7 @@ class NotCompletionAppError(BaseHTTPException): class NotChatAppError(BaseHTTPException): error_code = 'not_chat_app' - description = "Please check if your Chat app mode matches the right API route." + description = "Please check if your app mode matches the right API route." code = 400 diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 0050ab1aee..4e96a924b0 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -8,7 +8,7 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from fields.conversation_fields import message_file_fields from libs.helper import TimestampField, uuid_value -from models.model import App, EndUser +from models.model import App, AppMode, EndUser from services.message_service import MessageService @@ -71,7 +71,8 @@ class MessageListApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() parser = reqparse.RequestParser() @@ -110,7 +111,8 @@ class MessageSuggestedApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) def get(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() try: diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 452ce8709e..948d5fabb5 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value +from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -88,7 +89,8 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): def post(self, app_model, end_user): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() parser = reqparse.RequestParser() @@ -138,7 +140,8 @@ class ChatApi(WebApiResource): class ChatStopApi(WebApiResource): def post(self, app_model, end_user, task_id): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index c287f2a879..bbc57c7d61 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -7,6 +7,7 @@ from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value +from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService @@ -16,7 +17,8 @@ class ConversationListApi(WebApiResource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() parser = reqparse.RequestParser() @@ -43,7 +45,8 @@ class ConversationListApi(WebApiResource): class ConversationApi(WebApiResource): def delete(self, app_model, end_user, c_id): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) @@ -60,7 +63,8 @@ class ConversationRenameApi(WebApiResource): @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) @@ -85,7 +89,8 @@ class ConversationRenameApi(WebApiResource): class ConversationPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) @@ -100,7 +105,8 @@ class ConversationPinApi(WebApiResource): class ConversationUnPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 9cb3c8f235..453d08d2fa 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -15,7 +15,7 @@ class NotCompletionAppError(BaseHTTPException): class NotChatAppError(BaseHTTPException): error_code = 'not_chat_app' - description = "Please check if your Chat app mode matches the right API route." + description = "Please check if your app mode matches the right API route." code = 400 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index c4e49118d8..51a48ee9fb 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -24,6 +24,7 @@ from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields from libs import helper from libs.helper import TimestampField, uuid_value +from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError @@ -76,7 +77,8 @@ class MessageListApi(WebApiResource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotChatAppError() parser = reqparse.RequestParser() @@ -154,7 +156,8 @@ class MessageMoreLikeThisApi(WebApiResource): class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): - if app_model.mode != 'chat': + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: raise NotCompletionAppError() message_id = str(message_id) diff --git a/api/services/message_service.py b/api/services/message_service.py index 20918a8781..ced4b812b7 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -10,13 +10,11 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.model import App, AppModelConfig, EndUser, Message, MessageFeedback from services.conversation_service import ConversationService -from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, LastMessageNotExistsError, MessageNotExistsError, - SuggestedQuestionsAfterAnswerDisabledError, ) @@ -204,9 +202,6 @@ class MessageService: AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id ).first() - - if not app_model_config: - raise AppModelConfigBrokenError() else: conversation_override_model_configs = json.loads(conversation.override_model_configs) app_model_config = AppModelConfig( @@ -216,19 +211,21 @@ class MessageService: app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) - suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict - - if check_enabled and suggested_questions_after_answer.get("enabled", False) is False: - raise SuggestedQuestionsAfterAnswerDisabledError() - # get memory of conversation (read-only) model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=app_model.tenant_id, - provider=app_model_config.model_dict['provider'], - model_type=ModelType.LLM, - model=app_model_config.model_dict['name'] - ) + + if app_model_config: + model_instance = model_manager.get_model_instance( + tenant_id=app_model.tenant_id, + provider=app_model_config.model_dict['provider'], + model_type=ModelType.LLM, + model=app_model_config.model_dict['name'] + ) + else: + model_instance = model_manager.get_default_model_instance( + tenant_id=app_model.tenant_id, + model_type=ModelType.LLM + ) memory = TokenBufferMemory( conversation=conversation,