mirror of https://github.com/langgenius/dify.git
restore completion app
This commit is contained in:
parent
97c4733e79
commit
fce20e483c
|
|
@ -80,7 +80,7 @@ class AppListApi(Resource):
|
|||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json')
|
||||
parser.add_argument('mode', type=str, choices=['chat', 'agent', 'workflow'], location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('model_config', type=dict, location='json')
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class CompletionMessageApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
|
|
@ -90,7 +90,7 @@ class CompletionMessageStopApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model, task_id):
|
||||
account = flask_login.current_user
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class CompletionConversationApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_pagination_fields)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -102,7 +102,7 @@ class CompletionConversationDetailApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_message_detail_fields)
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
|
|
|
|||
|
|
@ -330,7 +330,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
|||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
|
|
@ -23,10 +24,13 @@ from controllers.console.explore.error import (
|
|||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
|
@ -72,6 +76,48 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
)
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
|
|
@ -120,4 +166,5 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||
|
||||
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
|
||||
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
|
||||
api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
|
||||
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
|||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
CompletionRequestError,
|
||||
NotChatAppError,
|
||||
|
|
@ -20,11 +21,14 @@ from controllers.web.error import (
|
|||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
|
@ -109,6 +113,48 @@ class MessageFeedbackApi(WebApiResource):
|
|||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageMoreLikeThisApi(WebApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
|
|
@ -156,4 +202,5 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
|
||||
api.add_resource(MessageListApi, '/messages')
|
||||
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
|
||||
api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')
|
||||
|
|
|
|||
|
|
@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
|
|||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from models.model import App, Message, MessageAnnotation
|
||||
from models.model import App, Message, MessageAnnotation, AppMode
|
||||
|
||||
|
||||
class AppRunner:
|
||||
|
|
@ -140,11 +141,11 @@ class AppRunner:
|
|||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = SimplePromptTransform()
|
||||
|
||||
# get prompt without memory and context
|
||||
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_messages, stop = prompt_transform.get_prompt(
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
|
|
@ -154,7 +155,17 @@ class AppRunner:
|
|||
model_config=model_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Advanced prompt is not supported yet.")
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
|
|
|
|||
|
|
@ -11,10 +11,9 @@ class PromptTransform:
|
|||
def _append_chat_histories(self, memory: TokenBufferMemory,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigEntity) -> list[PromptMessage]:
|
||||
if memory:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
"""
|
||||
|
||||
def get_prompt(self,
|
||||
app_mode: AppMode,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
|
|
@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
|
|
@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
)
|
||||
else:
|
||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
|
|
@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform):
|
|||
"prompt_rules": prompt_rules
|
||||
}
|
||||
|
||||
def _get_chat_model_prompt_messages(self, pre_prompt: str,
|
||||
def _get_chat_model_prompt_messages(self, app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
|
|
@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
|
||||
# get prompt
|
||||
prompt, _ = self.get_prompt_str_and_rules(
|
||||
app_mode=AppMode.CHAT,
|
||||
app_mode=app_mode,
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
|
|
@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform):
|
|||
)
|
||||
|
||||
if prompt:
|
||||
prompt_messages.append(SystemPromptMessage(content=prompt))
|
||||
if query:
|
||||
prompt_messages.append(SystemPromptMessage(content=prompt))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory=memory,
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=model_config
|
||||
)
|
||||
if memory:
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory=memory,
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
prompt_messages.append(self.get_last_user_message(query, files))
|
||||
if query:
|
||||
prompt_messages.append(self.get_last_user_message(query, files))
|
||||
|
||||
return prompt_messages, None
|
||||
|
||||
def _get_completion_model_prompt_messages(self, pre_prompt: str,
|
||||
def _get_completion_model_prompt_messages(self, app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
|
|
@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
# get prompt
|
||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
||||
app_mode=AppMode.CHAT,
|
||||
app_mode=app_mode,
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
|
|
@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
|
||||
# get prompt
|
||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
||||
app_mode=AppMode.CHAT,
|
||||
app_mode=app_mode,
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
|
|
@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform):
|
|||
is_baichuan = True
|
||||
|
||||
if is_baichuan:
|
||||
if app_mode == AppMode.WORKFLOW:
|
||||
if app_mode == AppMode.COMPLETION:
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
|
||||
# common
|
||||
if app_mode == AppMode.WORKFLOW:
|
||||
if app_mode == AppMode.COMPLETION:
|
||||
return 'common_completion'
|
||||
else:
|
||||
return 'common_chat'
|
||||
|
|
|
|||
|
|
@ -316,6 +316,9 @@ class AppModelConfigService:
|
|||
if "tool_parameters" not in tool:
|
||||
raise ValueError("tool_parameters is required in agent_mode.tools")
|
||||
|
||||
# dataset_query_variable
|
||||
cls.is_dataset_query_variable_valid(config, app_mode)
|
||||
|
||||
# advanced prompt validation
|
||||
cls.is_advanced_prompt_valid(config, app_mode)
|
||||
|
||||
|
|
@ -441,6 +444,21 @@ class AppModelConfigService:
|
|||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None:
|
||||
# Only check when mode is completion
|
||||
if mode != 'completion':
|
||||
return
|
||||
|
||||
agent_mode = config.get("agent_mode", {})
|
||||
tools = agent_mode.get("tools", [])
|
||||
dataset_exists = "dataset" in str(tools)
|
||||
|
||||
dataset_query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_exists and not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
@classmethod
|
||||
def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
|
||||
# prompt_type
|
||||
|
|
|
|||
|
|
@ -8,10 +8,12 @@ from core.application_manager import ApplicationManager
|
|||
from core.entities.application_entities import InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Account, App, AppModelConfig, Conversation, EndUser
|
||||
from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
class CompletionService:
|
||||
|
|
@ -155,6 +157,62 @@ class CompletionService:
|
|||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
current_app_model_config = app_model.app_model_config
|
||||
more_like_this = current_app_model_config.more_like_this_dict
|
||||
|
||||
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
|
||||
raise MoreLikeThisDisabledError()
|
||||
|
||||
app_model_config = message.app_model_config
|
||||
model_dict = app_model_config.model_dict
|
||||
completion_params = model_dict.get('completion_params')
|
||||
completion_params['temperature'] = 0.9
|
||||
model_dict['completion_params'] = completion_params
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
message.files, app_model_config
|
||||
)
|
||||
|
||||
application_manager = ApplicationManager()
|
||||
return application_manager.generate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=app_model_config.to_dict(),
|
||||
app_model_config_override=True,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=file_objs,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||
if user_inputs is None:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
__all__ = [
|
||||
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
||||
'completion', 'audio', 'file'
|
||||
'app', 'completion', 'audio', 'file'
|
||||
]
|
||||
|
||||
from . import *
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
class MoreLikeThisDisabledError(Exception):
|
||||
pass
|
||||
|
|
@ -160,6 +160,7 @@ def test__get_chat_model_prompt_messages():
|
|||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
|
||||
app_mode=AppMode.CHAT,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
|
|
@ -214,6 +215,7 @@ def test__get_completion_model_prompt_messages():
|
|||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
|
||||
app_mode=AppMode.CHAT,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
|
|
|
|||
Loading…
Reference in New Issue