advanced chat support

This commit is contained in:
takatost 2024-03-16 14:25:04 +08:00
parent 1df68a546e
commit a047a98462
13 changed files with 77 additions and 48 deletions

View File

@ -24,6 +24,7 @@ from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -95,7 +96,8 @@ class CompletionStopApi(InstalledAppResource):
class ChatApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -148,7 +150,8 @@ class ChatApi(InstalledAppResource):
class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

View File

@ -8,6 +8,7 @@ from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.web_conversation_service import WebConversationService
@ -18,7 +19,8 @@ class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, installed_app):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -47,7 +49,8 @@ class ConversationListApi(InstalledAppResource):
class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -65,7 +68,8 @@ class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -91,7 +95,8 @@ class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -107,7 +112,8 @@ class ConversationPinApi(InstalledAppResource):
class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -9,7 +9,7 @@ class NotCompletionAppError(BaseHTTPException):
class NotChatAppError(BaseHTTPException):
error_code = 'not_chat_app'
description = "Not Chat App"
description = "App mode is invalid."
code = 400

View File

@ -26,6 +26,7 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
@ -38,7 +39,8 @@ class MessageListApi(InstalledAppResource):
def get(self, installed_app):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -118,8 +120,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id):
app_model = installed_app.app
if app_model.mode != 'chat':
raise NotCompletionAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
message_id = str(message_id)

View File

@ -21,7 +21,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from models.model import App, EndUser
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
@ -90,7 +90,8 @@ class CompletionStopApi(Resource):
class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -141,7 +142,8 @@ class ChatApi(Resource):
class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

View File

@ -8,7 +8,7 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import App, EndUser
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@ -17,7 +17,8 @@ class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -30,11 +31,13 @@ class ConversationApi(Resource):
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -51,7 +54,8 @@ class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -15,7 +15,7 @@ class NotCompletionAppError(BaseHTTPException):
class NotChatAppError(BaseHTTPException):
error_code = 'not_chat_app'
description = "Please check if your Chat app mode matches the right API route."
description = "Please check if your app mode matches the right API route."
code = 400

View File

@ -8,7 +8,7 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from models.model import App, EndUser
from models.model import App, AppMode, EndUser
from services.message_service import MessageService
@ -71,7 +71,8 @@ class MessageListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -110,7 +111,8 @@ class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
try:

View File

@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -88,7 +89,8 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
def post(self, app_model, end_user):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -138,7 +140,8 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

View File

@ -7,6 +7,7 @@ from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.web_conversation_service import WebConversationService
@ -16,7 +17,8 @@ class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -43,7 +45,8 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -60,7 +63,8 @@ class ConversationRenameApi(WebApiResource):
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -85,7 +89,8 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -100,7 +105,8 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -15,7 +15,7 @@ class NotCompletionAppError(BaseHTTPException):
class NotChatAppError(BaseHTTPException):
error_code = 'not_chat_app'
description = "Please check if your Chat app mode matches the right API route."
description = "Please check if your app mode matches the right API route."
code = 400

View File

@ -24,6 +24,7 @@ from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields
from libs import helper
from libs.helper import TimestampField, uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
@ -76,7 +77,8 @@ class MessageListApi(WebApiResource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -154,7 +156,8 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotCompletionAppError()
message_id = str(message_id)

View File

@ -10,13 +10,11 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.model import App, AppModelConfig, EndUser, Message, MessageFeedback
from services.conversation_service import ConversationService
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
from services.errors.message import (
FirstMessageNotExistsError,
LastMessageNotExistsError,
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
@ -204,9 +202,6 @@ class MessageService:
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
if not app_model_config:
raise AppModelConfigBrokenError()
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
@ -216,19 +211,21 @@ class MessageService:
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
# get memory of conversation (read-only)
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
provider=app_model_config.model_dict['provider'],
model_type=ModelType.LLM,
model=app_model_config.model_dict['name']
)
if app_model_config:
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
provider=app_model_config.model_dict['provider'],
model_type=ModelType.LLM,
model=app_model_config.model_dict['name']
)
else:
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id,
model_type=ModelType.LLM
)
memory = TokenBufferMemory(
conversation=conversation,