diff --git a/api/constants/model_template.py b/api/constants/model_template.py index d87f7c3926..c22306ac87 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,10 +1,10 @@ import json model_templates = { - # completion default mode - 'completion_default': { + # workflow default mode + 'workflow_default': { 'app': { - 'mode': 'completion', + 'mode': 'workflow', 'enable_site': True, 'enable_api': True, 'is_demo': False, @@ -15,24 +15,7 @@ model_templates = { 'model_config': { 'provider': '', 'model_id': '', - 'configs': {}, - 'model': json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": {} - }), - 'user_input_form': json.dumps([ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]), - 'pre_prompt': '{{query}}' + 'configs': {} } }, @@ -48,14 +31,70 @@ model_templates = { 'status': 'normal' }, 'model_config': { - 'provider': '', - 'model_id': '', - 'configs': {}, + 'provider': 'openai', + 'model_id': 'gpt-4', + 'configs': { + 'prompt_template': '', + 'prompt_variables': [], + 'completion_params': { + 'max_token': 512, + 'temperature': 1, + 'top_p': 1, + 'presence_penalty': 0, + 'frequency_penalty': 0, + } + }, 'model': json.dumps({ "provider": "openai", - "name": "gpt-3.5-turbo", + "name": "gpt-4", "mode": "chat", - "completion_params": {} + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 + } + }) + } + }, + + # agent default mode + 'agent_default': { + 'app': { + 'mode': 'agent', + 'enable_site': True, + 'enable_api': True, + 'is_demo': False, + 'api_rpm': 0, + 'api_rph': 0, + 'status': 'normal' + }, + 'model_config': { + 'provider': 'openai', + 'model_id': 'gpt-4', + 'configs': { + 'prompt_template': '', + 'prompt_variables': [], + 'completion_params': { + 'max_token': 512, + 'temperature': 1, + 'top_p': 1, + 'presence_penalty': 0, + 'frequency_penalty': 0, + } + }, + 'model': json.dumps({ + "provider": "openai", + "name": "gpt-4", + "mode": "chat", + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 + } }) } }, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 934b19116b..649df278ec 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -5,7 +5,7 @@ bp = Blueprint('console', __name__, url_prefix='/console/api') api = ExternalApi(bp) # Import other controllers -from . import admin, apikey, extension, feature, setup, version +from . import admin, apikey, extension, feature, setup, version, ping # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, model_config, site, statistic, workflow) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index f291f8e81a..8e6da3bd4f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -26,7 +26,7 @@ from fields.app_fields import ( template_list_fields, ) from libs.login import login_required -from models.model import App, AppModelConfig, Site +from models.model import App, AppModelConfig, Site, AppMode from services.app_model_config_service import AppModelConfigService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager @@ -80,7 +80,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=['completion', 'chat', 'assistant'], location='json') + parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') parser.add_argument('model_config', type=dict, location='json') @@ -90,18 +90,7 @@ class AppListApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - try: - provider_manager = ProviderManager() - default_model_entity = provider_manager.get_default_model( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except (ProviderTokenNotInitError, LLMBadRequestError): - default_model_entity = None - except Exception as e: - logging.exception(e) - default_model_entity = None - + # TODO: MOVE TO IMPORT API if args['model_config'] is not None: # validate config model_config_dict = args['model_config'] @@ -150,27 +139,30 @@ class AppListApi(Resource): if 'mode' not in args or args['mode'] is None: abort(400, message="mode is required") - model_config_template = model_templates[args['mode'] + '_default'] + app_mode = AppMode.value_of(args['mode']) + + model_config_template = model_templates[app_mode.value + '_default'] app = App(**model_config_template['app']) app_model_config = AppModelConfig(**model_config_template['model_config']) - # get model provider - model_manager = ModelManager() + if app_mode in [AppMode.CHAT, AppMode.AGENT]: + # get model provider + model_manager = ModelManager() - try: - model_instance = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except ProviderTokenNotInitError: - model_instance = None + try: + model_instance = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, + model_type=ModelType.LLM + ) + except ProviderTokenNotInitError: + model_instance = None - if model_instance: - model_dict = app_model_config.model_dict - model_dict['provider'] = model_instance.provider - model_dict['name'] = model_instance.model - app_model_config.model = json.dumps(model_dict) + if model_instance: + model_dict = app_model_config.model_dict + model_dict['provider'] = model_instance.provider + model_dict['name'] = model_instance.model + app_model_config.model = json.dumps(model_dict) app.name = args['name'] app.mode = args['mode'] diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index daa5570f9a..458fa5098f 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -20,10 +20,10 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required +from models.model import AppMode from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 381d0bbb6b..11fdba177d 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -22,11 +22,12 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import AppMode, InvokeFrom +from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from libs.login import login_required +from models.model import AppMode from services.completion_service import CompletionService diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 4ee1ee4035..5d312149f7 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -12,7 +12,6 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode from extensions.ext_database import db from fields.conversation_fields import ( conversation_detail_fields, @@ -22,7 +21,7 @@ from fields.conversation_fields import ( ) from libs.helper import datetime_string from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation +from models.model import Conversation, Message, MessageAnnotation, AppMode class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index d7b31906c8..b1abb38248 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -85,3 +85,9 @@ class TooManyFilesError(BaseHTTPException): error_code = 'too_many_files' description = "Only one file is allowed." code = 400 + + +class DraftWorkflowNotExist(BaseHTTPException): + error_code = 'draft_workflow_not_exist' + description = "Draft workflow need to be initialized." + code = 400 diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 5d4f6b7e26..9a177116ea 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -11,7 +11,6 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api from controllers.console.app.error import ( - AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -20,7 +19,6 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.entities.application_entities import AppMode, InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -28,10 +26,8 @@ from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, MessageFeedback +from models.model import Conversation, Message, MessageAnnotation, MessageFeedback, AppMode from services.annotation_service import AppAnnotationService -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError from services.message_service import MessageService @@ -183,49 +179,6 @@ class MessageAnnotationCountApi(Resource): return {'count': count} -class MessageMoreLikeThisApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=AppMode.COMPLETION) - def get(self, app_model, message_id): - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], - location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming - ) - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception as e: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -291,7 +244,6 @@ class MessageApi(Resource): return message -api.add_resource(MessageMoreLikeThisApi, '/apps//completion-messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') api.add_resource(ChatMessageListApi, '/apps//chat-messages', endpoint='console_chat_messages') api.add_resource(MessageFeedbackApi, '/apps//feedbacks') diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index e3bc44d6e9..ea4d597112 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -10,10 +10,10 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode from extensions.ext_database import db from libs.helper import datetime_string from libs.login import login_required +from models.model import AppMode class DailyConversationStatistic(Resource): diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5689c0fd92..2794735bbb 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,30 +1,88 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse, marshal_with from controllers.console import api +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode -from libs.login import login_required +from fields.workflow_fields import workflow_fields +from libs.login import login_required, current_user +from models.model import App, ChatbotAppEngine, AppMode +from services.workflow_service import WorkflowService + + +class DraftWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + @marshal_with(workflow_fields) + def get(self, app_model: App): + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model=app_model) + + if not workflow: + raise DraftWorkflowNotExist() + + # return workflow, if not found, return None (initiate graph by frontend) + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + def post(self, app_model: App): + """ + Sync draft workflow + """ + parser = reqparse.RequestParser() + parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + workflow_service = WorkflowService() + workflow_service.sync_draft_workflow(app_model=app_model, graph=args.get('graph'), account=current_user) + + return { + "result": "success" + } class DefaultBlockConfigApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - parser = reqparse.RequestParser() - parser.add_argument('app_mode', type=str, required=True, nullable=False, - choices=[AppMode.CHAT.value, AppMode.WORKFLOW.value], location='args') - args = parser.parse_args() - - app_mode = args.get('app_mode') - app_mode = AppMode.value_of(app_mode) - - # TODO: implement this - - return { - "blocks": [] - } + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + def get(self, app_model: App): + """ + Get default block config + """ + # Get default block configs + workflow_service = WorkflowService() + return workflow_service.get_default_block_configs() -api.add_resource(DefaultBlockConfigApi, '/default-workflow-block-configs') +class ConvertToWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=AppMode.CHAT) + @marshal_with(workflow_fields) + def post(self, app_model: App): + """ + Convert basic mode of chatbot app to workflow + """ + # convert to workflow mode + workflow_service = WorkflowService() + workflow = workflow_service.chatbot_convert_to_workflow(app_model=app_model) + + # return workflow + return workflow + + +api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs') +api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index fe2b408702..fe35e72304 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -3,13 +3,14 @@ from functools import wraps from typing import Optional, Union from controllers.console.app.error import AppNotFoundError -from core.entities.application_entities import AppMode from extensions.ext_database import db from libs.login import current_user -from models.model import App +from models.model import App, ChatbotAppEngine, AppMode -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, + mode: Union[AppMode, list[AppMode]] = None, + app_engine: ChatbotAppEngine = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): @@ -37,14 +38,20 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ else: modes = [mode] - # [temp] if workflow is in the mode list, then completion should be in the mode list - if AppMode.WORKFLOW in modes: - modes.append(AppMode.COMPLETION) - if app_mode not in modes: mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") + if app_engine is not None: + if app_mode not in [AppMode.CHAT, AppMode.WORKFLOW]: + raise AppNotFoundError(f"App mode is not supported for {app_engine.value} app engine.") + + if app_mode == AppMode.CHAT: + # fetch current app model config + app_model_config = app_model.app_model_config + if not app_model_config or app_model_config.chatbot_app_engine != app_engine.value: + raise AppNotFoundError(f"{app_engine.value} app engine is not supported.") + kwargs['app_model'] = app_model return view_func(*args, **kwargs) diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 47af28425f..bef26b4d99 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -12,7 +12,6 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api from controllers.console.app.error import ( - AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -24,13 +23,10 @@ from controllers.console.explore.error import ( NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource -from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs.helper import uuid_value -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -76,48 +72,6 @@ class MessageFeedbackApi(InstalledAppResource): return {'result': 'success'} -class MessageMoreLikeThisApi(InstalledAppResource): - def get(self, installed_app, message_id): - app_model = installed_app.app - if app_model.mode != 'completion': - raise NotCompletionAppError() - - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.EXPLORE, - streaming=streaming - ) - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -166,5 +120,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource): api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') -api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py new file mode 100644 index 0000000000..7664ba8c16 --- /dev/null +++ b/api/controllers/console/ping.py @@ -0,0 +1,17 @@ +from flask_restful import Resource + +from controllers.console import api + + +class PingApi(Resource): + + def get(self): + """ + For connection health check + """ + return { + "result": "pong" + } + + +api.add_resource(PingApi, '/ping') diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index b7cfba9d04..656a4d4cee 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -16,26 +16,13 @@ from controllers.console.workspace.error import ( ) from controllers.console.wraps import account_initialization_required from extensions.ext_database import db +from fields.member_fields import account_fields from libs.helper import TimestampField, timezone from libs.login import login_required from models.account import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'is_password_set': fields.Boolean, - 'interface_language': fields.String, - 'interface_theme': fields.String, - 'timezone': fields.String, - 'last_login_at': TimestampField, - 'last_login_ip': fields.String, - 'created_at': TimestampField -} - class AccountInitApi(Resource): diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index cf57cd4b24..f40ccebf25 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,33 +1,18 @@ from flask import current_app from flask_login import current_user -from flask_restful import Resource, abort, fields, marshal_with, reqparse +from flask_restful import Resource, abort, marshal_with, reqparse import services from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from libs.helper import TimestampField +from fields.member_fields import account_with_role_list_fields from libs.login import login_required from models.account import Account from services.account_service import RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'last_login_at': TimestampField, - 'created_at': TimestampField, - 'role': fields.String, - 'status': fields.String, -} - -account_list_fields = { - 'accounts': fields.List(fields.Nested(account_fields)) -} - class MemberListApi(Resource): """List all members of current tenant.""" @@ -35,7 +20,7 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_list_fields) + @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_tenant_members(current_user.current_tenant) return {'result': 'success', 'accounts': members}, 200 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index e03bdd63bb..5120f49c5e 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -11,7 +11,6 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.web import api from controllers.web.error import ( - AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, CompletionRequestError, NotChatAppError, @@ -21,14 +20,11 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields from libs.helper import TimestampField, uuid_value -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -113,48 +109,6 @@ class MessageFeedbackApi(WebApiResource): return {'result': 'success'} -class MessageMoreLikeThisApi(WebApiResource): - def get(self, app_model, end_user, message_id): - if app_model.mode != 'completion': - raise NotCompletionAppError() - - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming - ) - - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -202,5 +156,4 @@ class MessageSuggestedQuestionApi(WebApiResource): api.add_resource(MessageListApi, '/messages') api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index d87302c717..26e9cc84aa 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -6,7 +6,6 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, - AppMode, DatasetEntity, InvokeFrom, ModelConfigEntity, @@ -16,7 +15,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from extensions.ext_database import db -from models.model import App, Conversation, Message +from models.model import App, Conversation, Message, AppMode logger = logging.getLogger(__name__) @@ -250,6 +249,7 @@ class BasicApplicationRunner(AppRunner): invoke_from ) + # TODO if (app_record.mode == AppMode.COMPLETION.value and dataset_config and dataset_config.retrieve_config.query_variable): query = inputs.get(dataset_config.retrieve_config.query_variable, "") diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 9aca61c7bb..2fde422d47 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -28,7 +28,7 @@ from core.entities.application_entities import ( ModelConfigEntity, PromptTemplateEntity, SensitiveWordAvoidanceEntity, - TextToSpeechEntity, + TextToSpeechEntity, VariableEntity, ) from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError @@ -93,7 +93,7 @@ class ApplicationManager: app_id=app_id, app_model_config_id=app_model_config_id, app_model_config_dict=app_model_config_dict, - app_orchestration_config_entity=self._convert_from_app_model_config_dict( + app_orchestration_config_entity=self.convert_from_app_model_config_dict( tenant_id=tenant_id, app_model_config_dict=app_model_config_dict ), @@ -234,7 +234,7 @@ class ApplicationManager: logger.exception(e) raise e - def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ + def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ -> AppOrchestrationConfigEntity: """ Convert app model config dict to entity. @@ -384,8 +384,10 @@ class ApplicationManager: config=external_data_tool['config'] ) ) + + properties['variables'] = [] - # current external_data_tools + # variables and external_data_tools for variable in copy_app_model_config_dict.get('user_input_form', []): typ = list(variable.keys())[0] if typ == 'external_data_tool': @@ -397,6 +399,30 @@ class ApplicationManager: config=val['config'] ) ) + elif typ in [VariableEntity.Type.TEXT_INPUT.value, VariableEntity.Type.PARAGRAPH.value]: + properties['variables'].append( + VariableEntity( + type=VariableEntity.Type.TEXT_INPUT, + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + max_length=variable[typ].get('max_length'), + default=variable[typ].get('default'), + ) + ) + elif typ == VariableEntity.Type.SELECT.value: + properties['variables'].append( + VariableEntity( + type=VariableEntity.Type.SELECT, + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + options=variable[typ].get('options'), + default=variable[typ].get('default'), + ) + ) # show retrieve source show_retrieve_source = False diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index d3231affb2..092591a73f 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -9,26 +9,6 @@ from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import AIModelEntity -class AppMode(Enum): - COMPLETION = 'completion' # will be deprecated in the future - WORKFLOW = 'workflow' # instead of 'completion' - CHAT = 'chat' - AGENT = 'agent' - - @classmethod - def value_of(cls, value: str) -> 'AppMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') - - class ModelConfigEntity(BaseModel): """ Model Config Entity. @@ -106,6 +86,38 @@ class PromptTemplateEntity(BaseModel): advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None +class VariableEntity(BaseModel): + """ + Variable Entity. + """ + class Type(Enum): + TEXT_INPUT = 'text-input' + SELECT = 'select' + PARAGRAPH = 'paragraph' + + @classmethod + def value_of(cls, value: str) -> 'VariableEntity.Type': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid variable type value {value}') + + variable: str + label: str + description: Optional[str] = None + type: Type + required: bool = False + max_length: Optional[int] = None + options: Optional[list[str]] = None + default: Optional[str] = None + + class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. @@ -245,6 +257,7 @@ class AppOrchestrationConfigEntity(BaseModel): """ model_config: ModelConfigEntity prompt_template: PromptTemplateEntity + variables: list[VariableEntity] = [] external_data_variables: list[ExternalDataVariableEntity] = [] agent: Optional[AgentEntity] = None @@ -256,7 +269,7 @@ class AppOrchestrationConfigEntity(BaseModel): show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False - text_to_speech: dict = {} + text_to_speech: Optional[TextToSpeechEntity] = None sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 4bf96ce265..abbfa96249 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -6,7 +6,6 @@ from typing import Optional, cast from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, - AppMode, ModelConfigEntity, PromptTemplateEntity, ) @@ -24,6 +23,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import PromptTemplateParser +from models.model import AppMode class ModelMode(enum.Enum): diff --git a/api/core/workflow/__init__.py b/api/core/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/entities/NodeEntities.py b/api/core/workflow/entities/NodeEntities.py new file mode 100644 index 0000000000..d72b000dfb --- /dev/null +++ b/api/core/workflow/entities/NodeEntities.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class NodeType(Enum): + """ + Node Types. + """ + START = 'start' + END = 'end' + DIRECT_ANSWER = 'direct-answer' + LLM = 'llm' + KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' + IF_ELSE = 'if-else' + CODE = 'code' + TEMPLATE_TRANSFORM = 'template-transform' + QUESTION_CLASSIFIER = 'question-classifier' + HTTP_REQUEST = 'http-request' + TOOL = 'tool' + VARIABLE_ASSIGNER = 'variable-assigner' + + @classmethod + def value_of(cls, value: str) -> 'BlockType': + """ + Get value of given block type. + + :param value: block type value + :return: block type + """ + for block_type in cls: + if block_type.value == value: + return block_type + raise ValueError(f'invalid block type value {value}') diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/end/__init__.py b/api/core/workflow/nodes/end/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py new file mode 100644 index 0000000000..045e7effc4 --- /dev/null +++ b/api/core/workflow/nodes/end/entities.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class EndNodeOutputType(Enum): + """ + END Node Output Types. + + none, plain-text, structured + """ + NONE = 'none' + PLAIN_TEXT = 'plain-text' + STRUCTURED = 'structured' + + @classmethod + def value_of(cls, value: str) -> 'OutputType': + """ + Get value of given output type. + + :param value: output type value + :return: output type + """ + for output_type in cls: + if output_type.value == value: + return output_type + raise ValueError(f'invalid output type value {value}') diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 5974de34de..d9cd6c03bb 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -2,12 +2,6 @@ from flask_restful import fields from libs.helper import TimestampField -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - annotation_fields = { "id": fields.String, @@ -15,7 +9,7 @@ annotation_fields = { "answer": fields.Raw(attribute='content'), "hit_count": fields.Integer, "created_at": TimestampField, - # 'account': fields.Nested(account_fields, allow_null=True) + # 'account': fields.Nested(simple_account_fields, allow_null=True) } annotation_list_fields = { diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 1adc836aa2..afa486f1cd 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,5 +1,6 @@ from flask_restful import fields +from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -8,31 +9,25 @@ class MessageTextField(fields.Raw): return value[0]['text'] if value else '' -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - feedback_fields = { 'rating': fields.String, 'content': fields.String, 'from_source': fields.String, 'from_end_user_id': fields.String, - 'from_account': fields.Nested(account_fields, allow_null=True), + 'from_account': fields.Nested(simple_account_fields, allow_null=True), } annotation_fields = { 'id': fields.String, 'question': fields.String, 'content': fields.String, - 'account': fields.Nested(account_fields, allow_null=True), + 'account': fields.Nested(simple_account_fields, allow_null=True), 'created_at': TimestampField } annotation_hit_history_fields = { 'annotation_id': fields.String(attribute='id'), - 'annotation_create_account': fields.Nested(account_fields, allow_null=True), + 'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True), 'created_at': TimestampField } diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py new file mode 100644 index 0000000000..79164b3848 --- /dev/null +++ b/api/fields/member_fields.py @@ -0,0 +1,38 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +simple_account_fields = { + 'id': fields.String, + 'name': fields.String, + 'email': fields.String +} + +account_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'is_password_set': fields.Boolean, + 'interface_language': fields.String, + 'interface_theme': fields.String, + 'timezone': fields.String, + 'last_login_at': TimestampField, + 'last_login_ip': fields.String, + 'created_at': TimestampField +} + +account_with_role_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'last_login_at': TimestampField, + 'created_at': TimestampField, + 'role': fields.String, + 'status': fields.String, +} + +account_with_role_list_fields = { + 'accounts': fields.List(fields.Nested(account_with_role_fields)) +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py new file mode 100644 index 0000000000..9dc92ea43b --- /dev/null +++ b/api/fields/workflow_fields.py @@ -0,0 +1,16 @@ +import json + +from flask_restful import fields + +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + + +workflow_fields = { + 'id': fields.String, + 'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), + 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), + 'created_at': TimestampField, + 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), + 'updated_at': TimestampField +} diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 605c66bed1..e9cd2caf3a 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -102,7 +102,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_pkey') ) with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'type', 'version'], unique=False) + batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.add_column(sa.Column('chatbot_app_engine', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False)) diff --git a/api/models/model.py b/api/models/model.py index 6c726928eb..6a0e5df568 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,5 +1,7 @@ import json import uuid +from enum import Enum +from typing import Optional from flask import current_app, request from flask_login import UserMixin @@ -25,6 +27,25 @@ class DifySetup(db.Model): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) +class AppMode(Enum): + WORKFLOW = 'workflow' + CHAT = 'chat' + AGENT = 'agent' + + @classmethod + def value_of(cls, value: str) -> 'AppMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + class App(db.Model): __tablename__ = 'apps' __table_args__ = ( @@ -56,7 +77,7 @@ class App(db.Model): return site @property - def app_model_config(self): + def app_model_config(self) -> Optional['AppModelConfig']: app_model_config = db.session.query(AppModelConfig).filter( AppModelConfig.id == self.app_model_config_id).first() return app_model_config @@ -130,6 +151,12 @@ class App(db.Model): return deleted_tools + +class ChatbotAppEngine(Enum): + NORMAL = 'normal' + WORKFLOW = 'workflow' + + class AppModelConfig(db.Model): __tablename__ = 'app_model_configs' __table_args__ = ( diff --git a/api/models/workflow.py b/api/models/workflow.py index 59b8eeb6cd..ed26e98896 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,6 +1,43 @@ +from enum import Enum +from typing import Union + from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db +from models.account import Account +from models.model import AppMode + + +class WorkflowType(Enum): + """ + Workflow Type Enum + """ + WORKFLOW = 'workflow' + CHAT = 'chat' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowType': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow type value {value}') + + @classmethod + def from_app_mode(cls, app_mode: Union[str, AppMode]) -> 'WorkflowType': + """ + Get workflow type from app mode. + + :param app_mode: app mode + :return: workflow type + """ + app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) + return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT class Workflow(db.Model): @@ -39,7 +76,7 @@ class Workflow(db.Model): __tablename__ = 'workflows' __table_args__ = ( db.PrimaryKeyConstraint('id', name='workflow_pkey'), - db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'type', 'version'), + db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), ) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) @@ -53,6 +90,14 @@ class Workflow(db.Model): updated_by = db.Column(UUID) updated_at = db.Column(db.DateTime) + @property + def created_by_account(self): + return Account.query.get(self.created_by) + + @property + def updated_by_account(self): + return Account.query.get(self.updated_by) + class WorkflowRun(db.Model): """ @@ -116,6 +161,14 @@ class WorkflowRun(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) finished_at = db.Column(db.DateTime) + @property + def created_by_account(self): + return Account.query.get(self.created_by) + + @property + def updated_by_account(self): + return Account.query.get(self.updated_by) + class WorkflowNodeExecution(db.Model): """ diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 3cf58d8e09..1e893e0eca 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,7 +1,6 @@ import copy -from core.entities.application_entities import AppMode from core.prompt.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, @@ -14,6 +13,7 @@ from core.prompt.advanced_prompt_templates import ( COMPLETION_APP_COMPLETION_PROMPT_CONFIG, CONTEXT, ) +from models.model import AppMode class AdvancedPromptTemplateService: diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index ccfb101405..3ac11c645c 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -9,6 +9,7 @@ from core.model_runtime.model_providers import model_provider_factory from core.moderation.factory import ModerationFactory from core.provider_manager import ProviderManager from models.account import Account +from models.model import AppMode from services.dataset_service import DatasetService SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -315,9 +316,6 @@ class AppModelConfigService: if "tool_parameters" not in tool: raise ValueError("tool_parameters is required in agent_mode.tools") - # dataset_query_variable - cls.is_dataset_query_variable_valid(config, app_mode) - # advanced prompt validation cls.is_advanced_prompt_valid(config, app_mode) @@ -443,21 +441,6 @@ class AppModelConfigService: config=config ) - @classmethod - def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: - # Only check when mode is completion - if mode != 'completion': - return - - agent_mode = config.get("agent_mode", {}) - tools = agent_mode.get("tools", []) - dataset_exists = "dataset" in str(tools) - - dataset_query_variable = config.get("dataset_query_variable") - - if dataset_exists and not dataset_query_variable: - raise ValueError("Dataset query variable is required when dataset is exist") - @classmethod def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: # prompt_type diff --git a/api/services/completion_service.py b/api/services/completion_service.py index cbfbe9ef41..5599c60113 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -8,12 +8,10 @@ from core.application_manager import ApplicationManager from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message +from models.model import Account, App, AppModelConfig, Conversation, EndUser from services.app_model_config_service import AppModelConfigService -from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError -from services.errors.message import MessageNotExistsError class CompletionService: @@ -157,62 +155,6 @@ class CompletionService: } ) - @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], - message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ - -> Union[dict, Generator]: - if not user: - raise ValueError('user cannot be None') - - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() - - if not message: - raise MessageNotExistsError() - - current_app_model_config = app_model.app_model_config - more_like_this = current_app_model_config.more_like_this_dict - - if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: - raise MoreLikeThisDisabledError() - - app_model_config = message.app_model_config - model_dict = app_model_config.model_dict - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - app_model_config.model = json.dumps(model_dict) - - # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_objs = message_file_parser.transform_message_files( - message.files, app_model_config - ) - - application_manager = ApplicationManager() - return application_manager.generate( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_model_config_id=app_model_config.id, - app_model_config_dict=app_model_config.to_dict(), - app_model_config_override=True, - user=user, - invoke_from=invoke_from, - inputs=message.inputs, - query=message.query, - files=file_objs, - conversation=None, - stream=streaming, - extras={ - "auto_generate_conversation_name": False - } - ) - @classmethod def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): if user_inputs is None: diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index 5804f599fe..a44c190cbc 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- __all__ = [ 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', - 'app', 'completion', 'audio', 'file' + 'completion', 'audio', 'file' ] from . import * diff --git a/api/services/errors/app.py b/api/services/errors/app.py deleted file mode 100644 index 7c4ca99c2a..0000000000 --- a/api/services/errors/app.py +++ /dev/null @@ -1,2 +0,0 @@ -class MoreLikeThisDisabledError(Exception): - pass diff --git a/api/services/workflow/__init__.py b/api/services/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/workflow/defaults.py b/api/services/workflow/defaults.py new file mode 100644 index 0000000000..67804fa4eb --- /dev/null +++ b/api/services/workflow/defaults.py @@ -0,0 +1,72 @@ +# default block config +default_block_configs = [ + { + "type": "llm", + "config": { + "prompt_templates": { + "chat_model": { + "prompts": [ + { + "role": "system", + "text": "You are a helpful AI assistant." + } + ] + }, + "completion_model": { + "conversation_histories_role": { + "user_prefix": "Human", + "assistant_prefix": "Assistant" + }, + "prompt": { + "text": "Here is the chat histories between human and assistant, inside " + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant:" + }, + "stop": ["Human:"] + } + } + } + }, + { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": "python3", + "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " + "+ arg2\n }", + "outputs": [ + { + "variable": "result", + "variable_type": "number" + } + ] + } + }, + { + "type": "template-transform", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + } + ], + "template": "{{ arg1 }}" + } + }, + { + "type": "question-classifier", + "config": { + "instructions": "" # TODO + } + } +] diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py new file mode 100644 index 0000000000..c2fad83aaf --- /dev/null +++ b/api/services/workflow/workflow_converter.py @@ -0,0 +1,259 @@ +import json +from typing import Optional + +from core.application_manager import ApplicationManager +from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, FileUploadEntity, \ + ExternalDataVariableEntity, DatasetEntity, VariableEntity +from core.model_runtime.utils import helper +from core.workflow.entities.NodeEntities import NodeType +from core.workflow.nodes.end.entities import EndNodeOutputType +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, ChatbotAppEngine +from models.workflow import Workflow, WorkflowType + + +class WorkflowConverter: + """ + App Convert to Workflow Mode + """ + + def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: + """ + Convert to workflow mode + + - basic mode of chatbot app + + - advanced mode of assistant app (for migration) + + - completion app (for migration) + + :param app_model: App instance + :param account: Account instance + :return: workflow instance + """ + # get original app config + app_model_config = app_model.app_model_config + + # convert app model config + application_manager = ApplicationManager() + application_manager.convert_from_app_model_config_dict( + tenant_id=app_model.tenant_id, + app_model_config_dict=app_model_config.to_dict() + ) + + # init workflow graph + graph = { + "nodes": [], + "edges": [] + } + + # Convert list: + # - variables -> start + # - model_config -> llm + # - prompt_template -> llm + # - file_upload -> llm + # - external_data_variables -> http-request + # - dataset -> knowledge-retrieval + # - show_retrieve_source -> knowledge-retrieval + + # convert to start node + start_node = self._convert_to_start_node( + variables=app_model_config.variables + ) + + graph['nodes'].append(start_node) + + # convert to http request node + if app_model_config.external_data_variables: + http_request_node = self._convert_to_http_request_node( + external_data_variables=app_model_config.external_data_variables + ) + + graph = self._append_node(graph, http_request_node) + + # convert to knowledge retrieval node + if app_model_config.dataset: + knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( + dataset=app_model_config.dataset, + show_retrieve_source=app_model_config.show_retrieve_source + ) + + graph = self._append_node(graph, knowledge_retrieval_node) + + # convert to llm node + llm_node = self._convert_to_llm_node( + model_config=app_model_config.model_config, + prompt_template=app_model_config.prompt_template, + file_upload=app_model_config.file_upload + ) + + graph = self._append_node(graph, llm_node) + + # convert to end node by app mode + end_node = self._convert_to_end_node(app_model=app_model) + + graph = self._append_node(graph, end_node) + + # get new app mode + app_mode = self._get_new_app_mode(app_model) + + # create workflow record + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(app_mode).value, + version='draft', + graph=json.dumps(graph), + created_by=account.id + ) + + db.session.add(workflow) + db.session.flush() + + # create new app model config record + new_app_model_config = app_model_config.copy() + new_app_model_config.external_data_tools = '' + new_app_model_config.model = '' + new_app_model_config.user_input_form = '' + new_app_model_config.dataset_query_variable = None + new_app_model_config.pre_prompt = None + new_app_model_config.agent_mode = '' + new_app_model_config.prompt_type = 'simple' + new_app_model_config.chat_prompt_config = '' + new_app_model_config.completion_prompt_config = '' + new_app_model_config.dataset_configs = '' + new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ + if app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value + new_app_model_config.workflow_id = workflow.id + + db.session.add(new_app_model_config) + db.session.commit() + + return workflow + + def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: + """ + Convert to Start Node + :param variables: list of variables + :return: + """ + return { + "id": "start", + "position": None, + "data": { + "title": "START", + "type": NodeType.START.value, + "variables": [helper.dump_model(v) for v in variables] + } + } + + def _convert_to_http_request_node(self, external_data_variables: list[ExternalDataVariableEntity]) -> dict: + """ + Convert API Based Extension to HTTP Request Node + :param external_data_variables: list of external data variables + :return: + """ + # TODO: implement + pass + + def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset: DatasetEntity) -> dict: + """ + Convert datasets to Knowledge Retrieval Node + :param new_app_mode: new app mode + :param dataset: dataset + :return: + """ + # TODO: implement + if new_app_mode == AppMode.CHAT: + query_variable_selector = ["start", "sys.query"] + else: + pass + + return { + "id": "knowledge-retrieval", + "position": None, + "data": { + "title": "KNOWLEDGE RETRIEVAL", + "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + } + } + + def _convert_to_llm_node(self, model_config: ModelConfigEntity, + prompt_template: PromptTemplateEntity, + file_upload: Optional[FileUploadEntity] = None) -> dict: + """ + Convert to LLM Node + :param model_config: model config + :param prompt_template: prompt template + :param file_upload: file upload config (optional) + """ + # TODO: implement + pass + + def _convert_to_end_node(self, app_model: App) -> dict: + """ + Convert to End Node + :param app_model: App instance + :return: + """ + if app_model.mode == AppMode.CHAT.value: + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + } + } + elif app_model.mode == "completion": + # for original completion app + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + "outputs": { + "type": EndNodeOutputType.PLAIN_TEXT.value, + "plain_text_selector": ["llm", "text"] + } + } + } + + def _create_edge(self, source: str, target: str) -> dict: + """ + Create Edge + :param source: source node id + :param target: target node id + :return: + """ + return { + "id": f"{source}-{target}", + "source": source, + "target": target + } + + def _append_node(self, graph: dict, node: dict) -> dict: + """ + Append Node to Graph + + :param graph: Graph, include: nodes, edges + :param node: Node to append + :return: + """ + previous_node = graph['nodes'][-1] + graph['nodes'].append(node) + graph['edges'].append(self._create_edge(previous_node['id'], node['id'])) + return graph + + def _get_new_app_mode(self, app_model: App) -> AppMode: + """ + Get new app mode + :param app_model: App instance + :return: AppMode + """ + if app_model.mode == "completion": + return AppMode.WORKFLOW + else: + return AppMode.value_of(app_model.mode) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py new file mode 100644 index 0000000000..6a967e86ff --- /dev/null +++ b/api/services/workflow_service.py @@ -0,0 +1,83 @@ +import json +from datetime import datetime + +from extensions.ext_database import db +from models.account import Account +from models.model import App, ChatbotAppEngine +from models.workflow import Workflow, WorkflowType +from services.workflow.defaults import default_block_configs +from services.workflow.workflow_converter import WorkflowConverter + + +class WorkflowService: + """ + Workflow Service + """ + + def get_draft_workflow(self, app_model: App) -> Workflow: + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == 'draft' + ).first() + + # return draft workflow + return workflow + + def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow: + """ + Sync draft workflow + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(app_model.mode).value, + version='draft', + graph=json.dumps(graph), + created_by=account.id + ) + db.session.add(workflow) + # update draft workflow if found + else: + workflow.graph = json.dumps(graph) + workflow.updated_by = account.id + workflow.updated_at = datetime.utcnow() + + # commit db session changes + db.session.commit() + + # return draft workflow + return workflow + + def get_default_block_configs(self) -> dict: + """ + Get default block configs + """ + # return default block config + return default_block_configs + + def chatbot_convert_to_workflow(self, app_model: App) -> Workflow: + """ + basic mode of chatbot app to workflow + + :param app_model: App instance + :return: + """ + # check if chatbot app is in basic mode + if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL: + raise ValueError('Chatbot app already in workflow mode') + + # convert to workflow mode + workflow_converter = WorkflowConverter() + workflow = workflow_converter.convert_to_workflow(app_model=app_model) + + return workflow