mirror of https://github.com/langgenius/dify.git
add workflow logics
This commit is contained in:
parent
9ad6bd78f5
commit
f067947266
|
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
|
||||
model_templates = {
|
||||
# completion default mode
|
||||
'completion_default': {
|
||||
# workflow default mode
|
||||
'workflow_default': {
|
||||
'app': {
|
||||
'mode': 'completion',
|
||||
'mode': 'workflow',
|
||||
'enable_site': True,
|
||||
'enable_api': True,
|
||||
'is_demo': False,
|
||||
|
|
@ -15,24 +15,7 @@ model_templates = {
|
|||
'model_config': {
|
||||
'provider': '',
|
||||
'model_id': '',
|
||||
'configs': {},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {}
|
||||
}),
|
||||
'user_input_form': json.dumps([
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
]),
|
||||
'pre_prompt': '{{query}}'
|
||||
'configs': {}
|
||||
}
|
||||
},
|
||||
|
||||
|
|
@ -48,14 +31,70 @@ model_templates = {
|
|||
'status': 'normal'
|
||||
},
|
||||
'model_config': {
|
||||
'provider': '',
|
||||
'model_id': '',
|
||||
'configs': {},
|
||||
'provider': 'openai',
|
||||
'model_id': 'gpt-4',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 512,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
}
|
||||
},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"name": "gpt-4",
|
||||
"mode": "chat",
|
||||
"completion_params": {}
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
# agent default mode
|
||||
'agent_default': {
|
||||
'app': {
|
||||
'mode': 'agent',
|
||||
'enable_site': True,
|
||||
'enable_api': True,
|
||||
'is_demo': False,
|
||||
'api_rpm': 0,
|
||||
'api_rph': 0,
|
||||
'status': 'normal'
|
||||
},
|
||||
'model_config': {
|
||||
'provider': 'openai',
|
||||
'model_id': 'gpt-4',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 512,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
}
|
||||
},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-4",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
})
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ bp = Blueprint('console', __name__, url_prefix='/console/api')
|
|||
api = ExternalApi(bp)
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, setup, version
|
||||
from . import admin, apikey, extension, feature, setup, version, ping
|
||||
# Import app controllers
|
||||
from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message,
|
||||
model_config, site, statistic, workflow)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from fields.app_fields import (
|
|||
template_list_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from models.model import App, AppModelConfig, Site, AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
|
@ -80,7 +80,7 @@ class AppListApi(Resource):
|
|||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('mode', type=str, choices=['completion', 'chat', 'assistant'], location='json')
|
||||
parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('model_config', type=dict, location='json')
|
||||
|
|
@ -90,18 +90,7 @@ class AppListApi(Resource):
|
|||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_manager = ProviderManager()
|
||||
default_model_entity = provider_manager.get_default_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
default_model_entity = None
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
default_model_entity = None
|
||||
|
||||
# TODO: MOVE TO IMPORT API
|
||||
if args['model_config'] is not None:
|
||||
# validate config
|
||||
model_config_dict = args['model_config']
|
||||
|
|
@ -150,27 +139,30 @@ class AppListApi(Resource):
|
|||
if 'mode' not in args or args['mode'] is None:
|
||||
abort(400, message="mode is required")
|
||||
|
||||
model_config_template = model_templates[args['mode'] + '_default']
|
||||
app_mode = AppMode.value_of(args['mode'])
|
||||
|
||||
model_config_template = model_templates[app_mode.value + '_default']
|
||||
|
||||
app = App(**model_config_template['app'])
|
||||
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
||||
|
||||
# get model provider
|
||||
model_manager = ModelManager()
|
||||
if app_mode in [AppMode.CHAT, AppMode.AGENT]:
|
||||
# get model provider
|
||||
model_manager = ModelManager()
|
||||
|
||||
try:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
model_instance = None
|
||||
try:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
model_instance = None
|
||||
|
||||
if model_instance:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = model_instance.provider
|
||||
model_dict['name'] = model_instance.model
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
if model_instance:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = model_instance.provider
|
||||
model_dict['name'] = model_instance.model
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
app.name = args['name']
|
||||
app.mode = args['mode']
|
||||
|
|
|
|||
|
|
@ -20,10 +20,10 @@ from controllers.console.app.error import (
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.entities.application_entities import AppMode
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
|
|
|
|||
|
|
@ -22,11 +22,12 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import AppMode, InvokeFrom
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from controllers.console import api
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.entities.application_entities import AppMode
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
conversation_detail_fields,
|
||||
|
|
@ -22,7 +21,7 @@ from fields.conversation_fields import (
|
|||
)
|
||||
from libs.helper import datetime_string
|
||||
from libs.login import login_required
|
||||
from models.model import Conversation, Message, MessageAnnotation
|
||||
from models.model import Conversation, Message, MessageAnnotation, AppMode
|
||||
|
||||
|
||||
class CompletionConversationApi(Resource):
|
||||
|
|
|
|||
|
|
@ -85,3 +85,9 @@ class TooManyFilesError(BaseHTTPException):
|
|||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = 'draft_workflow_not_exist'
|
||||
description = "Draft workflow need to be initialized."
|
||||
code = 400
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
|
|
@ -20,7 +19,6 @@ from controllers.console.app.error import (
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.entities.application_entities import AppMode, InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -28,10 +26,8 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
|
|||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import login_required
|
||||
from models.model import Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from models.model import Conversation, Message, MessageAnnotation, MessageFeedback, AppMode
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.message_service import MessageService
|
||||
|
|
@ -183,49 +179,6 @@ class MessageAnnotationCountApi(Resource):
|
|||
return {'count': count}
|
||||
|
||||
|
||||
class MessageMoreLikeThisApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'],
|
||||
location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming
|
||||
)
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
|
|
@ -291,7 +244,6 @@ class MessageApi(Resource):
|
|||
return message
|
||||
|
||||
|
||||
api.add_resource(MessageMoreLikeThisApi, '/apps/<uuid:app_id>/completion-messages/<uuid:message_id>/more-like-this')
|
||||
api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions')
|
||||
api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages')
|
||||
api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks')
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from controllers.console import api
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.entities.application_entities import AppMode
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class DailyConversationStatistic(Resource):
|
||||
|
|
|
|||
|
|
@ -1,30 +1,88 @@
|
|||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import Resource, reqparse, marshal_with
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.entities.application_entities import AppMode
|
||||
from libs.login import login_required
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from libs.login import login_required, current_user
|
||||
from models.model import App, ChatbotAppEngine, AppMode
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW)
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# return workflow, if not found, return None (initiate graph by frontend)
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW)
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Sync draft workflow
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.sync_draft_workflow(app_model=app_model, graph=args.get('graph'), account=current_user)
|
||||
|
||||
return {
|
||||
"result": "success"
|
||||
}
|
||||
|
||||
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('app_mode', type=str, required=True, nullable=False,
|
||||
choices=[AppMode.CHAT.value, AppMode.WORKFLOW.value], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_mode = args.get('app_mode')
|
||||
app_mode = AppMode.value_of(app_mode)
|
||||
|
||||
# TODO: implement this
|
||||
|
||||
return {
|
||||
"blocks": []
|
||||
}
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
# Get default block configs
|
||||
workflow_service = WorkflowService()
|
||||
return workflow_service.get_default_block_configs()
|
||||
|
||||
|
||||
api.add_resource(DefaultBlockConfigApi, '/default-workflow-block-configs')
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.CHAT)
|
||||
@marshal_with(workflow_fields)
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Convert basic mode of chatbot app to workflow
|
||||
"""
|
||||
# convert to workflow mode
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.chatbot_convert_to_workflow(app_model=app_model)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
|
||||
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
|
||||
api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow')
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ from functools import wraps
|
|||
from typing import Optional, Union
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from core.entities.application_entities import AppMode
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.model import App
|
||||
from models.model import App, ChatbotAppEngine, AppMode
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||
def get_app_model(view: Optional[Callable] = None, *,
|
||||
mode: Union[AppMode, list[AppMode]] = None,
|
||||
app_engine: ChatbotAppEngine = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
|
|
@ -37,14 +38,20 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
|
|||
else:
|
||||
modes = [mode]
|
||||
|
||||
# [temp] if workflow is in the mode list, then completion should be in the mode list
|
||||
if AppMode.WORKFLOW in modes:
|
||||
modes.append(AppMode.COMPLETION)
|
||||
|
||||
if app_mode not in modes:
|
||||
mode_values = {m.value for m in modes}
|
||||
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||
|
||||
if app_engine is not None:
|
||||
if app_mode not in [AppMode.CHAT, AppMode.WORKFLOW]:
|
||||
raise AppNotFoundError(f"App mode is not supported for {app_engine.value} app engine.")
|
||||
|
||||
if app_mode == AppMode.CHAT:
|
||||
# fetch current app model config
|
||||
app_model_config = app_model.app_model_config
|
||||
if not app_model_config or app_model_config.chatbot_app_engine != app_engine.value:
|
||||
raise AppNotFoundError(f"{app_engine.value} app engine is not supported.")
|
||||
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
|||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
|
|
@ -24,13 +23,10 @@ from controllers.console.explore.error import (
|
|||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
|
@ -76,48 +72,6 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
)
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
|
|
@ -166,5 +120,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||
|
||||
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
|
||||
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
|
||||
api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
|
||||
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
from flask_restful import Resource
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
|
||||
class PingApi(Resource):
|
||||
|
||||
def get(self):
|
||||
"""
|
||||
For connection health check
|
||||
"""
|
||||
return {
|
||||
"result": "pong"
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(PingApi, '/ping')
|
||||
|
|
@ -16,26 +16,13 @@ from controllers.console.workspace.error import (
|
|||
)
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.helper import TimestampField, timezone
|
||||
from libs.login import login_required
|
||||
from models.account import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'avatar': fields.String,
|
||||
'email': fields.String,
|
||||
'is_password_set': fields.Boolean,
|
||||
'interface_language': fields.String,
|
||||
'interface_theme': fields.String,
|
||||
'timezone': fields.String,
|
||||
'last_login_at': TimestampField,
|
||||
'last_login_ip': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
class AccountInitApi(Resource):
|
||||
|
||||
|
|
|
|||
|
|
@ -1,33 +1,18 @@
|
|||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
||||
from flask_restful import Resource, abort, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'avatar': fields.String,
|
||||
'email': fields.String,
|
||||
'last_login_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'role': fields.String,
|
||||
'status': fields.String,
|
||||
}
|
||||
|
||||
account_list_fields = {
|
||||
'accounts': fields.List(fields.Nested(account_fields))
|
||||
}
|
||||
|
||||
|
||||
class MemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
|
|
@ -35,7 +20,7 @@ class MemberListApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_list_fields)
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {'result': 'success', 'accounts': members}, 200
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
|||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
CompletionRequestError,
|
||||
NotChatAppError,
|
||||
|
|
@ -21,14 +20,11 @@ from controllers.web.error import (
|
|||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
|
@ -113,48 +109,6 @@ class MessageFeedbackApi(WebApiResource):
|
|||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageMoreLikeThisApi(WebApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
|
|
@ -202,5 +156,4 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
|
||||
api.add_resource(MessageListApi, '/messages')
|
||||
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
|
||||
api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
|||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import (
|
||||
ApplicationGenerateEntity,
|
||||
AppMode,
|
||||
DatasetEntity,
|
||||
InvokeFrom,
|
||||
ModelConfigEntity,
|
||||
|
|
@ -16,7 +15,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
|||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
from models.model import App, Conversation, Message, AppMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -250,6 +249,7 @@ class BasicApplicationRunner(AppRunner):
|
|||
invoke_from
|
||||
)
|
||||
|
||||
# TODO
|
||||
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
|
||||
and dataset_config.retrieve_config.query_variable):
|
||||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from core.entities.application_entities import (
|
|||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
SensitiveWordAvoidanceEntity,
|
||||
TextToSpeechEntity,
|
||||
TextToSpeechEntity, VariableEntity,
|
||||
)
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
|
|
@ -93,7 +93,7 @@ class ApplicationManager:
|
|||
app_id=app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
app_model_config_dict=app_model_config_dict,
|
||||
app_orchestration_config_entity=self._convert_from_app_model_config_dict(
|
||||
app_orchestration_config_entity=self.convert_from_app_model_config_dict(
|
||||
tenant_id=tenant_id,
|
||||
app_model_config_dict=app_model_config_dict
|
||||
),
|
||||
|
|
@ -234,7 +234,7 @@ class ApplicationManager:
|
|||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
||||
def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
||||
-> AppOrchestrationConfigEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
|
|
@ -384,8 +384,10 @@ class ApplicationManager:
|
|||
config=external_data_tool['config']
|
||||
)
|
||||
)
|
||||
|
||||
properties['variables'] = []
|
||||
|
||||
# current external_data_tools
|
||||
# variables and external_data_tools
|
||||
for variable in copy_app_model_config_dict.get('user_input_form', []):
|
||||
typ = list(variable.keys())[0]
|
||||
if typ == 'external_data_tool':
|
||||
|
|
@ -397,6 +399,30 @@ class ApplicationManager:
|
|||
config=val['config']
|
||||
)
|
||||
)
|
||||
elif typ in [VariableEntity.Type.TEXT_INPUT.value, VariableEntity.Type.PARAGRAPH.value]:
|
||||
properties['variables'].append(
|
||||
VariableEntity(
|
||||
type=VariableEntity.Type.TEXT_INPUT,
|
||||
variable=variable[typ].get('variable'),
|
||||
description=variable[typ].get('description'),
|
||||
label=variable[typ].get('label'),
|
||||
required=variable[typ].get('required', False),
|
||||
max_length=variable[typ].get('max_length'),
|
||||
default=variable[typ].get('default'),
|
||||
)
|
||||
)
|
||||
elif typ == VariableEntity.Type.SELECT.value:
|
||||
properties['variables'].append(
|
||||
VariableEntity(
|
||||
type=VariableEntity.Type.SELECT,
|
||||
variable=variable[typ].get('variable'),
|
||||
description=variable[typ].get('description'),
|
||||
label=variable[typ].get('label'),
|
||||
required=variable[typ].get('required', False),
|
||||
options=variable[typ].get('options'),
|
||||
default=variable[typ].get('default'),
|
||||
)
|
||||
)
|
||||
|
||||
# show retrieve source
|
||||
show_retrieve_source = False
|
||||
|
|
|
|||
|
|
@ -9,26 +9,6 @@ from core.model_runtime.entities.message_entities import PromptMessageRole
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
|
||||
class AppMode(Enum):
|
||||
COMPLETION = 'completion' # will be deprecated in the future
|
||||
WORKFLOW = 'workflow' # instead of 'completion'
|
||||
CHAT = 'chat'
|
||||
AGENT = 'agent'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'AppMode':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
|
||||
class ModelConfigEntity(BaseModel):
|
||||
"""
|
||||
Model Config Entity.
|
||||
|
|
@ -106,6 +86,38 @@ class PromptTemplateEntity(BaseModel):
|
|||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Variable Entity.
|
||||
"""
|
||||
class Type(Enum):
|
||||
TEXT_INPUT = 'text-input'
|
||||
SELECT = 'select'
|
||||
PARAGRAPH = 'paragraph'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'VariableEntity.Type':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid variable type value {value}')
|
||||
|
||||
variable: str
|
||||
label: str
|
||||
description: Optional[str] = None
|
||||
type: Type
|
||||
required: bool = False
|
||||
max_length: Optional[int] = None
|
||||
options: Optional[list[str]] = None
|
||||
default: Optional[str] = None
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
External Data Variable Entity.
|
||||
|
|
@ -245,6 +257,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
|||
"""
|
||||
model_config: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
variables: list[VariableEntity] = []
|
||||
external_data_variables: list[ExternalDataVariableEntity] = []
|
||||
agent: Optional[AgentEntity] = None
|
||||
|
||||
|
|
@ -256,7 +269,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
|||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: dict = {}
|
||||
text_to_speech: Optional[TextToSpeechEntity] = None
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from typing import Optional, cast
|
|||
|
||||
from core.entities.application_entities import (
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
AppMode,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
|
|
@ -24,6 +23,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey
|
|||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class ModelMode(enum.Enum):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""
|
||||
Node Types.
|
||||
"""
|
||||
START = 'start'
|
||||
END = 'end'
|
||||
DIRECT_ANSWER = 'direct-answer'
|
||||
LLM = 'llm'
|
||||
KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval'
|
||||
IF_ELSE = 'if-else'
|
||||
CODE = 'code'
|
||||
TEMPLATE_TRANSFORM = 'template-transform'
|
||||
QUESTION_CLASSIFIER = 'question-classifier'
|
||||
HTTP_REQUEST = 'http-request'
|
||||
TOOL = 'tool'
|
||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'BlockType':
|
||||
"""
|
||||
Get value of given block type.
|
||||
|
||||
:param value: block type value
|
||||
:return: block type
|
||||
"""
|
||||
for block_type in cls:
|
||||
if block_type.value == value:
|
||||
return block_type
|
||||
raise ValueError(f'invalid block type value {value}')
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class EndNodeOutputType(Enum):
|
||||
"""
|
||||
END Node Output Types.
|
||||
|
||||
none, plain-text, structured
|
||||
"""
|
||||
NONE = 'none'
|
||||
PLAIN_TEXT = 'plain-text'
|
||||
STRUCTURED = 'structured'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'OutputType':
|
||||
"""
|
||||
Get value of given output type.
|
||||
|
||||
:param value: output type value
|
||||
:return: output type
|
||||
"""
|
||||
for output_type in cls:
|
||||
if output_type.value == value:
|
||||
return output_type
|
||||
raise ValueError(f'invalid output type value {value}')
|
||||
|
|
@ -2,12 +2,6 @@ from flask_restful import fields
|
|||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
|
||||
annotation_fields = {
|
||||
"id": fields.String,
|
||||
|
|
@ -15,7 +9,7 @@ annotation_fields = {
|
|||
"answer": fields.Raw(attribute='content'),
|
||||
"hit_count": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
# 'account': fields.Nested(account_fields, allow_null=True)
|
||||
# 'account': fields.Nested(simple_account_fields, allow_null=True)
|
||||
}
|
||||
|
||||
annotation_list_fields = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask_restful import fields
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
|
||||
|
|
@ -8,31 +9,25 @@ class MessageTextField(fields.Raw):
|
|||
return value[0]['text'] if value else ''
|
||||
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String,
|
||||
'content': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account': fields.Nested(account_fields, allow_null=True),
|
||||
'from_account': fields.Nested(simple_account_fields, allow_null=True),
|
||||
}
|
||||
|
||||
annotation_fields = {
|
||||
'id': fields.String,
|
||||
'question': fields.String,
|
||||
'content': fields.String,
|
||||
'account': fields.Nested(account_fields, allow_null=True),
|
||||
'account': fields.Nested(simple_account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
'annotation_id': fields.String(attribute='id'),
|
||||
'annotation_create_account': fields.Nested(account_fields, allow_null=True),
|
||||
'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
simple_account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'avatar': fields.String,
|
||||
'email': fields.String,
|
||||
'is_password_set': fields.Boolean,
|
||||
'interface_language': fields.String,
|
||||
'interface_theme': fields.String,
|
||||
'timezone': fields.String,
|
||||
'last_login_at': TimestampField,
|
||||
'last_login_ip': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
account_with_role_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'avatar': fields.String,
|
||||
'email': fields.String,
|
||||
'last_login_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'role': fields.String,
|
||||
'status': fields.String,
|
||||
}
|
||||
|
||||
account_with_role_list_fields = {
|
||||
'accounts': fields.List(fields.Nested(account_with_role_fields))
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
import json
|
||||
|
||||
from flask_restful import fields
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
|
||||
workflow_fields = {
|
||||
'id': fields.String,
|
||||
'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None),
|
||||
'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'),
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),
|
||||
'updated_at': TimestampField
|
||||
}
|
||||
|
|
@ -102,7 +102,7 @@ def upgrade():
|
|||
sa.PrimaryKeyConstraint('id', name='workflow_pkey')
|
||||
)
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'type', 'version'], unique=False)
|
||||
batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False)
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('chatbot_app_engine', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import UserMixin
|
||||
|
|
@ -25,6 +27,25 @@ class DifySetup(db.Model):
|
|||
setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class AppMode(Enum):
|
||||
WORKFLOW = 'workflow'
|
||||
CHAT = 'chat'
|
||||
AGENT = 'agent'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'AppMode':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
|
||||
class App(db.Model):
|
||||
__tablename__ = 'apps'
|
||||
__table_args__ = (
|
||||
|
|
@ -56,7 +77,7 @@ class App(db.Model):
|
|||
return site
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
def app_model_config(self) -> Optional['AppModelConfig']:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == self.app_model_config_id).first()
|
||||
return app_model_config
|
||||
|
|
@ -130,6 +151,12 @@ class App(db.Model):
|
|||
|
||||
return deleted_tools
|
||||
|
||||
|
||||
class ChatbotAppEngine(Enum):
|
||||
NORMAL = 'normal'
|
||||
WORKFLOW = 'workflow'
|
||||
|
||||
|
||||
class AppModelConfig(db.Model):
|
||||
__tablename__ = 'app_model_configs'
|
||||
__table_args__ = (
|
||||
|
|
|
|||
|
|
@ -1,6 +1,43 @@
|
|||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class WorkflowType(Enum):
|
||||
"""
|
||||
Workflow Type Enum
|
||||
"""
|
||||
WORKFLOW = 'workflow'
|
||||
CHAT = 'chat'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'WorkflowType':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid workflow type value {value}')
|
||||
|
||||
@classmethod
|
||||
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> 'WorkflowType':
|
||||
"""
|
||||
Get workflow type from app mode.
|
||||
|
||||
:param app_mode: app mode
|
||||
:return: workflow type
|
||||
"""
|
||||
app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode)
|
||||
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
|
||||
|
||||
|
||||
class Workflow(db.Model):
|
||||
|
|
@ -39,7 +76,7 @@ class Workflow(db.Model):
|
|||
__tablename__ = 'workflows'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='workflow_pkey'),
|
||||
db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'type', 'version'),
|
||||
db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
|
|
@ -53,6 +90,14 @@ class Workflow(db.Model):
|
|||
updated_by = db.Column(UUID)
|
||||
updated_at = db.Column(db.DateTime)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
return Account.query.get(self.created_by)
|
||||
|
||||
@property
|
||||
def updated_by_account(self):
|
||||
return Account.query.get(self.updated_by)
|
||||
|
||||
|
||||
class WorkflowRun(db.Model):
|
||||
"""
|
||||
|
|
@ -116,6 +161,14 @@ class WorkflowRun(db.Model):
|
|||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
finished_at = db.Column(db.DateTime)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
return Account.query.get(self.created_by)
|
||||
|
||||
@property
|
||||
def updated_by_account(self):
|
||||
return Account.query.get(self.updated_by)
|
||||
|
||||
|
||||
class WorkflowNodeExecution(db.Model):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
|
||||
import copy
|
||||
|
||||
from core.entities.application_entities import AppMode
|
||||
from core.prompt.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
|
|
@ -14,6 +13,7 @@ from core.prompt.advanced_prompt_templates import (
|
|||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
CONTEXT,
|
||||
)
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from core.model_runtime.model_providers import model_provider_factory
|
|||
from core.moderation.factory import ModerationFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
||||
|
|
@ -315,9 +316,6 @@ class AppModelConfigService:
|
|||
if "tool_parameters" not in tool:
|
||||
raise ValueError("tool_parameters is required in agent_mode.tools")
|
||||
|
||||
# dataset_query_variable
|
||||
cls.is_dataset_query_variable_valid(config, app_mode)
|
||||
|
||||
# advanced prompt validation
|
||||
cls.is_advanced_prompt_valid(config, app_mode)
|
||||
|
||||
|
|
@ -443,21 +441,6 @@ class AppModelConfigService:
|
|||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None:
|
||||
# Only check when mode is completion
|
||||
if mode != 'completion':
|
||||
return
|
||||
|
||||
agent_mode = config.get("agent_mode", {})
|
||||
tools = agent_mode.get("tools", [])
|
||||
dataset_exists = "dataset" in str(tools)
|
||||
|
||||
dataset_query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_exists and not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
@classmethod
|
||||
def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
|
||||
# prompt_type
|
||||
|
|
|
|||
|
|
@ -8,12 +8,10 @@ from core.application_manager import ApplicationManager
|
|||
from core.entities.application_entities import InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message
|
||||
from models.model import Account, App, AppModelConfig, Conversation, EndUser
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
class CompletionService:
|
||||
|
|
@ -157,62 +155,6 @@ class CompletionService:
|
|||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
current_app_model_config = app_model.app_model_config
|
||||
more_like_this = current_app_model_config.more_like_this_dict
|
||||
|
||||
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
|
||||
raise MoreLikeThisDisabledError()
|
||||
|
||||
app_model_config = message.app_model_config
|
||||
model_dict = app_model_config.model_dict
|
||||
completion_params = model_dict.get('completion_params')
|
||||
completion_params['temperature'] = 0.9
|
||||
model_dict['completion_params'] = completion_params
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
message.files, app_model_config
|
||||
)
|
||||
|
||||
application_manager = ApplicationManager()
|
||||
return application_manager.generate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=app_model_config.to_dict(),
|
||||
app_model_config_override=True,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=file_objs,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||
if user_inputs is None:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
__all__ = [
|
||||
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
||||
'app', 'completion', 'audio', 'file'
|
||||
'completion', 'audio', 'file'
|
||||
]
|
||||
|
||||
from . import *
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
class MoreLikeThisDisabledError(Exception):
|
||||
pass
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
# default block config
|
||||
default_block_configs = [
|
||||
{
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_templates": {
|
||||
"chat_model": {
|
||||
"prompts": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "You are a helpful AI assistant."
|
||||
}
|
||||
]
|
||||
},
|
||||
"completion_model": {
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant"
|
||||
},
|
||||
"prompt": {
|
||||
"text": "Here is the chat histories between human and assistant, inside "
|
||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant:"
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"variable": "arg1",
|
||||
"value_selector": []
|
||||
},
|
||||
{
|
||||
"variable": "arg2",
|
||||
"value_selector": []
|
||||
}
|
||||
],
|
||||
"code_language": "python3",
|
||||
"code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 "
|
||||
"+ arg2\n }",
|
||||
"outputs": [
|
||||
{
|
||||
"variable": "result",
|
||||
"variable_type": "number"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "template-transform",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"variable": "arg1",
|
||||
"value_selector": []
|
||||
}
|
||||
],
|
||||
"template": "{{ arg1 }}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"instructions": "" # TODO
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,259 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from core.application_manager import ApplicationManager
|
||||
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, FileUploadEntity, \
|
||||
ExternalDataVariableEntity, DatasetEntity, VariableEntity
|
||||
from core.model_runtime.utils import helper
|
||||
from core.workflow.entities.NodeEntities import NodeType
|
||||
from core.workflow.nodes.end.entities import EndNodeOutputType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, ChatbotAppEngine
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
|
||||
class WorkflowConverter:
|
||||
"""
|
||||
App Convert to Workflow Mode
|
||||
"""
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account) -> Workflow:
|
||||
"""
|
||||
Convert to workflow mode
|
||||
|
||||
- basic mode of chatbot app
|
||||
|
||||
- advanced mode of assistant app (for migration)
|
||||
|
||||
- completion app (for migration)
|
||||
|
||||
:param app_model: App instance
|
||||
:param account: Account instance
|
||||
:return: workflow instance
|
||||
"""
|
||||
# get original app config
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
# convert app model config
|
||||
application_manager = ApplicationManager()
|
||||
application_manager.convert_from_app_model_config_dict(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_model_config_dict=app_model_config.to_dict()
|
||||
)
|
||||
|
||||
# init workflow graph
|
||||
graph = {
|
||||
"nodes": [],
|
||||
"edges": []
|
||||
}
|
||||
|
||||
# Convert list:
|
||||
# - variables -> start
|
||||
# - model_config -> llm
|
||||
# - prompt_template -> llm
|
||||
# - file_upload -> llm
|
||||
# - external_data_variables -> http-request
|
||||
# - dataset -> knowledge-retrieval
|
||||
# - show_retrieve_source -> knowledge-retrieval
|
||||
|
||||
# convert to start node
|
||||
start_node = self._convert_to_start_node(
|
||||
variables=app_model_config.variables
|
||||
)
|
||||
|
||||
graph['nodes'].append(start_node)
|
||||
|
||||
# convert to http request node
|
||||
if app_model_config.external_data_variables:
|
||||
http_request_node = self._convert_to_http_request_node(
|
||||
external_data_variables=app_model_config.external_data_variables
|
||||
)
|
||||
|
||||
graph = self._append_node(graph, http_request_node)
|
||||
|
||||
# convert to knowledge retrieval node
|
||||
if app_model_config.dataset:
|
||||
knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node(
|
||||
dataset=app_model_config.dataset,
|
||||
show_retrieve_source=app_model_config.show_retrieve_source
|
||||
)
|
||||
|
||||
graph = self._append_node(graph, knowledge_retrieval_node)
|
||||
|
||||
# convert to llm node
|
||||
llm_node = self._convert_to_llm_node(
|
||||
model_config=app_model_config.model_config,
|
||||
prompt_template=app_model_config.prompt_template,
|
||||
file_upload=app_model_config.file_upload
|
||||
)
|
||||
|
||||
graph = self._append_node(graph, llm_node)
|
||||
|
||||
# convert to end node by app mode
|
||||
end_node = self._convert_to_end_node(app_model=app_model)
|
||||
|
||||
graph = self._append_node(graph, end_node)
|
||||
|
||||
# get new app mode
|
||||
app_mode = self._get_new_app_mode(app_model)
|
||||
|
||||
# create workflow record
|
||||
workflow = Workflow(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(app_mode).value,
|
||||
version='draft',
|
||||
graph=json.dumps(graph),
|
||||
created_by=account.id
|
||||
)
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
|
||||
# create new app model config record
|
||||
new_app_model_config = app_model_config.copy()
|
||||
new_app_model_config.external_data_tools = ''
|
||||
new_app_model_config.model = ''
|
||||
new_app_model_config.user_input_form = ''
|
||||
new_app_model_config.dataset_query_variable = None
|
||||
new_app_model_config.pre_prompt = None
|
||||
new_app_model_config.agent_mode = ''
|
||||
new_app_model_config.prompt_type = 'simple'
|
||||
new_app_model_config.chat_prompt_config = ''
|
||||
new_app_model_config.completion_prompt_config = ''
|
||||
new_app_model_config.dataset_configs = ''
|
||||
new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \
|
||||
if app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value
|
||||
new_app_model_config.workflow_id = workflow.id
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
|
||||
"""
|
||||
Convert to Start Node
|
||||
:param variables: list of variables
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"id": "start",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "START",
|
||||
"type": NodeType.START.value,
|
||||
"variables": [helper.dump_model(v) for v in variables]
|
||||
}
|
||||
}
|
||||
|
||||
def _convert_to_http_request_node(self, external_data_variables: list[ExternalDataVariableEntity]) -> dict:
|
||||
"""
|
||||
Convert API Based Extension to HTTP Request Node
|
||||
:param external_data_variables: list of external data variables
|
||||
:return:
|
||||
"""
|
||||
# TODO: implement
|
||||
pass
|
||||
|
||||
def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset: DatasetEntity) -> dict:
|
||||
"""
|
||||
Convert datasets to Knowledge Retrieval Node
|
||||
:param new_app_mode: new app mode
|
||||
:param dataset: dataset
|
||||
:return:
|
||||
"""
|
||||
# TODO: implement
|
||||
if new_app_mode == AppMode.CHAT:
|
||||
query_variable_selector = ["start", "sys.query"]
|
||||
else:
|
||||
pass
|
||||
|
||||
return {
|
||||
"id": "knowledge-retrieval",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "KNOWLEDGE RETRIEVAL",
|
||||
"type": NodeType.KNOWLEDGE_RETRIEVAL.value,
|
||||
}
|
||||
}
|
||||
|
||||
def _convert_to_llm_node(self, model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileUploadEntity] = None) -> dict:
|
||||
"""
|
||||
Convert to LLM Node
|
||||
:param model_config: model config
|
||||
:param prompt_template: prompt template
|
||||
:param file_upload: file upload config (optional)
|
||||
"""
|
||||
# TODO: implement
|
||||
pass
|
||||
|
||||
def _convert_to_end_node(self, app_model: App) -> dict:
|
||||
"""
|
||||
Convert to End Node
|
||||
:param app_model: App instance
|
||||
:return:
|
||||
"""
|
||||
if app_model.mode == AppMode.CHAT.value:
|
||||
return {
|
||||
"id": "end",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "END",
|
||||
"type": NodeType.END.value,
|
||||
}
|
||||
}
|
||||
elif app_model.mode == "completion":
|
||||
# for original completion app
|
||||
return {
|
||||
"id": "end",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "END",
|
||||
"type": NodeType.END.value,
|
||||
"outputs": {
|
||||
"type": EndNodeOutputType.PLAIN_TEXT.value,
|
||||
"plain_text_selector": ["llm", "text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str) -> dict:
|
||||
"""
|
||||
Create Edge
|
||||
:param source: source node id
|
||||
:param target: target node id
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"id": f"{source}-{target}",
|
||||
"source": source,
|
||||
"target": target
|
||||
}
|
||||
|
||||
def _append_node(self, graph: dict, node: dict) -> dict:
|
||||
"""
|
||||
Append Node to Graph
|
||||
|
||||
:param graph: Graph, include: nodes, edges
|
||||
:param node: Node to append
|
||||
:return:
|
||||
"""
|
||||
previous_node = graph['nodes'][-1]
|
||||
graph['nodes'].append(node)
|
||||
graph['edges'].append(self._create_edge(previous_node['id'], node['id']))
|
||||
return graph
|
||||
|
||||
def _get_new_app_mode(self, app_model: App) -> AppMode:
|
||||
"""
|
||||
Get new app mode
|
||||
:param app_model: App instance
|
||||
:return: AppMode
|
||||
"""
|
||||
if app_model.mode == "completion":
|
||||
return AppMode.WORKFLOW
|
||||
else:
|
||||
return AppMode.value_of(app_model.mode)
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, ChatbotAppEngine
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.workflow.defaults import default_block_configs
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
"""
|
||||
Workflow Service
|
||||
"""
|
||||
|
||||
def get_draft_workflow(self, app_model: App) -> Workflow:
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == 'draft'
|
||||
).first()
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
# create draft workflow if not found
|
||||
if not workflow:
|
||||
workflow = Workflow(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(app_model.mode).value,
|
||||
version='draft',
|
||||
graph=json.dumps(graph),
|
||||
created_by=account.id
|
||||
)
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
else:
|
||||
workflow.graph = json.dumps(graph)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.utcnow()
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_default_block_configs(self) -> dict:
|
||||
"""
|
||||
Get default block configs
|
||||
"""
|
||||
# return default block config
|
||||
return default_block_configs
|
||||
|
||||
def chatbot_convert_to_workflow(self, app_model: App) -> Workflow:
|
||||
"""
|
||||
basic mode of chatbot app to workflow
|
||||
|
||||
:param app_model: App instance
|
||||
:return:
|
||||
"""
|
||||
# check if chatbot app is in basic mode
|
||||
if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL:
|
||||
raise ValueError('Chatbot app already in workflow mode')
|
||||
|
||||
# convert to workflow mode
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow = workflow_converter.convert_to_workflow(app_model=app_model)
|
||||
|
||||
return workflow
|
||||
Loading…
Reference in New Issue