From 49992925e29f05f3cbd14cda72b657a495de6c7a Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 19 Feb 2024 16:55:59 +0800 Subject: [PATCH] optimize get app model to wraps --- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/__init__.py | 21 ---- api/controllers/console/app/app.py | 100 +++++++----------- api/controllers/console/app/audio.py | 23 ++-- api/controllers/console/app/completion.py | 36 ++----- api/controllers/console/app/conversation.py | 59 ++++------- api/controllers/console/app/message.py | 64 ++++------- api/controllers/console/app/model_config.py | 17 ++- api/controllers/console/app/site.py | 14 +-- api/controllers/console/app/statistic.py | 38 +++---- api/controllers/console/app/workflow.py | 20 ++++ api/controllers/console/app/wraps.py | 55 ++++++++++ api/core/app_runner/basic_app_runner.py | 4 +- api/core/entities/application_entities.py | 20 ++++ api/core/prompt/prompt_transform.py | 20 +--- .../advanced_prompt_template_service.py | 2 +- api/services/app_model_config_service.py | 2 +- 17 files changed, 232 insertions(+), 265 deletions(-) create mode 100644 api/controllers/console/app/workflow.py create mode 100644 api/controllers/console/app/wraps.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ecfdc38612..934b19116b 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -8,7 +8,7 @@ api = ExternalApi(bp) from . import admin, apikey, extension, feature, setup, version # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic) + model_config, site, statistic, workflow) # Import auth controllers from .auth import activate, data_source_oauth, login, oauth # Import billing controllers diff --git a/api/controllers/console/app/__init__.py b/api/controllers/console/app/__init__.py index b0b07517f1..e69de29bb2 100644 --- a/api/controllers/console/app/__init__.py +++ b/api/controllers/console/app/__init__.py @@ -1,21 +0,0 @@ -from controllers.console.app.error import AppUnavailableError -from extensions.ext_database import db -from flask_login import current_user -from models.model import App -from werkzeug.exceptions import NotFound - - -def _get_app(app_id, mode=None): - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() - - if not app: - raise NotFound("App not found") - - if mode and app.mode != mode: - raise NotFound("The {} app not found".format(mode)) - - return app diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4b648a4e28..f291f8e81a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -9,7 +9,8 @@ from werkzeug.exceptions import Forbidden from constants.languages import demo_model_templates, languages from constants.model_template import model_templates from controllers.console import api -from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -31,13 +32,6 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager from core.entities.application_entities import AgentToolEntity -def _get_app(app_id, tenant_id): - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() - if not app: - raise AppNotFoundError - return app - - class AppListApi(Resource): @setup_required @@ -234,14 +228,12 @@ class AppApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields_with_site) - def get(self, app_id): + def get(self, app_model): """Get app detail""" - app_id = str(app_id) - app: App = _get_app(app_id, current_user.current_tenant_id) - # get original app model config - model_config: AppModelConfig = app.app_model_config + model_config: AppModelConfig = app_model.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input for tool in agent_mode.get('tools') or []: @@ -272,27 +264,24 @@ class AppApi(Resource): # override agent mode model_config.agent_mode = json.dumps(agent_mode) - return app + return app_model @setup_required @login_required @account_initialization_required - def delete(self, app_id): + @get_app_model + def delete(self, app_model): """Delete app""" - app_id = str(app_id) - if not current_user.is_admin_or_owner: raise Forbidden() - app = _get_app(app_id, current_user.current_tenant_id) - - db.session.delete(app) + db.session.delete(app_model) db.session.commit() # todo delete related data?? # model_config, site, api_token, conversation, message, message_feedback, message_annotation - app_was_deleted.send(app) + app_was_deleted.send(app_model) return {'result': 'success'}, 204 @@ -301,86 +290,77 @@ class AppNameApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') args = parser.parse_args() - app.name = args.get('name') - app.updated_at = datetime.utcnow() + app_model.name = args.get('name') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppIconApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() - app.icon = args.get('icon') - app.icon_background = args.get('icon_background') - app.updated_at = datetime.utcnow() + app_model.icon = args.get('icon') + app_model.icon_background = args.get('icon_background') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppSiteStatus(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('enable_site', type=bool, required=True, location='json') args = parser.parse_args() - app_id = str(app_id) - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id).first() - if not app: - raise AppNotFoundError - if args.get('enable_site') == app.enable_site: - return app + if args.get('enable_site') == app_model.enable_site: + return app_model - app.enable_site = args.get('enable_site') - app.updated_at = datetime.utcnow() + app_model.enable_site = args.get('enable_site') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppApiStatus(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('enable_api', type=bool, required=True, location='json') args = parser.parse_args() - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) + if args.get('enable_api') == app_model.enable_api: + return app_model - if args.get('enable_api') == app.enable_api: - return app - - app.enable_api = args.get('enable_api') - app.updated_at = datetime.utcnow() + app_model.enable_api = args.get('enable_api') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppCopy(Resource): @@ -410,16 +390,14 @@ class AppCopy(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - - copy_app = self.create_app_copy(app) + def post(self, app_model): + copy_app = self.create_app_copy(app_model) db.session.add(copy_app) app_config = db.session.query(AppModelConfig). \ - filter(AppModelConfig.app_id == app_id). \ + filter(AppModelConfig.app_id == app_model.id). \ one_or_none() if app_config: diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 77eaf136fc..daa5570f9a 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,7 +6,6 @@ from werkzeug.exceptions import InternalServerError import services from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -18,8 +17,10 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, UnsupportedAudioTypeError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required @@ -36,10 +37,8 @@ class ChatMessageAudioApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id, 'chat') - + @get_app_model(mode=AppMode.CHAT) + def post(self, app_model): file = request.files['file'] try: @@ -80,10 +79,8 @@ class ChatMessageTextApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id, None) - + @get_app_model + def post(self, app_model): try: response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, @@ -120,9 +117,11 @@ class ChatMessageTextApi(Resource): class TextModesApi(Resource): - def get(self, app_id: str): - app_model = _get_app(str(app_id)) - + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): try: parser = reqparse.RequestParser() parser.add_argument('language', type=str, required=True, location='args') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index f01d2afa03..f378f7b218 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -10,7 +10,6 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -19,10 +18,11 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom +from core.entities.application_entities import InvokeFrom, AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value @@ -36,12 +36,8 @@ class CompletionMessageApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app_model = _get_app(app_id, 'completion') - + @get_app_model(mode=AppMode.WORKFLOW) + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') @@ -93,12 +89,8 @@ class CompletionMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id, task_id): - app_id = str(app_id) - - # get app info - _get_app(app_id, 'completion') - + @get_app_model(mode=AppMode.WORKFLOW) + def post(self, app_model, task_id): account = flask_login.current_user ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) @@ -110,12 +102,8 @@ class ChatMessageApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app_model = _get_app(app_id, 'chat') - + @get_app_model(mode=AppMode.CHAT) + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json') @@ -179,12 +167,8 @@ class ChatMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id, task_id): - app_id = str(app_id) - - # get app info - _get_app(app_id, 'chat') - + @get_app_model(mode=AppMode.CHAT) + def post(self, app_model, task_id): account = flask_login.current_user ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 452b0fddf6..4ee1ee4035 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -9,9 +9,10 @@ from sqlalchemy.orm import joinedload from werkzeug.exceptions import NotFound from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode from extensions.ext_database import db from fields.conversation_fields import ( conversation_detail_fields, @@ -29,10 +30,9 @@ class CompletionConversationApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.WORKFLOW) @marshal_with(conversation_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('keyword', type=str, location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -43,10 +43,7 @@ class CompletionConversationApi(Resource): parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') args = parser.parse_args() - # get app info - app = _get_app(app_id, 'completion') - - query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'completion') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') if args['keyword']: query = query.join( @@ -106,24 +103,22 @@ class CompletionConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.WORKFLOW) @marshal_with(conversation_message_detail_fields) - def get(self, app_id, conversation_id): - app_id = str(app_id) + def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - return _get_conversation(app_id, conversation_id, 'completion') + return _get_conversation(app_model, conversation_id) @setup_required @login_required @account_initialization_required - def delete(self, app_id, conversation_id): - app_id = str(app_id) + @get_app_model(mode=AppMode.CHAT) + def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) - app = _get_app(app_id, 'chat') - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") @@ -139,10 +134,9 @@ class ChatConversationApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.CHAT) @marshal_with(conversation_with_summary_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('keyword', type=str, location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -154,10 +148,7 @@ class ChatConversationApi(Resource): parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') args = parser.parse_args() - # get app info - app = _get_app(app_id, 'chat') - - query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'chat') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'chat') if args['keyword']: query = query.join( @@ -228,25 +219,22 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.CHAT) @marshal_with(conversation_detail_fields) - def get(self, app_id, conversation_id): - app_id = str(app_id) + def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - return _get_conversation(app_id, conversation_id, 'chat') + return _get_conversation(app_model, conversation_id) @setup_required @login_required + @get_app_model(mode=AppMode.CHAT) @account_initialization_required - def delete(self, app_id, conversation_id): - app_id = str(app_id) + def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) - # get app info - app = _get_app(app_id, 'chat') - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") @@ -263,12 +251,9 @@ api.add_resource(ChatConversationApi, '/apps//chat-conversations') api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') -def _get_conversation(app_id, conversation_id, mode): - # get app info - app = _get_app(app_id, mode) - +def _get_conversation(app_model, conversation_id): conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 0064dbe663..360602b9c2 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -10,7 +10,6 @@ from flask_restful.inputs import int_range from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -18,9 +17,10 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.entities.application_entities import InvokeFrom +from core.entities.application_entities import InvokeFrom, AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -46,14 +46,10 @@ class ChatMessageListApi(Resource): @setup_required @login_required + @get_app_model(mode=AppMode.CHAT) @account_initialization_required @marshal_with(message_infinite_scroll_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id, 'chat') - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args') @@ -62,7 +58,7 @@ class ChatMessageListApi(Resource): conversation = db.session.query(Conversation).filter( Conversation.id == args['conversation_id'], - Conversation.app_id == app.id + Conversation.app_id == app_model.id ).first() if not conversation: @@ -110,12 +106,8 @@ class MessageFeedbackApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id) - + @get_app_model + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('message_id', required=True, type=uuid_value, location='json') parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') @@ -125,7 +117,7 @@ class MessageFeedbackApi(Resource): message = db.session.query(Message).filter( Message.id == message_id, - Message.app_id == app.id + Message.app_id == app_model.id ).first() if not message: @@ -141,7 +133,7 @@ class MessageFeedbackApi(Resource): raise ValueError('rating cannot be None when feedback not exists') else: feedback = MessageFeedback( - app_id=app.id, + app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, rating=args['rating'], @@ -160,21 +152,20 @@ class MessageAnnotationApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check('annotation') + @get_app_model @marshal_with(annotation_fields) - def post(self, app_id): + def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - app_id = str(app_id) - parser = reqparse.RequestParser() parser.add_argument('message_id', required=False, type=uuid_value, location='json') parser.add_argument('question', required=True, type=str, location='json') parser.add_argument('answer', required=True, type=str, location='json') parser.add_argument('annotation_reply', required=False, type=dict, location='json') args = parser.parse_args() - annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id) + annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) return annotation @@ -183,14 +174,10 @@ class MessageAnnotationCountApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id) - + @get_app_model + def get(self, app_model): count = db.session.query(MessageAnnotation).filter( - MessageAnnotation.app_id == app.id + MessageAnnotation.app_id == app_model.id ).count() return {'count': count} @@ -200,8 +187,8 @@ class MessageMoreLikeThisApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id, message_id): - app_id = str(app_id) + @get_app_model(mode=AppMode.COMPLETION) + def get(self, app_model, message_id): message_id = str(message_id) parser = reqparse.RequestParser() @@ -211,9 +198,6 @@ class MessageMoreLikeThisApi(Resource): streaming = args['response_mode'] == 'streaming' - # get app info - app_model = _get_app(app_id, 'completion') - try: response = CompletionService.generate_more_like_this( app_model=app_model, @@ -257,13 +241,10 @@ class MessageSuggestedQuestionApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id, message_id): - app_id = str(app_id) + @get_app_model(mode=AppMode.CHAT) + def get(self, app_model, message_id): message_id = str(message_id) - # get app info - app_model = _get_app(app_id, 'chat') - try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, @@ -294,14 +275,11 @@ class MessageApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(message_detail_fields) - def get(self, app_id, message_id): - app_id = str(app_id) + def get(self, app_model, message_id): message_id = str(message_id) - # get app info - app_model = _get_app(app_id) - message = db.session.query(Message).filter( Message.id == message_id, Message.app_id == app_model.id diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 117007d055..912c4eab9a 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -5,7 +5,7 @@ from flask_login import current_user from flask_restful import Resource from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.entities.application_entities import AgentToolEntity @@ -23,22 +23,19 @@ class ModelConfigResource(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): + @get_app_model + def post(self, app_model): """Modify app model config""" - app_id = str(app_id) - - app = _get_app(app_id) - # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, account=current_user, config=request.json, - app_mode=app.mode + app_mode=app_model.mode ) new_app_model_config = AppModelConfig( - app_id=app.id, + app_id=app_model.id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) @@ -121,11 +118,11 @@ class ModelConfigResource(Resource): db.session.add(new_app_model_config) db.session.flush() - app.app_model_config_id = new_app_model_config.id + app_model.app_model_config_id = new_app_model_config.id db.session.commit() app_model_config_was_updated.send( - app, + app_model, app_model_config=new_app_model_config ) diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 4e9d9ed9b4..256824981e 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -4,7 +4,7 @@ from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db @@ -34,13 +34,11 @@ class AppSite(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_site_fields) - def post(self, app_id): + def post(self, app_model): args = parse_app_site_args() - app_id = str(app_id) - app_model = _get_app(app_id) - # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() @@ -82,11 +80,9 @@ class AppSiteAccessTokenReset(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_site_fields) - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id) - + def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 7aed7da404..e3bc44d6e9 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -7,9 +7,10 @@ from flask_login import current_user from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode from extensions.ext_database import db from libs.helper import datetime_string from libs.login import login_required @@ -20,10 +21,9 @@ class DailyConversationStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -81,10 +81,9 @@ class DailyTerminalsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -141,10 +140,9 @@ class DailyTokenCostStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -205,10 +203,9 @@ class AverageSessionInteractionStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model(mode=AppMode.CHAT) + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id, 'chat') parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -271,10 +268,9 @@ class UserSatisfactionRateStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -334,10 +330,9 @@ class AverageResponseTimeStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model(mode=AppMode.WORKFLOW) + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id, 'completion') parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -396,10 +391,9 @@ class TokensPerSecondStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py new file mode 100644 index 0000000000..5a08e31c16 --- /dev/null +++ b/api/controllers/console/app/workflow.py @@ -0,0 +1,20 @@ +from flask_restful import Resource + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode +from libs.login import login_required + + +class DefaultBlockConfigApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW]) + def post(self, app_model): + return 'success', 200 + + +api.add_resource(DefaultBlockConfigApi, '/apps//default-workflow-block-configs') diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py new file mode 100644 index 0000000000..b3aca51871 --- /dev/null +++ b/api/controllers/console/app/wraps.py @@ -0,0 +1,55 @@ +from functools import wraps +from typing import Union, Optional, Callable + +from controllers.console.app.error import AppNotFoundError +from core.entities.application_entities import AppMode +from extensions.ext_database import db +from libs.login import current_user +from models.model import App + + +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): + def decorator(view_func): + @wraps(view_func) + def decorated_view(*args, **kwargs): + if not kwargs.get('app_id'): + raise ValueError('missing app_id in path parameters') + + app_id = kwargs.get('app_id') + app_id = str(app_id) + + del kwargs['app_id'] + + app_model = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app_model: + raise AppNotFoundError() + + app_mode = AppMode.value_of(app_model.mode) + if mode is not None: + if isinstance(mode, list): + modes = mode + else: + modes = [mode] + + # [temp] if workflow is in the mode list, then completion should be in the mode list + if AppMode.WORKFLOW in modes: + modes.append(AppMode.COMPLETION) + + if app_mode not in modes: + mode_values = {m.value for m in modes} + raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") + + kwargs['app_model'] = app_model + + return view_func(*args, **kwargs) + return decorated_view + + if view is None: + return decorator + else: + return decorator(view) diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index d3c91337c8..d1e16f860c 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -4,12 +4,12 @@ from typing import Optional from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity +from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity, \ + AppMode from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException -from core.prompt.prompt_transform import AppMode from extensions.ext_database import db from models.model import App, Conversation, Message diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index abcf605c92..d3231affb2 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -9,6 +9,26 @@ from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import AIModelEntity +class AppMode(Enum): + COMPLETION = 'completion' # will be deprecated in the future + WORKFLOW = 'workflow' # instead of 'completion' + CHAT = 'chat' + AGENT = 'agent' + + @classmethod + def value_of(cls, value: str) -> 'AppMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + class ModelConfigEntity(BaseModel): """ Model Config Entity. diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 0a373b7c42..08d94661b7 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -7,7 +7,7 @@ from typing import Optional, cast from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, ModelConfigEntity, - PromptTemplateEntity, + PromptTemplateEntity, AppMode, ) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -25,24 +25,6 @@ from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import PromptTemplateParser -class AppMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' - - @classmethod - def value_of(cls, value: str) -> 'AppMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') - - class ModelMode(enum.Enum): COMPLETION = 'completion' CHAT = 'chat' diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d52f6e20c2..3cf58d8e09 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,6 +1,7 @@ import copy +from core.entities.application_entities import AppMode from core.prompt.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, @@ -13,7 +14,6 @@ from core.prompt.advanced_prompt_templates import ( COMPLETION_APP_COMPLETION_PROMPT_CONFIG, CONTEXT, ) -from core.prompt.prompt_transform import AppMode class AdvancedPromptTemplateService: diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 2e21e56266..ccfb101405 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -2,11 +2,11 @@ import re import uuid from core.entities.agent_entities import PlanningStrategy +from core.entities.application_entities import AppMode from core.external_data_tool.factory import ExternalDataToolFactory from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory from core.moderation.factory import ModerationFactory -from core.prompt.prompt_transform import AppMode from core.provider_manager import ProviderManager from models.account import Account from services.dataset_service import DatasetService