From fce20e483cf4cc4eadd8f3386f4478ac5a50bbfd Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:30:36 +0800 Subject: [PATCH] restore completion app --- api/controllers/console/app/app.py | 2 +- api/controllers/console/app/completion.py | 4 +- api/controllers/console/app/conversation.py | 4 +- api/controllers/console/app/statistic.py | 2 +- api/controllers/console/explore/message.py | 47 +++++++++++++++ api/controllers/web/message.py | 47 +++++++++++++++ api/core/app_runner/app_runner.py | 19 ++++-- api/core/prompt/prompt_transform.py | 7 +-- api/core/prompt/simple_prompt_transform.py | 38 +++++++----- api/services/app_model_config_service.py | 18 ++++++ api/services/completion_service.py | 60 ++++++++++++++++++- api/services/errors/__init__.py | 2 +- api/services/errors/app.py | 2 + .../prompt/test_simple_prompt_transform.py | 2 + 14 files changed, 224 insertions(+), 30 deletions(-) create mode 100644 api/services/errors/app.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index cf505bedb8..93dc1ca34a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -80,7 +80,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json') + parser.add_argument('mode', type=str, choices=['chat', 'agent', 'workflow'], location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') parser.add_argument('model_config', type=dict, location='json') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 11fdba177d..e62475308f 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -37,7 +37,7 @@ class CompletionMessageApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') @@ -90,7 +90,7 @@ class CompletionMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model, task_id): account = flask_login.current_user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index daf9641121..b808d62eb0 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -29,7 +29,7 @@ class CompletionConversationApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() @@ -102,7 +102,7 @@ class CompletionConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_message_detail_fields) def get(self, app_model, conversation_id): conversation_id = str(conversation_id) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index ea4d597112..e3a5112200 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -330,7 +330,7 @@ class AverageResponseTimeStatistic(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) def get(self, app_model): account = current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index bef26b4d99..47af28425f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -12,6 +12,7 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api from controllers.console.app.error import ( + AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -23,10 +24,13 @@ from controllers.console.explore.error import ( NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource +from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs.helper import uuid_value +from services.completion_service import CompletionService +from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -72,6 +76,48 @@ class MessageFeedbackApi(InstalledAppResource): return {'result': 'success'} +class MessageMoreLikeThisApi(InstalledAppResource): + def get(self, installed_app, message_id): + app_model = installed_app.app + if app_model.mode != 'completion': + raise NotCompletionAppError() + + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.generate_more_like_this( + app_model=app_model, + user=current_user, + message_id=message_id, + invoke_from=InvokeFrom.EXPLORE, + streaming=streaming + ) + return compact_response(response) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + except MoreLikeThisDisabledError: + raise AppMoreLikeThisDisabledError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -120,4 +166,5 @@ class MessageSuggestedQuestionApi(InstalledAppResource): api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') +api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 5120f49c5e..e03bdd63bb 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -11,6 +11,7 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.web import api from controllers.web.error import ( + AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, CompletionRequestError, NotChatAppError, @@ -20,11 +21,14 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource +from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields from libs.helper import TimestampField, uuid_value +from services.completion_service import CompletionService +from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -109,6 +113,48 @@ class MessageFeedbackApi(WebApiResource): return {'result': 'success'} +class MessageMoreLikeThisApi(WebApiResource): + def get(self, app_model, end_user, message_id): + if app_model.mode != 'completion': + raise NotCompletionAppError() + + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.generate_more_like_this( + app_model=app_model, + user=end_user, + message_id=message_id, + invoke_from=InvokeFrom.WEB_APP, + streaming=streaming + ) + + return compact_response(response) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + except MoreLikeThisDisabledError: + raise AppMoreLikeThisDisabledError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -156,4 +202,5 @@ class MessageSuggestedQuestionApi(WebApiResource): api.add_resource(MessageListApi, '/messages') api.add_resource(MessageFeedbackApi, '/messages//feedbacks') +api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index c6f6268a7a..231530ef08 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage, from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform -from models.model import App, Message, MessageAnnotation +from models.model import App, Message, MessageAnnotation, AppMode class AppRunner: @@ -140,11 +141,11 @@ class AppRunner: :param memory: memory :return: """ - prompt_transform = SimplePromptTransform() - # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform = SimplePromptTransform() prompt_messages, stop = prompt_transform.get_prompt( + app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, query=query if query else '', @@ -154,7 +155,17 @@ class AppRunner: model_config=model_config ) else: - raise NotImplementedError("Advanced prompt is not supported yet.") + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query if query else '', + files=files, + context=context, + memory=memory, + model_config=model_config + ) + stop = model_config.stop return prompt_messages, stop diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 9596976b6e..9c554140b7 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -11,10 +11,9 @@ class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> list[PromptMessage]: - if memory: - rest_tokens = self._calculate_rest_token(prompt_messages, model_config) - histories = self._get_history_messages_list_from_memory(memory, rest_tokens) - prompt_messages.extend(histories) + rest_tokens = self._calculate_rest_token(prompt_messages, model_config) + histories = self._get_history_messages_list_from_memory(memory, rest_tokens) + prompt_messages.extend(histories) return prompt_messages diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 2f98fbcae8..a929416be4 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform): """ def get_prompt(self, + app_mode: AppMode, prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, @@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform): model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( + app_mode=app_mode, pre_prompt=prompt_template_entity.simple_prompt_template, inputs=inputs, query=query, @@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform): ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( + app_mode=app_mode, pre_prompt=prompt_template_entity.simple_prompt_template, inputs=inputs, query=query, @@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform): "prompt_rules": prompt_rules } - def _get_chat_model_prompt_messages(self, pre_prompt: str, + def _get_chat_model_prompt_messages(self, app_mode: AppMode, + pre_prompt: str, inputs: dict, query: str, context: Optional[str], @@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform): # get prompt prompt, _ = self.get_prompt_str_and_rules( - app_mode=AppMode.CHAT, + app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, @@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform): ) if prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) + if query: + prompt_messages.append(SystemPromptMessage(content=prompt)) + else: + prompt_messages.append(UserPromptMessage(content=prompt)) - prompt_messages = self._append_chat_histories( - memory=memory, - prompt_messages=prompt_messages, - model_config=model_config - ) + if memory: + prompt_messages = self._append_chat_histories( + memory=memory, + prompt_messages=prompt_messages, + model_config=model_config + ) - prompt_messages.append(self.get_last_user_message(query, files)) + if query: + prompt_messages.append(self.get_last_user_message(query, files)) return prompt_messages, None - def _get_completion_model_prompt_messages(self, pre_prompt: str, + def _get_completion_model_prompt_messages(self, app_mode: AppMode, + pre_prompt: str, inputs: dict, query: str, context: Optional[str], @@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform): -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( - app_mode=AppMode.CHAT, + app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, @@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform): # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( - app_mode=AppMode.CHAT, + app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, @@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform): is_baichuan = True if is_baichuan: - if app_mode == AppMode.WORKFLOW: + if app_mode == AppMode.COMPLETION: return 'baichuan_completion' else: return 'baichuan_chat' # common - if app_mode == AppMode.WORKFLOW: + if app_mode == AppMode.COMPLETION: return 'common_completion' else: return 'common_chat' diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index aa8cd73ea7..34b6d62d51 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -316,6 +316,9 @@ class AppModelConfigService: if "tool_parameters" not in tool: raise ValueError("tool_parameters is required in agent_mode.tools") + # dataset_query_variable + cls.is_dataset_query_variable_valid(config, app_mode) + # advanced prompt validation cls.is_advanced_prompt_valid(config, app_mode) @@ -441,6 +444,21 @@ class AppModelConfigService: config=config ) + @classmethod + def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: + # Only check when mode is completion + if mode != 'completion': + return + + agent_mode = config.get("agent_mode", {}) + tools = agent_mode.get("tools", []) + dataset_exists = "dataset" in str(tools) + + dataset_query_variable = config.get("dataset_query_variable") + + if dataset_exists and not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + @classmethod def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: # prompt_type diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 5599c60113..cbfbe9ef41 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -8,10 +8,12 @@ from core.application_manager import ApplicationManager from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser +from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message from services.app_model_config_service import AppModelConfigService +from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError +from services.errors.message import MessageNotExistsError class CompletionService: @@ -155,6 +157,62 @@ class CompletionService: } ) + @classmethod + def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], + message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ + -> Union[dict, Generator]: + if not user: + raise ValueError('user cannot be None') + + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ).first() + + if not message: + raise MessageNotExistsError() + + current_app_model_config = app_model.app_model_config + more_like_this = current_app_model_config.more_like_this_dict + + if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: + raise MoreLikeThisDisabledError() + + app_model_config = message.app_model_config + model_dict = app_model_config.model_dict + completion_params = model_dict.get('completion_params') + completion_params['temperature'] = 0.9 + model_dict['completion_params'] = completion_params + app_model_config.model = json.dumps(model_dict) + + # parse files + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_objs = message_file_parser.transform_message_files( + message.files, app_model_config + ) + + application_manager = ApplicationManager() + return application_manager.generate( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_model_config_id=app_model_config.id, + app_model_config_dict=app_model_config.to_dict(), + app_model_config_override=True, + user=user, + invoke_from=invoke_from, + inputs=message.inputs, + query=message.query, + files=file_objs, + conversation=None, + stream=streaming, + extras={ + "auto_generate_conversation_name": False + } + ) + @classmethod def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): if user_inputs is None: diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index a44c190cbc..5804f599fe 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- __all__ = [ 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', - 'completion', 'audio', 'file' + 'app', 'completion', 'audio', 'file' ] from . import * diff --git a/api/services/errors/app.py b/api/services/errors/app.py new file mode 100644 index 0000000000..7c4ca99c2a --- /dev/null +++ b/api/services/errors/app.py @@ -0,0 +1,2 @@ +class MoreLikeThisDisabledError(Exception): + pass diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index c174983e38..a95a6dc52f 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -160,6 +160,7 @@ def test__get_chat_model_prompt_messages(): context = "yes or no." query = "How are you?" prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( + app_mode=AppMode.CHAT, pre_prompt=pre_prompt, inputs=inputs, query=query, @@ -214,6 +215,7 @@ def test__get_completion_model_prompt_messages(): context = "yes or no." query = "How are you?" prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( + app_mode=AppMode.CHAT, pre_prompt=pre_prompt, inputs=inputs, query=query,