mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 19:27:23 +08:00
restore completion app
This commit is contained in:
parent
97c4733e79
commit
fce20e483c
@ -80,7 +80,7 @@ class AppListApi(Resource):
|
|||||||
"""Create app"""
|
"""Create app"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, location='json')
|
parser.add_argument('name', type=str, required=True, location='json')
|
||||||
parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json')
|
parser.add_argument('mode', type=str, choices=['chat', 'agent', 'workflow'], location='json')
|
||||||
parser.add_argument('icon', type=str, location='json')
|
parser.add_argument('icon', type=str, location='json')
|
||||||
parser.add_argument('icon_background', type=str, location='json')
|
parser.add_argument('icon_background', type=str, location='json')
|
||||||
parser.add_argument('model_config', type=dict, location='json')
|
parser.add_argument('model_config', type=dict, location='json')
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class CompletionMessageApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.WORKFLOW)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||||
@ -90,7 +90,7 @@ class CompletionMessageStopApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.WORKFLOW)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def post(self, app_model, task_id):
|
def post(self, app_model, task_id):
|
||||||
account = flask_login.current_user
|
account = flask_login.current_user
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class CompletionConversationApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.WORKFLOW)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
@marshal_with(conversation_pagination_fields)
|
@marshal_with(conversation_pagination_fields)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -102,7 +102,7 @@ class CompletionConversationDetailApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.WORKFLOW)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
@marshal_with(conversation_message_detail_fields)
|
@marshal_with(conversation_message_detail_fields)
|
||||||
def get(self, app_model, conversation_id):
|
def get(self, app_model, conversation_id):
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|||||||
@ -330,7 +330,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.WORKFLOW)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
|||||||
import services
|
import services
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
|
AppMoreLikeThisDisabledError,
|
||||||
CompletionRequestError,
|
CompletionRequestError,
|
||||||
ProviderModelCurrentlyNotSupportError,
|
ProviderModelCurrentlyNotSupportError,
|
||||||
ProviderNotInitializeError,
|
ProviderNotInitializeError,
|
||||||
@ -23,10 +24,13 @@ from controllers.console.explore.error import (
|
|||||||
NotCompletionAppError,
|
NotCompletionAppError,
|
||||||
)
|
)
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
|
from core.entities.application_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
|
from services.completion_service import CompletionService
|
||||||
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||||
from services.message_service import MessageService
|
from services.message_service import MessageService
|
||||||
@ -72,6 +76,48 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||||||
return {'result': 'success'}
|
return {'result': 'success'}
|
||||||
|
|
||||||
|
|
||||||
|
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
|
def get(self, installed_app, message_id):
|
||||||
|
app_model = installed_app.app
|
||||||
|
if app_model.mode != 'completion':
|
||||||
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
|
message_id = str(message_id)
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
streaming = args['response_mode'] == 'streaming'
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = CompletionService.generate_more_like_this(
|
||||||
|
app_model=app_model,
|
||||||
|
user=current_user,
|
||||||
|
message_id=message_id,
|
||||||
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
|
return compact_response(response)
|
||||||
|
except MessageNotExistsError:
|
||||||
|
raise NotFound("Message Not Exists.")
|
||||||
|
except MoreLikeThisDisabledError:
|
||||||
|
raise AppMoreLikeThisDisabledError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise e
|
||||||
|
except Exception:
|
||||||
|
logging.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
if isinstance(response, dict):
|
if isinstance(response, dict):
|
||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
@ -120,4 +166,5 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||||||
|
|
||||||
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
|
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
|
||||||
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
|
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
|
||||||
|
api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
|
||||||
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
|
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
|||||||
import services
|
import services
|
||||||
from controllers.web import api
|
from controllers.web import api
|
||||||
from controllers.web.error import (
|
from controllers.web.error import (
|
||||||
|
AppMoreLikeThisDisabledError,
|
||||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||||
CompletionRequestError,
|
CompletionRequestError,
|
||||||
NotChatAppError,
|
NotChatAppError,
|
||||||
@ -20,11 +21,14 @@ from controllers.web.error import (
|
|||||||
ProviderQuotaExceededError,
|
ProviderQuotaExceededError,
|
||||||
)
|
)
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
|
from core.entities.application_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from fields.message_fields import agent_thought_fields
|
from fields.message_fields import agent_thought_fields
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
|
from services.completion_service import CompletionService
|
||||||
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||||
from services.message_service import MessageService
|
from services.message_service import MessageService
|
||||||
@ -109,6 +113,48 @@ class MessageFeedbackApi(WebApiResource):
|
|||||||
return {'result': 'success'}
|
return {'result': 'success'}
|
||||||
|
|
||||||
|
|
||||||
|
class MessageMoreLikeThisApi(WebApiResource):
|
||||||
|
def get(self, app_model, end_user, message_id):
|
||||||
|
if app_model.mode != 'completion':
|
||||||
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
|
message_id = str(message_id)
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
streaming = args['response_mode'] == 'streaming'
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = CompletionService.generate_more_like_this(
|
||||||
|
app_model=app_model,
|
||||||
|
user=end_user,
|
||||||
|
message_id=message_id,
|
||||||
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
return compact_response(response)
|
||||||
|
except MessageNotExistsError:
|
||||||
|
raise NotFound("Message Not Exists.")
|
||||||
|
except MoreLikeThisDisabledError:
|
||||||
|
raise AppMoreLikeThisDisabledError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise e
|
||||||
|
except Exception:
|
||||||
|
logging.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
if isinstance(response, dict):
|
if isinstance(response, dict):
|
||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
@ -156,4 +202,5 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||||||
|
|
||||||
api.add_resource(MessageListApi, '/messages')
|
api.add_resource(MessageListApi, '/messages')
|
||||||
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
|
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
|
||||||
|
api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
|
||||||
api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')
|
api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')
|
||||||
|
|||||||
@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
|
|||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||||
from models.model import App, Message, MessageAnnotation
|
from models.model import App, Message, MessageAnnotation, AppMode
|
||||||
|
|
||||||
|
|
||||||
class AppRunner:
|
class AppRunner:
|
||||||
@ -140,11 +141,11 @@ class AppRunner:
|
|||||||
:param memory: memory
|
:param memory: memory
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompt_transform = SimplePromptTransform()
|
|
||||||
|
|
||||||
# get prompt without memory and context
|
# get prompt without memory and context
|
||||||
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||||
|
prompt_transform = SimplePromptTransform()
|
||||||
prompt_messages, stop = prompt_transform.get_prompt(
|
prompt_messages, stop = prompt_transform.get_prompt(
|
||||||
|
app_mode=AppMode.value_of(app_record.mode),
|
||||||
prompt_template_entity=prompt_template_entity,
|
prompt_template_entity=prompt_template_entity,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query if query else '',
|
query=query if query else '',
|
||||||
@ -154,7 +155,17 @@ class AppRunner:
|
|||||||
model_config=model_config
|
model_config=model_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Advanced prompt is not supported yet.")
|
prompt_transform = AdvancedPromptTransform()
|
||||||
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
|
prompt_template_entity=prompt_template_entity,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query if query else '',
|
||||||
|
files=files,
|
||||||
|
context=context,
|
||||||
|
memory=memory,
|
||||||
|
model_config=model_config
|
||||||
|
)
|
||||||
|
stop = model_config.stop
|
||||||
|
|
||||||
return prompt_messages, stop
|
return prompt_messages, stop
|
||||||
|
|
||||||
|
|||||||
@ -11,10 +11,9 @@ class PromptTransform:
|
|||||||
def _append_chat_histories(self, memory: TokenBufferMemory,
|
def _append_chat_histories(self, memory: TokenBufferMemory,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_config: ModelConfigEntity) -> list[PromptMessage]:
|
model_config: ModelConfigEntity) -> list[PromptMessage]:
|
||||||
if memory:
|
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
|
||||||
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
|
prompt_messages.extend(histories)
|
||||||
prompt_messages.extend(histories)
|
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
|||||||
@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def get_prompt(self,
|
def get_prompt(self,
|
||||||
|
app_mode: AppMode,
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
model_mode = ModelMode.value_of(model_config.mode)
|
model_mode = ModelMode.value_of(model_config.mode)
|
||||||
if model_mode == ModelMode.CHAT:
|
if model_mode == ModelMode.CHAT:
|
||||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||||
|
app_mode=app_mode,
|
||||||
pre_prompt=prompt_template_entity.simple_prompt_template,
|
pre_prompt=prompt_template_entity.simple_prompt_template,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query,
|
query=query,
|
||||||
@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||||
|
app_mode=app_mode,
|
||||||
pre_prompt=prompt_template_entity.simple_prompt_template,
|
pre_prompt=prompt_template_entity.simple_prompt_template,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query,
|
query=query,
|
||||||
@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
"prompt_rules": prompt_rules
|
"prompt_rules": prompt_rules
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_chat_model_prompt_messages(self, pre_prompt: str,
|
def _get_chat_model_prompt_messages(self, app_mode: AppMode,
|
||||||
|
pre_prompt: str,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
|
|
||||||
# get prompt
|
# get prompt
|
||||||
prompt, _ = self.get_prompt_str_and_rules(
|
prompt, _ = self.get_prompt_str_and_rules(
|
||||||
app_mode=AppMode.CHAT,
|
app_mode=app_mode,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
prompt_messages.append(SystemPromptMessage(content=prompt))
|
if query:
|
||||||
|
prompt_messages.append(SystemPromptMessage(content=prompt))
|
||||||
|
else:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||||
|
|
||||||
prompt_messages = self._append_chat_histories(
|
if memory:
|
||||||
memory=memory,
|
prompt_messages = self._append_chat_histories(
|
||||||
prompt_messages=prompt_messages,
|
memory=memory,
|
||||||
model_config=model_config
|
prompt_messages=prompt_messages,
|
||||||
)
|
model_config=model_config
|
||||||
|
)
|
||||||
|
|
||||||
prompt_messages.append(self.get_last_user_message(query, files))
|
if query:
|
||||||
|
prompt_messages.append(self.get_last_user_message(query, files))
|
||||||
|
|
||||||
return prompt_messages, None
|
return prompt_messages, None
|
||||||
|
|
||||||
def _get_completion_model_prompt_messages(self, pre_prompt: str,
|
def _get_completion_model_prompt_messages(self, app_mode: AppMode,
|
||||||
|
pre_prompt: str,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
# get prompt
|
# get prompt
|
||||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
||||||
app_mode=AppMode.CHAT,
|
app_mode=app_mode,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
|
|
||||||
# get prompt
|
# get prompt
|
||||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
||||||
app_mode=AppMode.CHAT,
|
app_mode=app_mode,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
is_baichuan = True
|
is_baichuan = True
|
||||||
|
|
||||||
if is_baichuan:
|
if is_baichuan:
|
||||||
if app_mode == AppMode.WORKFLOW:
|
if app_mode == AppMode.COMPLETION:
|
||||||
return 'baichuan_completion'
|
return 'baichuan_completion'
|
||||||
else:
|
else:
|
||||||
return 'baichuan_chat'
|
return 'baichuan_chat'
|
||||||
|
|
||||||
# common
|
# common
|
||||||
if app_mode == AppMode.WORKFLOW:
|
if app_mode == AppMode.COMPLETION:
|
||||||
return 'common_completion'
|
return 'common_completion'
|
||||||
else:
|
else:
|
||||||
return 'common_chat'
|
return 'common_chat'
|
||||||
|
|||||||
@ -316,6 +316,9 @@ class AppModelConfigService:
|
|||||||
if "tool_parameters" not in tool:
|
if "tool_parameters" not in tool:
|
||||||
raise ValueError("tool_parameters is required in agent_mode.tools")
|
raise ValueError("tool_parameters is required in agent_mode.tools")
|
||||||
|
|
||||||
|
# dataset_query_variable
|
||||||
|
cls.is_dataset_query_variable_valid(config, app_mode)
|
||||||
|
|
||||||
# advanced prompt validation
|
# advanced prompt validation
|
||||||
cls.is_advanced_prompt_valid(config, app_mode)
|
cls.is_advanced_prompt_valid(config, app_mode)
|
||||||
|
|
||||||
@ -441,6 +444,21 @@ class AppModelConfigService:
|
|||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None:
|
||||||
|
# Only check when mode is completion
|
||||||
|
if mode != 'completion':
|
||||||
|
return
|
||||||
|
|
||||||
|
agent_mode = config.get("agent_mode", {})
|
||||||
|
tools = agent_mode.get("tools", [])
|
||||||
|
dataset_exists = "dataset" in str(tools)
|
||||||
|
|
||||||
|
dataset_query_variable = config.get("dataset_query_variable")
|
||||||
|
|
||||||
|
if dataset_exists and not dataset_query_variable:
|
||||||
|
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
|
def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
|
||||||
# prompt_type
|
# prompt_type
|
||||||
|
|||||||
@ -8,10 +8,12 @@ from core.application_manager import ApplicationManager
|
|||||||
from core.entities.application_entities import InvokeFrom
|
from core.entities.application_entities import InvokeFrom
|
||||||
from core.file.message_file_parser import MessageFileParser
|
from core.file.message_file_parser import MessageFileParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Account, App, AppModelConfig, Conversation, EndUser
|
from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message
|
||||||
from services.app_model_config_service import AppModelConfigService
|
from services.app_model_config_service import AppModelConfigService
|
||||||
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||||
|
from services.errors.message import MessageNotExistsError
|
||||||
|
|
||||||
|
|
||||||
class CompletionService:
|
class CompletionService:
|
||||||
@ -155,6 +157,62 @@ class CompletionService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||||
|
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
|
||||||
|
-> Union[dict, Generator]:
|
||||||
|
if not user:
|
||||||
|
raise ValueError('user cannot be None')
|
||||||
|
|
||||||
|
message = db.session.query(Message).filter(
|
||||||
|
Message.id == message_id,
|
||||||
|
Message.app_id == app_model.id,
|
||||||
|
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||||
|
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||||
|
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not message:
|
||||||
|
raise MessageNotExistsError()
|
||||||
|
|
||||||
|
current_app_model_config = app_model.app_model_config
|
||||||
|
more_like_this = current_app_model_config.more_like_this_dict
|
||||||
|
|
||||||
|
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
|
||||||
|
raise MoreLikeThisDisabledError()
|
||||||
|
|
||||||
|
app_model_config = message.app_model_config
|
||||||
|
model_dict = app_model_config.model_dict
|
||||||
|
completion_params = model_dict.get('completion_params')
|
||||||
|
completion_params['temperature'] = 0.9
|
||||||
|
model_dict['completion_params'] = completion_params
|
||||||
|
app_model_config.model = json.dumps(model_dict)
|
||||||
|
|
||||||
|
# parse files
|
||||||
|
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||||
|
file_objs = message_file_parser.transform_message_files(
|
||||||
|
message.files, app_model_config
|
||||||
|
)
|
||||||
|
|
||||||
|
application_manager = ApplicationManager()
|
||||||
|
return application_manager.generate(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
app_model_config_id=app_model_config.id,
|
||||||
|
app_model_config_dict=app_model_config.to_dict(),
|
||||||
|
app_model_config_override=True,
|
||||||
|
user=user,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
inputs=message.inputs,
|
||||||
|
query=message.query,
|
||||||
|
files=file_objs,
|
||||||
|
conversation=None,
|
||||||
|
stream=streaming,
|
||||||
|
extras={
|
||||||
|
"auto_generate_conversation_name": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||||
if user_inputs is None:
|
if user_inputs is None:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
||||||
'completion', 'audio', 'file'
|
'app', 'completion', 'audio', 'file'
|
||||||
]
|
]
|
||||||
|
|
||||||
from . import *
|
from . import *
|
||||||
|
|||||||
2
api/services/errors/app.py
Normal file
2
api/services/errors/app.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
class MoreLikeThisDisabledError(Exception):
|
||||||
|
pass
|
||||||
@ -160,6 +160,7 @@ def test__get_chat_model_prompt_messages():
|
|||||||
context = "yes or no."
|
context = "yes or no."
|
||||||
query = "How are you?"
|
query = "How are you?"
|
||||||
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
|
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
|
||||||
|
app_mode=AppMode.CHAT,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query,
|
query=query,
|
||||||
@ -214,6 +215,7 @@ def test__get_completion_model_prompt_messages():
|
|||||||
context = "yes or no."
|
context = "yes or no."
|
||||||
query = "How are you?"
|
query = "How are you?"
|
||||||
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
|
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
|
||||||
|
app_mode=AppMode.CHAT,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query,
|
query=query,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user