mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 00:18:28 +08:00
refactor app
This commit is contained in:
parent
896c200211
commit
799db69e4f
@ -21,7 +21,7 @@ from controllers.console.app.error import (
|
|||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.app.app_queue_manager import AppQueueManager
|
||||||
from core.entities.application_entities import InvokeFrom
|
from core.entities.application_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
@ -94,7 +94,7 @@ class CompletionMessageStopApi(Resource):
|
|||||||
def post(self, app_model, task_id):
|
def post(self, app_model, task_id):
|
||||||
account = flask_login.current_user
|
account = flask_login.current_user
|
||||||
|
|
||||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
@ -172,7 +172,7 @@ class ChatMessageStopApi(Resource):
|
|||||||
def post(self, app_model, task_id):
|
def post(self, app_model, task_id):
|
||||||
account = flask_login.current_user
|
account = flask_login.current_user
|
||||||
|
|
||||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from controllers.console.app.error import (
|
|||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from controllers.console.app.error import (
|
|||||||
)
|
)
|
||||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.app.app_queue_manager import AppQueueManager
|
||||||
from core.entities.application_entities import InvokeFrom
|
from core.entities.application_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
@ -90,7 +90,7 @@ class CompletionStopApi(InstalledAppResource):
|
|||||||
if app_model.mode != 'completion':
|
if app_model.mode != 'completion':
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
@ -154,7 +154,7 @@ class ChatStopApi(InstalledAppResource):
|
|||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from controllers.service_api.app.error import (
|
|||||||
ProviderQuotaExceededError,
|
ProviderQuotaExceededError,
|
||||||
)
|
)
|
||||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.app.app_queue_manager import AppQueueManager
|
||||||
from core.entities.application_entities import InvokeFrom
|
from core.entities.application_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
@ -85,7 +85,7 @@ class CompletionStopApi(Resource):
|
|||||||
if app_model.mode != 'completion':
|
if app_model.mode != 'completion':
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
|
|
||||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ class ChatStopApi(Resource):
|
|||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from controllers.web.error import (
|
|||||||
ProviderQuotaExceededError,
|
ProviderQuotaExceededError,
|
||||||
)
|
)
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.app.app_queue_manager import AppQueueManager
|
||||||
from core.entities.application_entities import InvokeFrom
|
from core.entities.application_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
@ -84,7 +84,7 @@ class CompletionStopApi(WebApiResource):
|
|||||||
if app_model.mode != 'completion':
|
if app_model.mode != 'completion':
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ class ChatStopApi(WebApiResource):
|
|||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|||||||
@ -5,8 +5,8 @@ from datetime import datetime
|
|||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from core.app_runner.app_runner import AppRunner
|
from core.app.base_app_runner import AppRunner
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.app.app_queue_manager import AppQueueManager
|
||||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.entities.application_entities import (
|
from core.entities.application_entities import (
|
||||||
@ -48,13 +48,13 @@ from models.tools import ToolConversationVariables
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BaseAssistantApplicationRunner(AppRunner):
|
class BaseAgentRunner(AppRunner):
|
||||||
def __init__(self, tenant_id: str,
|
def __init__(self, tenant_id: str,
|
||||||
application_generate_entity: ApplicationGenerateEntity,
|
application_generate_entity: ApplicationGenerateEntity,
|
||||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||||
model_config: ModelConfigEntity,
|
model_config: ModelConfigEntity,
|
||||||
config: AgentEntity,
|
config: AgentEntity,
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
message: Message,
|
message: Message,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
memory: Optional[TokenBufferMemory] = None,
|
memory: Optional[TokenBufferMemory] = None,
|
||||||
@ -3,9 +3,9 @@ import re
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from core.application_queue_manager import PublishFrom
|
from core.app.app_queue_manager import PublishFrom
|
||||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
from core.agent.base_agent_runner import BaseAgentRunner
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -262,7 +262,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
tool_call_args = json.loads(tool_call_args)
|
tool_call_args = json.loads(tool_call_args)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
tool_response = tool_instance.invoke(
|
tool_response = tool_instance.invoke(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
tool_parameters=tool_call_args
|
tool_parameters=tool_call_args
|
||||||
@ -3,8 +3,8 @@ import logging
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from core.application_queue_manager import PublishFrom
|
from core.app.app_queue_manager import PublishFrom
|
||||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
from core.agent.base_agent_runner import BaseAgentRunner
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -26,7 +26,7 @@ from models.model import Conversation, Message, MessageAgentThought
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
def run(self, conversation: Conversation,
|
def run(self, conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
query: str,
|
query: str,
|
||||||
@ -1,10 +1,10 @@
|
|||||||
from core.apps.config_validators.file_upload import FileUploadValidator
|
from core.app.validators.file_upload import FileUploadValidator
|
||||||
from core.apps.config_validators.moderation import ModerationValidator
|
from core.app.validators.moderation import ModerationValidator
|
||||||
from core.apps.config_validators.opening_statement import OpeningStatementValidator
|
from core.app.validators.opening_statement import OpeningStatementValidator
|
||||||
from core.apps.config_validators.retriever_resource import RetrieverResourceValidator
|
from core.app.validators.retriever_resource import RetrieverResourceValidator
|
||||||
from core.apps.config_validators.speech_to_text import SpeechToTextValidator
|
from core.app.validators.speech_to_text import SpeechToTextValidator
|
||||||
from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator
|
from core.app.validators.suggested_questions import SuggestedQuestionsValidator
|
||||||
from core.apps.config_validators.text_to_speech import TextToSpeechValidator
|
from core.app.validators.text_to_speech import TextToSpeechValidator
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatAppConfigValidator:
|
class AdvancedChatAppConfigValidator:
|
||||||
@ -1,11 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.app_runner.app_runner import AppRunner
|
from core.app.base_app_runner import AppRunner
|
||||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity
|
from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity
|
||||||
from core.features.assistant_cot_runner import AssistantCotApplicationRunner
|
from core.agent.cot_agent_runner import CotAgentRunner
|
||||||
from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
|
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
@ -19,12 +19,13 @@ from models.tools import ToolConversationVariables
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class AssistantApplicationRunner(AppRunner):
|
|
||||||
|
class AgentChatAppRunner(AppRunner):
|
||||||
"""
|
"""
|
||||||
Assistant Application Runner
|
Agent Application Runner
|
||||||
"""
|
"""
|
||||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message) -> None:
|
message: Message) -> None:
|
||||||
"""
|
"""
|
||||||
@ -201,7 +202,7 @@ class AssistantApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
# start agent runner
|
# start agent runner
|
||||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||||
assistant_cot_runner = AssistantCotApplicationRunner(
|
assistant_cot_runner = CotAgentRunner(
|
||||||
tenant_id=application_generate_entity.tenant_id,
|
tenant_id=application_generate_entity.tenant_id,
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
app_orchestration_config=app_orchestration_config,
|
app_orchestration_config=app_orchestration_config,
|
||||||
@ -223,7 +224,7 @@ class AssistantApplicationRunner(AppRunner):
|
|||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
)
|
)
|
||||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||||
assistant_fc_runner = AssistantFunctionCallApplicationRunner(
|
assistant_fc_runner = FunctionCallAgentRunner(
|
||||||
tenant_id=application_generate_entity.tenant_id,
|
tenant_id=application_generate_entity.tenant_id,
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
app_orchestration_config=app_orchestration_config,
|
app_orchestration_config=app_orchestration_config,
|
||||||
162
api/core/app/agent_chat/config_validator.py
Normal file
162
api/core/app/agent_chat/config_validator.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
|
from core.app.validators.dataset_retrieval import DatasetValidator
|
||||||
|
from core.app.validators.external_data_fetch import ExternalDataFetchValidator
|
||||||
|
from core.app.validators.file_upload import FileUploadValidator
|
||||||
|
from core.app.validators.model_validator import ModelValidator
|
||||||
|
from core.app.validators.moderation import ModerationValidator
|
||||||
|
from core.app.validators.opening_statement import OpeningStatementValidator
|
||||||
|
from core.app.validators.prompt import PromptValidator
|
||||||
|
from core.app.validators.retriever_resource import RetrieverResourceValidator
|
||||||
|
from core.app.validators.speech_to_text import SpeechToTextValidator
|
||||||
|
from core.app.validators.suggested_questions import SuggestedQuestionsValidator
|
||||||
|
from core.app.validators.text_to_speech import TextToSpeechValidator
|
||||||
|
from core.app.validators.user_input_form import UserInputFormValidator
|
||||||
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
|
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentChatAppConfigValidator:
|
||||||
|
@classmethod
|
||||||
|
def config_validate(cls, tenant_id: str, config: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Validate for agent chat app model config
|
||||||
|
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param config: app model config args
|
||||||
|
"""
|
||||||
|
app_mode = AppMode.AGENT_CHAT
|
||||||
|
|
||||||
|
related_config_keys = []
|
||||||
|
|
||||||
|
# model
|
||||||
|
config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# user_input_form
|
||||||
|
config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# external data tools validation
|
||||||
|
config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# file upload validation
|
||||||
|
config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# prompt
|
||||||
|
config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# agent_mode
|
||||||
|
config, current_related_config_keys = cls.validate_agent_mode_and_set_defaults(tenant_id, config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# opening_statement
|
||||||
|
config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# suggested_questions_after_answer
|
||||||
|
config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# speech_to_text
|
||||||
|
config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# text_to_speech
|
||||||
|
config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# return retriever resource
|
||||||
|
config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# moderation validation
|
||||||
|
config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
related_config_keys = list(set(related_config_keys))
|
||||||
|
|
||||||
|
# Filter out extra parameters
|
||||||
|
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||||
|
|
||||||
|
return filtered_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||||
|
"""
|
||||||
|
Validate agent_mode and set defaults for agent feature
|
||||||
|
|
||||||
|
:param tenant_id: tenant ID
|
||||||
|
:param config: app model config args
|
||||||
|
"""
|
||||||
|
if not config.get("agent_mode"):
|
||||||
|
config["agent_mode"] = {
|
||||||
|
"enabled": False,
|
||||||
|
"tools": []
|
||||||
|
}
|
||||||
|
|
||||||
|
if not isinstance(config["agent_mode"], dict):
|
||||||
|
raise ValueError("agent_mode must be of object type")
|
||||||
|
|
||||||
|
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
|
||||||
|
config["agent_mode"]["enabled"] = False
|
||||||
|
|
||||||
|
if not isinstance(config["agent_mode"]["enabled"], bool):
|
||||||
|
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||||
|
|
||||||
|
if not config["agent_mode"].get("strategy"):
|
||||||
|
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||||
|
|
||||||
|
if config["agent_mode"]["strategy"] not in [member.value for member in
|
||||||
|
list(PlanningStrategy.__members__.values())]:
|
||||||
|
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||||
|
|
||||||
|
if not config["agent_mode"].get("tools"):
|
||||||
|
config["agent_mode"]["tools"] = []
|
||||||
|
|
||||||
|
if not isinstance(config["agent_mode"]["tools"], list):
|
||||||
|
raise ValueError("tools in agent_mode must be a list of objects")
|
||||||
|
|
||||||
|
for tool in config["agent_mode"]["tools"]:
|
||||||
|
key = list(tool.keys())[0]
|
||||||
|
if key in OLD_TOOLS:
|
||||||
|
# old style, use tool name as key
|
||||||
|
tool_item = tool[key]
|
||||||
|
|
||||||
|
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||||
|
tool_item["enabled"] = False
|
||||||
|
|
||||||
|
if not isinstance(tool_item["enabled"], bool):
|
||||||
|
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||||
|
|
||||||
|
if key == "dataset":
|
||||||
|
if 'id' not in tool_item:
|
||||||
|
raise ValueError("id is required in dataset")
|
||||||
|
|
||||||
|
try:
|
||||||
|
uuid.UUID(tool_item["id"])
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("id in dataset must be of UUID type")
|
||||||
|
|
||||||
|
if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]):
|
||||||
|
raise ValueError("Dataset ID does not exist, please check your permission.")
|
||||||
|
else:
|
||||||
|
# latest style, use key-value pair
|
||||||
|
if "enabled" not in tool or not tool["enabled"]:
|
||||||
|
tool["enabled"] = False
|
||||||
|
if "provider_type" not in tool:
|
||||||
|
raise ValueError("provider_type is required in agent_mode.tools")
|
||||||
|
if "provider_id" not in tool:
|
||||||
|
raise ValueError("provider_id is required in agent_mode.tools")
|
||||||
|
if "tool_name" not in tool:
|
||||||
|
raise ValueError("tool_name is required in agent_mode.tools")
|
||||||
|
if "tool_parameters" not in tool:
|
||||||
|
raise ValueError("tool_parameters is required in agent_mode.tools")
|
||||||
|
|
||||||
|
return config, ["agent_mode"]
|
||||||
382
api/core/app/app_manager.py
Normal file
382
api/core/app/app_manager.py
Normal file
@ -0,0 +1,382 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
|
from flask import Flask, current_app
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter
|
||||||
|
from core.app.agent_chat.app_runner import AgentChatAppRunner
|
||||||
|
from core.app.chat.app_runner import ChatAppRunner
|
||||||
|
from core.app.generate_task_pipeline import GenerateTaskPipeline
|
||||||
|
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||||
|
from core.entities.application_entities import (
|
||||||
|
ApplicationGenerateEntity,
|
||||||
|
InvokeFrom,
|
||||||
|
)
|
||||||
|
from core.file.file_obj import FileObj
|
||||||
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.account import Account
|
||||||
|
from models.model import App, Conversation, EndUser, Message, MessageFile
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AppManager:
|
||||||
|
"""
|
||||||
|
This class is responsible for managing application
|
||||||
|
"""
|
||||||
|
|
||||||
|
def generate(self, tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
app_model_config_id: str,
|
||||||
|
app_model_config_dict: dict,
|
||||||
|
app_model_config_override: bool,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
inputs: dict[str, str],
|
||||||
|
query: Optional[str] = None,
|
||||||
|
files: Optional[list[FileObj]] = None,
|
||||||
|
conversation: Optional[Conversation] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
extras: Optional[dict[str, Any]] = None) \
|
||||||
|
-> Union[dict, Generator]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param tenant_id: workspace ID
|
||||||
|
:param app_id: app ID
|
||||||
|
:param app_model_config_id: app model config id
|
||||||
|
:param app_model_config_dict: app model config dict
|
||||||
|
:param app_model_config_override: app model config override
|
||||||
|
:param user: account or end user
|
||||||
|
:param invoke_from: invoke from source
|
||||||
|
:param inputs: inputs
|
||||||
|
:param query: query
|
||||||
|
:param files: file obj list
|
||||||
|
:param conversation: conversation
|
||||||
|
:param stream: is stream
|
||||||
|
:param extras: extras
|
||||||
|
"""
|
||||||
|
# init task id
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# init application generate entity
|
||||||
|
application_generate_entity = ApplicationGenerateEntity(
|
||||||
|
task_id=task_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_id=app_id,
|
||||||
|
app_model_config_id=app_model_config_id,
|
||||||
|
app_model_config_dict=app_model_config_dict,
|
||||||
|
app_orchestration_config_entity=AppOrchestrationConfigConverter.convert_from_app_model_config_dict(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_model_config_dict=app_model_config_dict
|
||||||
|
),
|
||||||
|
app_model_config_override=app_model_config_override,
|
||||||
|
conversation_id=conversation.id if conversation else None,
|
||||||
|
inputs=conversation.inputs if conversation else inputs,
|
||||||
|
query=query.replace('\x00', '') if query else None,
|
||||||
|
files=files if files else [],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=stream,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
extras=extras
|
||||||
|
)
|
||||||
|
|
||||||
|
if not stream and application_generate_entity.app_orchestration_config_entity.agent:
|
||||||
|
raise ValueError("Agent app is not supported in blocking mode.")
|
||||||
|
|
||||||
|
# init generate records
|
||||||
|
(
|
||||||
|
conversation,
|
||||||
|
message
|
||||||
|
) = self._init_generate_records(application_generate_entity)
|
||||||
|
|
||||||
|
# init queue manager
|
||||||
|
queue_manager = AppQueueManager(
|
||||||
|
task_id=application_generate_entity.task_id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
app_mode=conversation.mode,
|
||||||
|
message_id=message.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# new thread
|
||||||
|
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||||
|
'flask_app': current_app._get_current_object(),
|
||||||
|
'application_generate_entity': application_generate_entity,
|
||||||
|
'queue_manager': queue_manager,
|
||||||
|
'conversation_id': conversation.id,
|
||||||
|
'message_id': message.id,
|
||||||
|
})
|
||||||
|
|
||||||
|
worker_thread.start()
|
||||||
|
|
||||||
|
# return response or stream generator
|
||||||
|
return self._handle_response(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation=conversation,
|
||||||
|
message=message,
|
||||||
|
stream=stream
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_worker(self, flask_app: Flask,
|
||||||
|
application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
conversation_id: str,
|
||||||
|
message_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Generate worker in a new thread.
|
||||||
|
:param flask_app: Flask app
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param conversation_id: conversation ID
|
||||||
|
:param message_id: message ID
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
with flask_app.app_context():
|
||||||
|
try:
|
||||||
|
# get conversation and message
|
||||||
|
conversation = self._get_conversation(conversation_id)
|
||||||
|
message = self._get_message(message_id)
|
||||||
|
|
||||||
|
if application_generate_entity.app_orchestration_config_entity.agent:
|
||||||
|
# agent app
|
||||||
|
runner = AgentChatAppRunner()
|
||||||
|
runner.run(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation=conversation,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# basic app
|
||||||
|
runner = ChatAppRunner()
|
||||||
|
runner.run(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation=conversation,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
except ConversationTaskStoppedException:
|
||||||
|
pass
|
||||||
|
except InvokeAuthorizationError:
|
||||||
|
queue_manager.publish_error(
|
||||||
|
InvokeAuthorizationError('Incorrect API key provided'),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.exception("Validation Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except (ValueError, InvokeError) as e:
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unknown Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
finally:
|
||||||
|
db.session.remove()
|
||||||
|
|
||||||
|
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
conversation: Conversation,
|
||||||
|
message: Message,
|
||||||
|
stream: bool = False) -> Union[dict, Generator]:
|
||||||
|
"""
|
||||||
|
Handle response.
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param conversation: conversation
|
||||||
|
:param message: message
|
||||||
|
:param stream: is stream
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# init generate task pipeline
|
||||||
|
generate_task_pipeline = GenerateTaskPipeline(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation=conversation,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return generate_task_pipeline.process(stream=stream)
|
||||||
|
except ValueError as e:
|
||||||
|
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
|
raise ConversationTaskStoppedException()
|
||||||
|
else:
|
||||||
|
logger.exception(e)
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
db.session.remove()
|
||||||
|
|
||||||
|
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
||||||
|
-> tuple[Conversation, Message]:
|
||||||
|
"""
|
||||||
|
Initialize generate records
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
||||||
|
|
||||||
|
model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
model_schema = model_type_instance.get_model_schema(
|
||||||
|
model=app_orchestration_config_entity.model_config.model,
|
||||||
|
credentials=app_orchestration_config_entity.model_config.credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
app_record = (db.session.query(App)
|
||||||
|
.filter(App.id == application_generate_entity.app_id).first())
|
||||||
|
|
||||||
|
app_mode = app_record.mode
|
||||||
|
|
||||||
|
# get from source
|
||||||
|
end_user_id = None
|
||||||
|
account_id = None
|
||||||
|
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||||
|
from_source = 'api'
|
||||||
|
end_user_id = application_generate_entity.user_id
|
||||||
|
else:
|
||||||
|
from_source = 'console'
|
||||||
|
account_id = application_generate_entity.user_id
|
||||||
|
|
||||||
|
override_model_configs = None
|
||||||
|
if application_generate_entity.app_model_config_override:
|
||||||
|
override_model_configs = application_generate_entity.app_model_config_dict
|
||||||
|
|
||||||
|
introduction = ''
|
||||||
|
if app_mode == 'chat':
|
||||||
|
# get conversation introduction
|
||||||
|
introduction = self._get_conversation_introduction(application_generate_entity)
|
||||||
|
|
||||||
|
if not application_generate_entity.conversation_id:
|
||||||
|
conversation = Conversation(
|
||||||
|
app_id=app_record.id,
|
||||||
|
app_model_config_id=application_generate_entity.app_model_config_id,
|
||||||
|
model_provider=app_orchestration_config_entity.model_config.provider,
|
||||||
|
model_id=app_orchestration_config_entity.model_config.model,
|
||||||
|
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||||
|
mode=app_mode,
|
||||||
|
name='New conversation',
|
||||||
|
inputs=application_generate_entity.inputs,
|
||||||
|
introduction=introduction,
|
||||||
|
system_instruction="",
|
||||||
|
system_instruction_tokens=0,
|
||||||
|
status='normal',
|
||||||
|
from_source=from_source,
|
||||||
|
from_end_user_id=end_user_id,
|
||||||
|
from_account_id=account_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(conversation)
|
||||||
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
conversation = (
|
||||||
|
db.session.query(Conversation)
|
||||||
|
.filter(
|
||||||
|
Conversation.id == application_generate_entity.conversation_id,
|
||||||
|
Conversation.app_id == app_record.id
|
||||||
|
).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
|
||||||
|
|
||||||
|
message = Message(
|
||||||
|
app_id=app_record.id,
|
||||||
|
model_provider=app_orchestration_config_entity.model_config.provider,
|
||||||
|
model_id=app_orchestration_config_entity.model_config.model,
|
||||||
|
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
inputs=application_generate_entity.inputs,
|
||||||
|
query=application_generate_entity.query or "",
|
||||||
|
message="",
|
||||||
|
message_tokens=0,
|
||||||
|
message_unit_price=0,
|
||||||
|
message_price_unit=0,
|
||||||
|
answer="",
|
||||||
|
answer_tokens=0,
|
||||||
|
answer_unit_price=0,
|
||||||
|
answer_price_unit=0,
|
||||||
|
provider_response_latency=0,
|
||||||
|
total_price=0,
|
||||||
|
currency=currency,
|
||||||
|
from_source=from_source,
|
||||||
|
from_end_user_id=end_user_id,
|
||||||
|
from_account_id=account_id,
|
||||||
|
agent_based=app_orchestration_config_entity.agent is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(message)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
for file in application_generate_entity.files:
|
||||||
|
message_file = MessageFile(
|
||||||
|
message_id=message.id,
|
||||||
|
type=file.type.value,
|
||||||
|
transfer_method=file.transfer_method.value,
|
||||||
|
belongs_to='user',
|
||||||
|
url=file.url,
|
||||||
|
upload_file_id=file.upload_file_id,
|
||||||
|
created_by_role=('account' if account_id else 'end_user'),
|
||||||
|
created_by=account_id or end_user_id,
|
||||||
|
)
|
||||||
|
db.session.add(message_file)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return conversation, message
|
||||||
|
|
||||||
|
def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
|
||||||
|
"""
|
||||||
|
Get conversation introduction
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:return: conversation introduction
|
||||||
|
"""
|
||||||
|
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
||||||
|
introduction = app_orchestration_config_entity.opening_statement
|
||||||
|
|
||||||
|
if introduction:
|
||||||
|
try:
|
||||||
|
inputs = application_generate_entity.inputs
|
||||||
|
prompt_template = PromptTemplateParser(template=introduction)
|
||||||
|
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||||
|
introduction = prompt_template.format(prompt_inputs)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return introduction
|
||||||
|
|
||||||
|
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||||
|
"""
|
||||||
|
Get conversation by conversation id
|
||||||
|
:param conversation_id: conversation id
|
||||||
|
:return: conversation
|
||||||
|
"""
|
||||||
|
conversation = (
|
||||||
|
db.session.query(Conversation)
|
||||||
|
.filter(Conversation.id == conversation_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
def _get_message(self, message_id: str) -> Message:
|
||||||
|
"""
|
||||||
|
Get message by message id
|
||||||
|
:param message_id: message id
|
||||||
|
:return: message
|
||||||
|
"""
|
||||||
|
message = (
|
||||||
|
db.session.query(Message)
|
||||||
|
.filter(Message.id == message_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
return message
|
||||||
@ -1,241 +1,21 @@
|
|||||||
import json
|
from typing import cast
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Generator
|
|
||||||
from typing import Any, Optional, Union, cast
|
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from core.entities.application_entities import AppOrchestrationConfigEntity, SensitiveWordAvoidanceEntity, \
|
||||||
from pydantic import ValidationError
|
TextToSpeechEntity, DatasetRetrieveConfigEntity, DatasetEntity, AgentPromptEntity, AgentEntity, AgentToolEntity, \
|
||||||
|
ExternalDataVariableEntity, VariableEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, \
|
||||||
from core.app_runner.assistant_app_runner import AssistantApplicationRunner
|
AdvancedChatPromptTemplateEntity, ModelConfigEntity, FileUploadEntity
|
||||||
from core.app_runner.basic_app_runner import BasicApplicationRunner
|
|
||||||
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
|
|
||||||
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
|
|
||||||
from core.entities.application_entities import (
|
|
||||||
AdvancedChatPromptTemplateEntity,
|
|
||||||
AdvancedCompletionPromptTemplateEntity,
|
|
||||||
AgentEntity,
|
|
||||||
AgentPromptEntity,
|
|
||||||
AgentToolEntity,
|
|
||||||
ApplicationGenerateEntity,
|
|
||||||
AppOrchestrationConfigEntity,
|
|
||||||
DatasetEntity,
|
|
||||||
DatasetRetrieveConfigEntity,
|
|
||||||
ExternalDataVariableEntity,
|
|
||||||
FileUploadEntity,
|
|
||||||
InvokeFrom,
|
|
||||||
ModelConfigEntity,
|
|
||||||
PromptTemplateEntity,
|
|
||||||
SensitiveWordAvoidanceEntity,
|
|
||||||
TextToSpeechEntity,
|
|
||||||
VariableEntity,
|
|
||||||
)
|
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError
|
||||||
from core.file.file_obj import FileObj
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
|
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.account import Account
|
|
||||||
from models.model import App, Conversation, EndUser, Message, MessageFile
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationManager:
|
class AppOrchestrationConfigConverter:
|
||||||
"""
|
@classmethod
|
||||||
This class is responsible for managing application
|
def convert_from_app_model_config_dict(cls, tenant_id: str,
|
||||||
"""
|
|
||||||
|
|
||||||
def generate(self, tenant_id: str,
|
|
||||||
app_id: str,
|
|
||||||
app_model_config_id: str,
|
|
||||||
app_model_config_dict: dict,
|
|
||||||
app_model_config_override: bool,
|
|
||||||
user: Union[Account, EndUser],
|
|
||||||
invoke_from: InvokeFrom,
|
|
||||||
inputs: dict[str, str],
|
|
||||||
query: Optional[str] = None,
|
|
||||||
files: Optional[list[FileObj]] = None,
|
|
||||||
conversation: Optional[Conversation] = None,
|
|
||||||
stream: bool = False,
|
|
||||||
extras: Optional[dict[str, Any]] = None) \
|
|
||||||
-> Union[dict, Generator]:
|
|
||||||
"""
|
|
||||||
Generate App response.
|
|
||||||
|
|
||||||
:param tenant_id: workspace ID
|
|
||||||
:param app_id: app ID
|
|
||||||
:param app_model_config_id: app model config id
|
|
||||||
:param app_model_config_dict: app model config dict
|
|
||||||
:param app_model_config_override: app model config override
|
|
||||||
:param user: account or end user
|
|
||||||
:param invoke_from: invoke from source
|
|
||||||
:param inputs: inputs
|
|
||||||
:param query: query
|
|
||||||
:param files: file obj list
|
|
||||||
:param conversation: conversation
|
|
||||||
:param stream: is stream
|
|
||||||
:param extras: extras
|
|
||||||
"""
|
|
||||||
# init task id
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# init application generate entity
|
|
||||||
application_generate_entity = ApplicationGenerateEntity(
|
|
||||||
task_id=task_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
app_id=app_id,
|
|
||||||
app_model_config_id=app_model_config_id,
|
|
||||||
app_model_config_dict=app_model_config_dict,
|
|
||||||
app_orchestration_config_entity=self.convert_from_app_model_config_dict(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
app_model_config_dict=app_model_config_dict
|
|
||||||
),
|
|
||||||
app_model_config_override=app_model_config_override,
|
|
||||||
conversation_id=conversation.id if conversation else None,
|
|
||||||
inputs=conversation.inputs if conversation else inputs,
|
|
||||||
query=query.replace('\x00', '') if query else None,
|
|
||||||
files=files if files else [],
|
|
||||||
user_id=user.id,
|
|
||||||
stream=stream,
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
extras=extras
|
|
||||||
)
|
|
||||||
|
|
||||||
if not stream and application_generate_entity.app_orchestration_config_entity.agent:
|
|
||||||
raise ValueError("Agent app is not supported in blocking mode.")
|
|
||||||
|
|
||||||
# init generate records
|
|
||||||
(
|
|
||||||
conversation,
|
|
||||||
message
|
|
||||||
) = self._init_generate_records(application_generate_entity)
|
|
||||||
|
|
||||||
# init queue manager
|
|
||||||
queue_manager = ApplicationQueueManager(
|
|
||||||
task_id=application_generate_entity.task_id,
|
|
||||||
user_id=application_generate_entity.user_id,
|
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
|
||||||
conversation_id=conversation.id,
|
|
||||||
app_mode=conversation.mode,
|
|
||||||
message_id=message.id
|
|
||||||
)
|
|
||||||
|
|
||||||
# new thread
|
|
||||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
|
||||||
'flask_app': current_app._get_current_object(),
|
|
||||||
'application_generate_entity': application_generate_entity,
|
|
||||||
'queue_manager': queue_manager,
|
|
||||||
'conversation_id': conversation.id,
|
|
||||||
'message_id': message.id,
|
|
||||||
})
|
|
||||||
|
|
||||||
worker_thread.start()
|
|
||||||
|
|
||||||
# return response or stream generator
|
|
||||||
return self._handle_response(
|
|
||||||
application_generate_entity=application_generate_entity,
|
|
||||||
queue_manager=queue_manager,
|
|
||||||
conversation=conversation,
|
|
||||||
message=message,
|
|
||||||
stream=stream
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_worker(self, flask_app: Flask,
|
|
||||||
application_generate_entity: ApplicationGenerateEntity,
|
|
||||||
queue_manager: ApplicationQueueManager,
|
|
||||||
conversation_id: str,
|
|
||||||
message_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Generate worker in a new thread.
|
|
||||||
:param flask_app: Flask app
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param queue_manager: queue manager
|
|
||||||
:param conversation_id: conversation ID
|
|
||||||
:param message_id: message ID
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
with flask_app.app_context():
|
|
||||||
try:
|
|
||||||
# get conversation and message
|
|
||||||
conversation = self._get_conversation(conversation_id)
|
|
||||||
message = self._get_message(message_id)
|
|
||||||
|
|
||||||
if application_generate_entity.app_orchestration_config_entity.agent:
|
|
||||||
# agent app
|
|
||||||
runner = AssistantApplicationRunner()
|
|
||||||
runner.run(
|
|
||||||
application_generate_entity=application_generate_entity,
|
|
||||||
queue_manager=queue_manager,
|
|
||||||
conversation=conversation,
|
|
||||||
message=message
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# basic app
|
|
||||||
runner = BasicApplicationRunner()
|
|
||||||
runner.run(
|
|
||||||
application_generate_entity=application_generate_entity,
|
|
||||||
queue_manager=queue_manager,
|
|
||||||
conversation=conversation,
|
|
||||||
message=message
|
|
||||||
)
|
|
||||||
except ConversationTaskStoppedException:
|
|
||||||
pass
|
|
||||||
except InvokeAuthorizationError:
|
|
||||||
queue_manager.publish_error(
|
|
||||||
InvokeAuthorizationError('Incorrect API key provided'),
|
|
||||||
PublishFrom.APPLICATION_MANAGER
|
|
||||||
)
|
|
||||||
except ValidationError as e:
|
|
||||||
logger.exception("Validation Error when generating")
|
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
|
||||||
except (ValueError, InvokeError) as e:
|
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Unknown Error when generating")
|
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
|
||||||
finally:
|
|
||||||
db.session.close()
|
|
||||||
|
|
||||||
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
|
|
||||||
queue_manager: ApplicationQueueManager,
|
|
||||||
conversation: Conversation,
|
|
||||||
message: Message,
|
|
||||||
stream: bool = False) -> Union[dict, Generator]:
|
|
||||||
"""
|
|
||||||
Handle response.
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param queue_manager: queue manager
|
|
||||||
:param conversation: conversation
|
|
||||||
:param message: message
|
|
||||||
:param stream: is stream
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# init generate task pipeline
|
|
||||||
generate_task_pipeline = GenerateTaskPipeline(
|
|
||||||
application_generate_entity=application_generate_entity,
|
|
||||||
queue_manager=queue_manager,
|
|
||||||
conversation=conversation,
|
|
||||||
message=message
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return generate_task_pipeline.process(stream=stream)
|
|
||||||
except ValueError as e:
|
|
||||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
|
||||||
raise ConversationTaskStoppedException()
|
|
||||||
else:
|
|
||||||
logger.exception(e)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def convert_from_app_model_config_dict(self, tenant_id: str,
|
|
||||||
app_model_config_dict: dict,
|
app_model_config_dict: dict,
|
||||||
skip_check: bool = False) \
|
skip_check: bool = False) \
|
||||||
-> AppOrchestrationConfigEntity:
|
-> AppOrchestrationConfigEntity:
|
||||||
@ -394,7 +174,7 @@ class ApplicationManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
properties['variables'] = []
|
properties['variables'] = []
|
||||||
|
|
||||||
# variables and external_data_tools
|
# variables and external_data_tools
|
||||||
for variable in copy_app_model_config_dict.get('user_input_form', []):
|
for variable in copy_app_model_config_dict.get('user_input_form', []):
|
||||||
typ = list(variable.keys())[0]
|
typ = list(variable.keys())[0]
|
||||||
@ -444,7 +224,7 @@ class ApplicationManager:
|
|||||||
show_retrieve_source = True
|
show_retrieve_source = True
|
||||||
|
|
||||||
properties['show_retrieve_source'] = show_retrieve_source
|
properties['show_retrieve_source'] = show_retrieve_source
|
||||||
|
|
||||||
dataset_ids = []
|
dataset_ids = []
|
||||||
if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
|
if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
|
||||||
datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
|
datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
|
||||||
@ -452,26 +232,23 @@ class ApplicationManager:
|
|||||||
'datasets': []
|
'datasets': []
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
for dataset in datasets.get('datasets', []):
|
for dataset in datasets.get('datasets', []):
|
||||||
keys = list(dataset.keys())
|
keys = list(dataset.keys())
|
||||||
if len(keys) == 0 or keys[0] != 'dataset':
|
if len(keys) == 0 or keys[0] != 'dataset':
|
||||||
continue
|
continue
|
||||||
dataset = dataset['dataset']
|
dataset = dataset['dataset']
|
||||||
|
|
||||||
if 'enabled' not in dataset or not dataset['enabled']:
|
if 'enabled' not in dataset or not dataset['enabled']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
dataset_id = dataset.get('id', None)
|
dataset_id = dataset.get('id', None)
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
dataset_ids.append(dataset_id)
|
dataset_ids.append(dataset_id)
|
||||||
else:
|
|
||||||
datasets = {'strategy': 'router', 'datasets': []}
|
|
||||||
|
|
||||||
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
|
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
|
||||||
and 'enabled' in copy_app_model_config_dict['agent_mode'] \
|
and 'enabled' in copy_app_model_config_dict['agent_mode'] \
|
||||||
and copy_app_model_config_dict['agent_mode']['enabled']:
|
and copy_app_model_config_dict['agent_mode']['enabled']:
|
||||||
|
|
||||||
agent_dict = copy_app_model_config_dict.get('agent_mode', {})
|
agent_dict = copy_app_model_config_dict.get('agent_mode', {})
|
||||||
agent_strategy = agent_dict.get('strategy', 'cot')
|
agent_strategy = agent_dict.get('strategy', 'cot')
|
||||||
|
|
||||||
@ -515,7 +292,7 @@ class ApplicationManager:
|
|||||||
|
|
||||||
dataset_id = tool_item['id']
|
dataset_id = tool_item['id']
|
||||||
dataset_ids.append(dataset_id)
|
dataset_ids.append(dataset_id)
|
||||||
|
|
||||||
if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
|
if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
|
||||||
copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
|
copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
|
||||||
agent_prompt = agent_dict.get('prompt', None) or {}
|
agent_prompt = agent_dict.get('prompt', None) or {}
|
||||||
@ -523,13 +300,18 @@ class ApplicationManager:
|
|||||||
model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
|
model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
|
||||||
if model_mode == 'completion':
|
if model_mode == 'completion':
|
||||||
agent_prompt_entity = AgentPromptEntity(
|
agent_prompt_entity = AgentPromptEntity(
|
||||||
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
|
first_prompt=agent_prompt.get('first_prompt',
|
||||||
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
|
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
|
||||||
|
next_iteration=agent_prompt.get('next_iteration',
|
||||||
|
REACT_PROMPT_TEMPLATES['english']['completion'][
|
||||||
|
'agent_scratchpad']),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
agent_prompt_entity = AgentPromptEntity(
|
agent_prompt_entity = AgentPromptEntity(
|
||||||
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
|
first_prompt=agent_prompt.get('first_prompt',
|
||||||
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
|
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
|
||||||
|
next_iteration=agent_prompt.get('next_iteration',
|
||||||
|
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
|
||||||
)
|
)
|
||||||
|
|
||||||
properties['agent'] = AgentEntity(
|
properties['agent'] = AgentEntity(
|
||||||
@ -551,7 +333,7 @@ class ApplicationManager:
|
|||||||
dataset_ids=dataset_ids,
|
dataset_ids=dataset_ids,
|
||||||
retrieve_config=DatasetRetrieveConfigEntity(
|
retrieve_config=DatasetRetrieveConfigEntity(
|
||||||
query_variable=query_variable,
|
query_variable=query_variable,
|
||||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||||
dataset_configs['retrieval_model']
|
dataset_configs['retrieval_model']
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -624,169 +406,3 @@ class ApplicationManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return AppOrchestrationConfigEntity(**properties)
|
return AppOrchestrationConfigEntity(**properties)
|
||||||
|
|
||||||
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
|
||||||
-> tuple[Conversation, Message]:
|
|
||||||
"""
|
|
||||||
Initialize generate records
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
|
||||||
|
|
||||||
model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
|
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
||||||
model_schema = model_type_instance.get_model_schema(
|
|
||||||
model=app_orchestration_config_entity.model_config.model,
|
|
||||||
credentials=app_orchestration_config_entity.model_config.credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
app_record = (db.session.query(App)
|
|
||||||
.filter(App.id == application_generate_entity.app_id).first())
|
|
||||||
|
|
||||||
app_mode = app_record.mode
|
|
||||||
|
|
||||||
# get from source
|
|
||||||
end_user_id = None
|
|
||||||
account_id = None
|
|
||||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
|
||||||
from_source = 'api'
|
|
||||||
end_user_id = application_generate_entity.user_id
|
|
||||||
else:
|
|
||||||
from_source = 'console'
|
|
||||||
account_id = application_generate_entity.user_id
|
|
||||||
|
|
||||||
override_model_configs = None
|
|
||||||
if application_generate_entity.app_model_config_override:
|
|
||||||
override_model_configs = application_generate_entity.app_model_config_dict
|
|
||||||
|
|
||||||
introduction = ''
|
|
||||||
if app_mode == 'chat':
|
|
||||||
# get conversation introduction
|
|
||||||
introduction = self._get_conversation_introduction(application_generate_entity)
|
|
||||||
|
|
||||||
if not application_generate_entity.conversation_id:
|
|
||||||
conversation = Conversation(
|
|
||||||
app_id=app_record.id,
|
|
||||||
app_model_config_id=application_generate_entity.app_model_config_id,
|
|
||||||
model_provider=app_orchestration_config_entity.model_config.provider,
|
|
||||||
model_id=app_orchestration_config_entity.model_config.model,
|
|
||||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
|
||||||
mode=app_mode,
|
|
||||||
name='New conversation',
|
|
||||||
inputs=application_generate_entity.inputs,
|
|
||||||
introduction=introduction,
|
|
||||||
system_instruction="",
|
|
||||||
system_instruction_tokens=0,
|
|
||||||
status='normal',
|
|
||||||
from_source=from_source,
|
|
||||||
from_end_user_id=end_user_id,
|
|
||||||
from_account_id=account_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(conversation)
|
|
||||||
db.session.commit()
|
|
||||||
db.session.refresh(conversation)
|
|
||||||
else:
|
|
||||||
conversation = (
|
|
||||||
db.session.query(Conversation)
|
|
||||||
.filter(
|
|
||||||
Conversation.id == application_generate_entity.conversation_id,
|
|
||||||
Conversation.app_id == app_record.id
|
|
||||||
).first()
|
|
||||||
)
|
|
||||||
|
|
||||||
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
|
|
||||||
|
|
||||||
message = Message(
|
|
||||||
app_id=app_record.id,
|
|
||||||
model_provider=app_orchestration_config_entity.model_config.provider,
|
|
||||||
model_id=app_orchestration_config_entity.model_config.model,
|
|
||||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
|
||||||
conversation_id=conversation.id,
|
|
||||||
inputs=application_generate_entity.inputs,
|
|
||||||
query=application_generate_entity.query or "",
|
|
||||||
message="",
|
|
||||||
message_tokens=0,
|
|
||||||
message_unit_price=0,
|
|
||||||
message_price_unit=0,
|
|
||||||
answer="",
|
|
||||||
answer_tokens=0,
|
|
||||||
answer_unit_price=0,
|
|
||||||
answer_price_unit=0,
|
|
||||||
provider_response_latency=0,
|
|
||||||
total_price=0,
|
|
||||||
currency=currency,
|
|
||||||
from_source=from_source,
|
|
||||||
from_end_user_id=end_user_id,
|
|
||||||
from_account_id=account_id,
|
|
||||||
agent_based=app_orchestration_config_entity.agent is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(message)
|
|
||||||
db.session.commit()
|
|
||||||
db.session.refresh(message)
|
|
||||||
|
|
||||||
for file in application_generate_entity.files:
|
|
||||||
message_file = MessageFile(
|
|
||||||
message_id=message.id,
|
|
||||||
type=file.type.value,
|
|
||||||
transfer_method=file.transfer_method.value,
|
|
||||||
belongs_to='user',
|
|
||||||
url=file.url,
|
|
||||||
upload_file_id=file.upload_file_id,
|
|
||||||
created_by_role=('account' if account_id else 'end_user'),
|
|
||||||
created_by=account_id or end_user_id,
|
|
||||||
)
|
|
||||||
db.session.add(message_file)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return conversation, message
|
|
||||||
|
|
||||||
def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
|
|
||||||
"""
|
|
||||||
Get conversation introduction
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:return: conversation introduction
|
|
||||||
"""
|
|
||||||
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
|
||||||
introduction = app_orchestration_config_entity.opening_statement
|
|
||||||
|
|
||||||
if introduction:
|
|
||||||
try:
|
|
||||||
inputs = application_generate_entity.inputs
|
|
||||||
prompt_template = PromptTemplateParser(template=introduction)
|
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
|
||||||
introduction = prompt_template.format(prompt_inputs)
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return introduction
|
|
||||||
|
|
||||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
|
||||||
"""
|
|
||||||
Get conversation by conversation id
|
|
||||||
:param conversation_id: conversation id
|
|
||||||
:return: conversation
|
|
||||||
"""
|
|
||||||
conversation = (
|
|
||||||
db.session.query(Conversation)
|
|
||||||
.filter(Conversation.id == conversation_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
return conversation
|
|
||||||
|
|
||||||
def _get_message(self, message_id: str) -> Message:
|
|
||||||
"""
|
|
||||||
Get message by message id
|
|
||||||
:param message_id: message id
|
|
||||||
:return: message
|
|
||||||
"""
|
|
||||||
message = (
|
|
||||||
db.session.query(Message)
|
|
||||||
.filter(Message.id == message_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
return message
|
|
||||||
@ -32,7 +32,7 @@ class PublishFrom(Enum):
|
|||||||
TASK_PIPELINE = 2
|
TASK_PIPELINE = 2
|
||||||
|
|
||||||
|
|
||||||
class ApplicationQueueManager:
|
class AppQueueManager:
|
||||||
def __init__(self, task_id: str,
|
def __init__(self, task_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
@ -50,7 +50,7 @@ class ApplicationQueueManager:
|
|||||||
self._message_id = str(message_id)
|
self._message_id = str(message_id)
|
||||||
|
|
||||||
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||||
redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
|
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
|
||||||
|
|
||||||
q = queue.Queue()
|
q = queue.Queue()
|
||||||
|
|
||||||
@ -239,7 +239,7 @@ class ApplicationQueueManager:
|
|||||||
Check if task is stopped
|
Check if task is stopped
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
|
stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id)
|
||||||
result = redis_client.get(stopped_cache_key)
|
result = redis_client.get(stopped_cache_key)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return True
|
return True
|
||||||
@ -2,7 +2,7 @@ import time
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.entities.application_entities import (
|
from core.entities.application_entities import (
|
||||||
ApplicationGenerateEntity,
|
ApplicationGenerateEntity,
|
||||||
AppOrchestrationConfigEntity,
|
AppOrchestrationConfigEntity,
|
||||||
@ -11,10 +11,10 @@ from core.entities.application_entities import (
|
|||||||
ModelConfigEntity,
|
ModelConfigEntity,
|
||||||
PromptTemplateEntity,
|
PromptTemplateEntity,
|
||||||
)
|
)
|
||||||
from core.features.annotation_reply import AnnotationReplyFeature
|
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||||
from core.features.hosting_moderation import HostingModerationFeature
|
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||||
from core.features.moderation import ModerationFeature
|
from core.moderation.input_moderation import InputModeration
|
||||||
from core.file.file_obj import FileObj
|
from core.file.file_obj import FileObj
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
@ -169,7 +169,7 @@ class AppRunner:
|
|||||||
|
|
||||||
return prompt_messages, stop
|
return prompt_messages, stop
|
||||||
|
|
||||||
def direct_output(self, queue_manager: ApplicationQueueManager,
|
def direct_output(self, queue_manager: AppQueueManager,
|
||||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||||
prompt_messages: list,
|
prompt_messages: list,
|
||||||
text: str,
|
text: str,
|
||||||
@ -210,7 +210,7 @@ class AppRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
agent: bool = False) -> None:
|
agent: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
@ -234,7 +234,7 @@ class AppRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
agent: bool) -> None:
|
agent: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Handle invoke result direct
|
Handle invoke result direct
|
||||||
@ -248,7 +248,7 @@ class AppRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
agent: bool) -> None:
|
agent: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Handle invoke result
|
Handle invoke result
|
||||||
@ -306,7 +306,7 @@ class AppRunner:
|
|||||||
:param query: query
|
:param query: query
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
moderation_feature = ModerationFeature()
|
moderation_feature = InputModeration()
|
||||||
return moderation_feature.check(
|
return moderation_feature.check(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -316,7 +316,7 @@ class AppRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
prompt_messages: list[PromptMessage]) -> bool:
|
prompt_messages: list[PromptMessage]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check hosting moderation
|
Check hosting moderation
|
||||||
@ -358,7 +358,7 @@ class AppRunner:
|
|||||||
:param query: the query
|
:param query: the query
|
||||||
:return: the filled inputs
|
:return: the filled inputs
|
||||||
"""
|
"""
|
||||||
external_data_fetch_feature = ExternalDataFetchFeature()
|
external_data_fetch_feature = ExternalDataFetch()
|
||||||
return external_data_fetch_feature.fetch(
|
return external_data_fetch_feature.fetch(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
@ -388,4 +388,4 @@ class AppRunner:
|
|||||||
query=query,
|
query=query,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
invoke_from=invoke_from
|
invoke_from=invoke_from
|
||||||
)
|
)
|
||||||
@ -1,8 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app_runner.app_runner import AppRunner
|
from core.app.base_app_runner import AppRunner
|
||||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.entities.application_entities import (
|
from core.entities.application_entities import (
|
||||||
ApplicationGenerateEntity,
|
ApplicationGenerateEntity,
|
||||||
@ -10,7 +10,7 @@ from core.entities.application_entities import (
|
|||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
ModelConfigEntity,
|
ModelConfigEntity,
|
||||||
)
|
)
|
||||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationException
|
||||||
@ -20,13 +20,13 @@ from models.model import App, AppMode, Conversation, Message
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BasicApplicationRunner(AppRunner):
|
class ChatAppRunner(AppRunner):
|
||||||
"""
|
"""
|
||||||
Basic Application Runner
|
Chat Application Runner
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message) -> None:
|
message: Message) -> None:
|
||||||
"""
|
"""
|
||||||
@ -215,7 +215,7 @@ class BasicApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
def retrieve_dataset_context(self, tenant_id: str,
|
def retrieve_dataset_context(self, tenant_id: str,
|
||||||
app_record: App,
|
app_record: App,
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
model_config: ModelConfigEntity,
|
model_config: ModelConfigEntity,
|
||||||
dataset_config: DatasetEntity,
|
dataset_config: DatasetEntity,
|
||||||
show_retrieve_source: bool,
|
show_retrieve_source: bool,
|
||||||
@ -254,7 +254,7 @@ class BasicApplicationRunner(AppRunner):
|
|||||||
and dataset_config.retrieve_config.query_variable):
|
and dataset_config.retrieve_config.query_variable):
|
||||||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||||
|
|
||||||
dataset_retrieval = DatasetRetrievalFeature()
|
dataset_retrieval = DatasetRetrieval()
|
||||||
return dataset_retrieval.retrieve(
|
return dataset_retrieval.retrieve(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -1,15 +1,15 @@
|
|||||||
from core.apps.config_validators.dataset import DatasetValidator
|
from core.app.validators.dataset_retrieval import DatasetValidator
|
||||||
from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator
|
from core.app.validators.external_data_fetch import ExternalDataFetchValidator
|
||||||
from core.apps.config_validators.file_upload import FileUploadValidator
|
from core.app.validators.file_upload import FileUploadValidator
|
||||||
from core.apps.config_validators.model import ModelValidator
|
from core.app.validators.model_validator import ModelValidator
|
||||||
from core.apps.config_validators.moderation import ModerationValidator
|
from core.app.validators.moderation import ModerationValidator
|
||||||
from core.apps.config_validators.opening_statement import OpeningStatementValidator
|
from core.app.validators.opening_statement import OpeningStatementValidator
|
||||||
from core.apps.config_validators.prompt import PromptValidator
|
from core.app.validators.prompt import PromptValidator
|
||||||
from core.apps.config_validators.retriever_resource import RetrieverResourceValidator
|
from core.app.validators.retriever_resource import RetrieverResourceValidator
|
||||||
from core.apps.config_validators.speech_to_text import SpeechToTextValidator
|
from core.app.validators.speech_to_text import SpeechToTextValidator
|
||||||
from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator
|
from core.app.validators.suggested_questions import SuggestedQuestionsValidator
|
||||||
from core.apps.config_validators.text_to_speech import TextToSpeechValidator
|
from core.app.validators.text_to_speech import TextToSpeechValidator
|
||||||
from core.apps.config_validators.user_input_form import UserInputFormValidator
|
from core.app.validators.user_input_form import UserInputFormValidator
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ class ChatAppConfigValidator:
|
|||||||
related_config_keys.extend(current_related_config_keys)
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
# external data tools validation
|
# external data tools validation
|
||||||
config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config)
|
config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config)
|
||||||
related_config_keys.extend(current_related_config_keys)
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
# file upload validation
|
# file upload validation
|
||||||
266
api/core/app/completion/app_runner.py
Normal file
266
api/core/app/completion/app_runner.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.app.base_app_runner import AppRunner
|
||||||
|
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.entities.application_entities import (
|
||||||
|
ApplicationGenerateEntity,
|
||||||
|
DatasetEntity,
|
||||||
|
InvokeFrom,
|
||||||
|
ModelConfigEntity,
|
||||||
|
)
|
||||||
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.moderation.base import ModerationException
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import App, AppMode, Conversation, Message
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionAppRunner(AppRunner):
|
||||||
|
"""
|
||||||
|
Completion Application Runner
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
conversation: Conversation,
|
||||||
|
message: Message) -> None:
|
||||||
|
"""
|
||||||
|
Run application
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param conversation: conversation
|
||||||
|
:param message: message
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||||
|
if not app_record:
|
||||||
|
raise ValueError("App not found")
|
||||||
|
|
||||||
|
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||||
|
|
||||||
|
inputs = application_generate_entity.inputs
|
||||||
|
query = application_generate_entity.query
|
||||||
|
files = application_generate_entity.files
|
||||||
|
|
||||||
|
# Pre-calculate the number of tokens of the prompt messages,
|
||||||
|
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||||
|
# If the rest number of tokens is not enough, raise exception.
|
||||||
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
|
# Not Include: memory, external data, dataset context
|
||||||
|
self.get_pre_calculate_rest_tokens(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = None
|
||||||
|
if application_generate_entity.conversation_id:
|
||||||
|
# get memory of conversation (read-only)
|
||||||
|
model_instance = ModelInstance(
|
||||||
|
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||||
|
model=app_orchestration_config.model_config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = TokenBufferMemory(
|
||||||
|
conversation=conversation,
|
||||||
|
model_instance=model_instance
|
||||||
|
)
|
||||||
|
|
||||||
|
# organize all inputs and template to prompt messages
|
||||||
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
|
# memory(optional)
|
||||||
|
prompt_messages, stop = self.organize_prompt_messages(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# moderation
|
||||||
|
try:
|
||||||
|
# process sensitive_word_avoidance
|
||||||
|
_, inputs, query = self.moderation_for_inputs(
|
||||||
|
app_id=app_record.id,
|
||||||
|
tenant_id=application_generate_entity.tenant_id,
|
||||||
|
app_orchestration_config_entity=app_orchestration_config,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
except ModerationException as e:
|
||||||
|
self.direct_output(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
app_orchestration_config=app_orchestration_config,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
text=str(e),
|
||||||
|
stream=application_generate_entity.stream
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if query:
|
||||||
|
# annotation reply
|
||||||
|
annotation_reply = self.query_app_annotations_to_reply(
|
||||||
|
app_record=app_record,
|
||||||
|
message=message,
|
||||||
|
query=query,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from
|
||||||
|
)
|
||||||
|
|
||||||
|
if annotation_reply:
|
||||||
|
queue_manager.publish_annotation_reply(
|
||||||
|
message_annotation_id=annotation_reply.id,
|
||||||
|
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
self.direct_output(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
app_orchestration_config=app_orchestration_config,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
text=annotation_reply.content,
|
||||||
|
stream=application_generate_entity.stream
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# fill in variable inputs from external data tools if exists
|
||||||
|
external_data_tools = app_orchestration_config.external_data_variables
|
||||||
|
if external_data_tools:
|
||||||
|
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||||
|
tenant_id=app_record.tenant_id,
|
||||||
|
app_id=app_record.id,
|
||||||
|
external_data_tools=external_data_tools,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
# get context from datasets
|
||||||
|
context = None
|
||||||
|
if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids:
|
||||||
|
context = self.retrieve_dataset_context(
|
||||||
|
tenant_id=app_record.tenant_id,
|
||||||
|
app_record=app_record,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
show_retrieve_source=app_orchestration_config.show_retrieve_source,
|
||||||
|
dataset_config=app_orchestration_config.dataset,
|
||||||
|
message=message,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# reorganize all inputs and template to prompt messages
|
||||||
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
|
# memory(optional), external data, dataset context(optional)
|
||||||
|
prompt_messages, stop = self.organize_prompt_messages(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# check hosting moderation
|
||||||
|
hosting_moderation_result = self.check_hosting_moderation(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
prompt_messages=prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if hosting_moderation_result:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||||
|
self.recale_llm_max_tokens(
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_messages=prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invoke model
|
||||||
|
model_instance = ModelInstance(
|
||||||
|
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||||
|
model=app_orchestration_config.model_config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
invoke_result = model_instance.invoke_llm(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=app_orchestration_config.model_config.parameters,
|
||||||
|
stop=stop,
|
||||||
|
stream=application_generate_entity.stream,
|
||||||
|
user=application_generate_entity.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle invoke result
|
||||||
|
self._handle_invoke_result(
|
||||||
|
invoke_result=invoke_result,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
stream=application_generate_entity.stream
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve_dataset_context(self, tenant_id: str,
|
||||||
|
app_record: App,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
model_config: ModelConfigEntity,
|
||||||
|
dataset_config: DatasetEntity,
|
||||||
|
show_retrieve_source: bool,
|
||||||
|
message: Message,
|
||||||
|
inputs: dict,
|
||||||
|
query: str,
|
||||||
|
user_id: str,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieve dataset context
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param app_record: app record
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param model_config: model config
|
||||||
|
:param dataset_config: dataset config
|
||||||
|
:param show_retrieve_source: show retrieve source
|
||||||
|
:param message: message
|
||||||
|
:param inputs: inputs
|
||||||
|
:param query: query
|
||||||
|
:param user_id: user id
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:param memory: memory
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
hit_callback = DatasetIndexToolCallbackHandler(
|
||||||
|
queue_manager,
|
||||||
|
app_record.id,
|
||||||
|
message.id,
|
||||||
|
user_id,
|
||||||
|
invoke_from
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
|
||||||
|
and dataset_config.retrieve_config.query_variable):
|
||||||
|
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||||
|
|
||||||
|
dataset_retrieval = DatasetRetrieval()
|
||||||
|
return dataset_retrieval.retrieve(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_config=model_config,
|
||||||
|
config=dataset_config,
|
||||||
|
query=query,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
show_retrieve_source=show_retrieve_source,
|
||||||
|
hit_callback=hit_callback,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
@ -1,12 +1,12 @@
|
|||||||
from core.apps.config_validators.dataset import DatasetValidator
|
from core.app.validators.dataset_retrieval import DatasetValidator
|
||||||
from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator
|
from core.app.validators.external_data_fetch import ExternalDataFetchValidator
|
||||||
from core.apps.config_validators.file_upload import FileUploadValidator
|
from core.app.validators.file_upload import FileUploadValidator
|
||||||
from core.apps.config_validators.model import ModelValidator
|
from core.app.validators.model_validator import ModelValidator
|
||||||
from core.apps.config_validators.moderation import ModerationValidator
|
from core.app.validators.moderation import ModerationValidator
|
||||||
from core.apps.config_validators.more_like_this import MoreLikeThisValidator
|
from core.app.validators.more_like_this import MoreLikeThisValidator
|
||||||
from core.apps.config_validators.prompt import PromptValidator
|
from core.app.validators.prompt import PromptValidator
|
||||||
from core.apps.config_validators.text_to_speech import TextToSpeechValidator
|
from core.app.validators.text_to_speech import TextToSpeechValidator
|
||||||
from core.apps.config_validators.user_input_form import UserInputFormValidator
|
from core.app.validators.user_input_form import UserInputFormValidator
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ class CompletionAppConfigValidator:
|
|||||||
related_config_keys.extend(current_related_config_keys)
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
# external data tools validation
|
# external data tools validation
|
||||||
config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config)
|
config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config)
|
||||||
related_config_keys.extend(current_related_config_keys)
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
# file upload validation
|
# file upload validation
|
||||||
@ -6,8 +6,8 @@ from typing import Optional, Union, cast
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
|
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
|
from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
|
||||||
from core.entities.queue_entities import (
|
from core.entities.queue_entities import (
|
||||||
AnnotationReplyEvent,
|
AnnotationReplyEvent,
|
||||||
@ -35,7 +35,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -59,7 +59,7 @@ class GenerateTaskPipeline:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, application_generate_entity: ApplicationGenerateEntity,
|
def __init__(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message) -> None:
|
message: Message) -> None:
|
||||||
"""
|
"""
|
||||||
@ -633,7 +633,7 @@ class GenerateTaskPipeline:
|
|||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
|
def _init_output_moderation(self) -> Optional[OutputModeration]:
|
||||||
"""
|
"""
|
||||||
Init output moderation.
|
Init output moderation.
|
||||||
:return:
|
:return:
|
||||||
@ -642,7 +642,7 @@ class GenerateTaskPipeline:
|
|||||||
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
|
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
|
||||||
|
|
||||||
if sensitive_word_avoidance:
|
if sensitive_word_avoidance:
|
||||||
return OutputModerationHandler(
|
return OutputModeration(
|
||||||
tenant_id=self._application_generate_entity.tenant_id,
|
tenant_id=self._application_generate_entity.tenant_id,
|
||||||
app_id=self._application_generate_entity.app_id,
|
app_id=self._application_generate_entity.app_id,
|
||||||
rule=ModerationRule(
|
rule=ModerationRule(
|
||||||
0
api/core/app/validators/__init__.py
Normal file
0
api/core/app/validators/__init__.py
Normal file
@ -2,7 +2,7 @@
|
|||||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataToolsValidator:
|
class ExternalDataFetchValidator:
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||||
"""
|
"""
|
||||||
0
api/core/app/workflow/__init__.py
Normal file
0
api/core/app/workflow/__init__.py
Normal file
@ -1,6 +1,6 @@
|
|||||||
from core.apps.config_validators.file_upload import FileUploadValidator
|
from core.app.validators.file_upload import FileUploadValidator
|
||||||
from core.apps.config_validators.moderation import ModerationValidator
|
from core.app.validators.moderation import ModerationValidator
|
||||||
from core.apps.config_validators.text_to_speech import TextToSpeechValidator
|
from core.app.validators.text_to_speech import TextToSpeechValidator
|
||||||
|
|
||||||
|
|
||||||
class WorkflowAppConfigValidator:
|
class WorkflowAppConfigValidator:
|
||||||
@ -1,82 +0,0 @@
|
|||||||
from core.apps.config_validators.agent import AgentValidator
|
|
||||||
from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator
|
|
||||||
from core.apps.config_validators.file_upload import FileUploadValidator
|
|
||||||
from core.apps.config_validators.model import ModelValidator
|
|
||||||
from core.apps.config_validators.moderation import ModerationValidator
|
|
||||||
from core.apps.config_validators.opening_statement import OpeningStatementValidator
|
|
||||||
from core.apps.config_validators.prompt import PromptValidator
|
|
||||||
from core.apps.config_validators.retriever_resource import RetrieverResourceValidator
|
|
||||||
from core.apps.config_validators.speech_to_text import SpeechToTextValidator
|
|
||||||
from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator
|
|
||||||
from core.apps.config_validators.text_to_speech import TextToSpeechValidator
|
|
||||||
from core.apps.config_validators.user_input_form import UserInputFormValidator
|
|
||||||
from models.model import AppMode
|
|
||||||
|
|
||||||
|
|
||||||
class AgentChatAppConfigValidator:
|
|
||||||
@classmethod
|
|
||||||
def config_validate(cls, tenant_id: str, config: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Validate for agent chat app model config
|
|
||||||
|
|
||||||
:param tenant_id: tenant id
|
|
||||||
:param config: app model config args
|
|
||||||
"""
|
|
||||||
app_mode = AppMode.AGENT_CHAT
|
|
||||||
|
|
||||||
related_config_keys = []
|
|
||||||
|
|
||||||
# model
|
|
||||||
config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# user_input_form
|
|
||||||
config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# external data tools validation
|
|
||||||
config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# file upload validation
|
|
||||||
config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# prompt
|
|
||||||
config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# agent_mode
|
|
||||||
config, current_related_config_keys = AgentValidator.validate_and_set_defaults(tenant_id, config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# opening_statement
|
|
||||||
config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# suggested_questions_after_answer
|
|
||||||
config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# speech_to_text
|
|
||||||
config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# text_to_speech
|
|
||||||
config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# return retriever resource
|
|
||||||
config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
# moderation validation
|
|
||||||
config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
|
||||||
|
|
||||||
related_config_keys = list(set(related_config_keys))
|
|
||||||
|
|
||||||
# Filter out extra parameters
|
|
||||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
|
||||||
|
|
||||||
return filtered_config
|
|
||||||
@ -1,81 +0,0 @@
|
|||||||
import uuid
|
|
||||||
|
|
||||||
from core.apps.config_validators.dataset import DatasetValidator
|
|
||||||
from core.entities.agent_entities import PlanningStrategy
|
|
||||||
|
|
||||||
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
|
||||||
|
|
||||||
|
|
||||||
class AgentValidator:
|
|
||||||
@classmethod
|
|
||||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
|
||||||
"""
|
|
||||||
Validate and set defaults for agent feature
|
|
||||||
|
|
||||||
:param tenant_id: tenant ID
|
|
||||||
:param config: app model config args
|
|
||||||
"""
|
|
||||||
if not config.get("agent_mode"):
|
|
||||||
config["agent_mode"] = {
|
|
||||||
"enabled": False,
|
|
||||||
"tools": []
|
|
||||||
}
|
|
||||||
|
|
||||||
if not isinstance(config["agent_mode"], dict):
|
|
||||||
raise ValueError("agent_mode must be of object type")
|
|
||||||
|
|
||||||
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
|
|
||||||
config["agent_mode"]["enabled"] = False
|
|
||||||
|
|
||||||
if not isinstance(config["agent_mode"]["enabled"], bool):
|
|
||||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
|
||||||
|
|
||||||
if not config["agent_mode"].get("strategy"):
|
|
||||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
|
||||||
|
|
||||||
if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
|
|
||||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
|
||||||
|
|
||||||
if not config["agent_mode"].get("tools"):
|
|
||||||
config["agent_mode"]["tools"] = []
|
|
||||||
|
|
||||||
if not isinstance(config["agent_mode"]["tools"], list):
|
|
||||||
raise ValueError("tools in agent_mode must be a list of objects")
|
|
||||||
|
|
||||||
for tool in config["agent_mode"]["tools"]:
|
|
||||||
key = list(tool.keys())[0]
|
|
||||||
if key in OLD_TOOLS:
|
|
||||||
# old style, use tool name as key
|
|
||||||
tool_item = tool[key]
|
|
||||||
|
|
||||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
|
||||||
tool_item["enabled"] = False
|
|
||||||
|
|
||||||
if not isinstance(tool_item["enabled"], bool):
|
|
||||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
|
||||||
|
|
||||||
if key == "dataset":
|
|
||||||
if 'id' not in tool_item:
|
|
||||||
raise ValueError("id is required in dataset")
|
|
||||||
|
|
||||||
try:
|
|
||||||
uuid.UUID(tool_item["id"])
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError("id in dataset must be of UUID type")
|
|
||||||
|
|
||||||
if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]):
|
|
||||||
raise ValueError("Dataset ID does not exist, please check your permission.")
|
|
||||||
else:
|
|
||||||
# latest style, use key-value pair
|
|
||||||
if "enabled" not in tool or not tool["enabled"]:
|
|
||||||
tool["enabled"] = False
|
|
||||||
if "provider_type" not in tool:
|
|
||||||
raise ValueError("provider_type is required in agent_mode.tools")
|
|
||||||
if "provider_id" not in tool:
|
|
||||||
raise ValueError("provider_id is required in agent_mode.tools")
|
|
||||||
if "tool_name" not in tool:
|
|
||||||
raise ValueError("tool_name is required in agent_mode.tools")
|
|
||||||
if "tool_parameters" not in tool:
|
|
||||||
raise ValueError("tool_parameters is required in agent_mode.tools")
|
|
||||||
|
|
||||||
return config, ["agent_mode"]
|
|
||||||
@ -7,7 +7,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen
|
|||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult
|
||||||
|
|
||||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||||
@ -22,7 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
raise_error: bool = True
|
raise_error: bool = True
|
||||||
|
|
||||||
def __init__(self, model_config: ModelConfigEntity,
|
def __init__(self, model_config: ModelConfigEntity,
|
||||||
queue_manager: ApplicationQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
message: Message,
|
message: Message,
|
||||||
message_chain: MessageChain) -> None:
|
message_chain: MessageChain) -> None:
|
||||||
"""Initialize callback handler."""
|
"""Initialize callback handler."""
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.entities.application_entities import InvokeFrom
|
from core.entities.application_entities import InvokeFrom
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -10,7 +10,7 @@ from models.model import DatasetRetrieverResource
|
|||||||
class DatasetIndexToolCallbackHandler:
|
class DatasetIndexToolCallbackHandler:
|
||||||
"""Callback handler for dataset tool."""
|
"""Callback handler for dataset tool."""
|
||||||
|
|
||||||
def __init__(self, queue_manager: ApplicationQueueManager,
|
def __init__(self, queue_manager: AppQueueManager,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from core.external_data_tool.factory import ExternalDataToolFactory
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataFetchFeature:
|
class ExternalDataFetch:
|
||||||
def fetch(self, tenant_id: str,
|
def fetch(self, tenant_id: str,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
external_data_tools: list[ExternalDataVariableEntity],
|
external_data_tools: list[ExternalDataVariableEntity],
|
||||||
@ -13,7 +13,7 @@ from sqlalchemy.orm.exc import ObjectDeletedError
|
|||||||
|
|
||||||
from core.docstore.dataset_docstore import DatasetDocumentStore
|
from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.errors.error import ProviderTokenNotInitError
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
from core.generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|||||||
0
api/core/llm_generator/__init__.py
Normal file
0
api/core/llm_generator/__init__.py
Normal file
@ -7,10 +7,10 @@ from core.model_manager import ModelManager
|
|||||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
|
from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
|
||||||
|
|
||||||
|
|
||||||
class LLMGenerator:
|
class LLMGenerator:
|
||||||
0
api/core/llm_generator/output_parser/__init__.py
Normal file
0
api/core/llm_generator/output_parser/__init__.py
Normal file
@ -2,7 +2,7 @@ from typing import Any
|
|||||||
|
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
|
||||||
from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE
|
from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE
|
||||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
|
|
||||||
|
|
||||||
@ -4,7 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
from langchain.schema import BaseOutputParser
|
from langchain.schema import BaseOutputParser
|
||||||
|
|
||||||
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||||
|
|
||||||
|
|
||||||
class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
|
class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
|
||||||
@ -7,7 +7,7 @@ from core.moderation.factory import ModerationFactory
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ModerationFeature:
|
class InputModeration:
|
||||||
def check(self, app_id: str,
|
def check(self, app_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||||
@ -6,7 +6,7 @@ from typing import Any, Optional
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.application_queue_manager import PublishFrom
|
from core.app.app_queue_manager import PublishFrom
|
||||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||||
from core.moderation.factory import ModerationFactory
|
from core.moderation.factory import ModerationFactory
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ class ModerationRule(BaseModel):
|
|||||||
config: dict[str, Any]
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class OutputModerationHandler(BaseModel):
|
class OutputModeration(BaseModel):
|
||||||
DEFAULT_BUFFER_SIZE: int = 300
|
DEFAULT_BUFFER_SIZE: int = 300
|
||||||
|
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
0
api/core/prompt/__init__.py
Normal file
0
api/core/prompt/__init__.py
Normal file
@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from core.prompt.prompt_transform import PromptTransform
|
from core.prompt.prompt_transform import PromptTransform
|
||||||
from core.prompt.simple_prompt_transform import ModelMode
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
|
|
||||||
|
|||||||
0
api/core/prompt/prompt_templates/__init__.py
Normal file
0
api/core/prompt/prompt_templates/__init__.py
Normal file
@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from core.prompt.prompt_transform import PromptTransform
|
from core.prompt.prompt_transform import PromptTransform
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
@ -275,7 +275,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
return prompt_file_contents[prompt_file_name]
|
return prompt_file_contents[prompt_file_name]
|
||||||
|
|
||||||
# Get the absolute path of the subdirectory
|
# Get the absolute path of the subdirectory
|
||||||
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts')
|
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates')
|
||||||
json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
|
json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
|
||||||
|
|
||||||
# Open the JSON file and read its content
|
# Open the JSON file and read its content
|
||||||
|
|||||||
0
api/core/prompt/utils/__init__.py
Normal file
0
api/core/prompt/utils/__init__.py
Normal file
@ -9,7 +9,7 @@ import pandas as pd
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from werkzeug.datastructures import FileStorage
|
from werkzeug.datastructures import FileStorage
|
||||||
|
|
||||||
from core.generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
|
|||||||
0
api/core/rag/retrieval/__init__.py
Normal file
0
api/core/rag/retrieval/__init__.py
Normal file
0
api/core/rag/retrieval/agent/__init__.py
Normal file
0
api/core/rag/retrieval/agent/__init__.py
Normal file
@ -7,8 +7,8 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||||
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
|
from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||||
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
|
from core.rag.retrieval.agent.fake_llm import FakeLLM
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
|
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ from pydantic import root_validator
|
|||||||
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||||
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
|
from core.rag.retrieval.agent.fake_llm import FakeLLM
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
|||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
|
from core.rag.retrieval.agent.llm_chain import LLMChain
|
||||||
|
|
||||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||||
@ -10,10 +10,10 @@ from pydantic import BaseModel, Extra
|
|||||||
from core.entities.agent_entities import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.entities.message_entities import prompt_messages_to_lc_messages
|
from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||||
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
|
from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||||
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||||
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||||
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||||
from core.helper import moderation
|
from core.helper import moderation
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
@ -5,7 +5,7 @@ from langchain.tools import BaseTool
|
|||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.entities.agent_entities import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
|
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
|
||||||
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
|
from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
@ -15,7 +15,7 @@ from extensions.ext_database import db
|
|||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrievalFeature:
|
class DatasetRetrieval:
|
||||||
def retrieve(self, tenant_id: str,
|
def retrieve(self, tenant_id: str,
|
||||||
model_config: ModelConfigEntity,
|
model_config: ModelConfigEntity,
|
||||||
config: DatasetEntity,
|
config: DatasetEntity,
|
||||||
@ -4,7 +4,7 @@ from langchain.tools import BaseTool
|
|||||||
|
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
|
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
|
||||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
@ -30,7 +30,7 @@ class DatasetRetrieverTool(Tool):
|
|||||||
if retrieve_config is None:
|
if retrieve_config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
feature = DatasetRetrievalFeature()
|
feature = DatasetRetrieval()
|
||||||
|
|
||||||
# save original retrieve strategy, and set retrieve strategy to SINGLE
|
# save original retrieve strategy, and set retrieve strategy to SINGLE
|
||||||
# Agent only support SINGLE mode
|
# Agent only support SINGLE mode
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from core.generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|||||||
@ -310,22 +310,28 @@ class AppModelConfig(db.Model):
|
|||||||
|
|
||||||
def from_model_config_dict(self, model_config: dict):
|
def from_model_config_dict(self, model_config: dict):
|
||||||
self.opening_statement = model_config['opening_statement']
|
self.opening_statement = model_config['opening_statement']
|
||||||
self.suggested_questions = json.dumps(model_config['suggested_questions'])
|
self.suggested_questions = json.dumps(model_config['suggested_questions']) \
|
||||||
self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer'])
|
if model_config.get('suggested_questions') else None
|
||||||
|
self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \
|
||||||
|
if model_config.get('suggested_questions_after_answer') else None
|
||||||
self.speech_to_text = json.dumps(model_config['speech_to_text']) \
|
self.speech_to_text = json.dumps(model_config['speech_to_text']) \
|
||||||
if model_config.get('speech_to_text') else None
|
if model_config.get('speech_to_text') else None
|
||||||
self.text_to_speech = json.dumps(model_config['text_to_speech']) \
|
self.text_to_speech = json.dumps(model_config['text_to_speech']) \
|
||||||
if model_config.get('text_to_speech') else None
|
if model_config.get('text_to_speech') else None
|
||||||
self.more_like_this = json.dumps(model_config['more_like_this'])
|
self.more_like_this = json.dumps(model_config['more_like_this']) \
|
||||||
|
if model_config.get('more_like_this') else None
|
||||||
self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \
|
self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \
|
||||||
if model_config.get('sensitive_word_avoidance') else None
|
if model_config.get('sensitive_word_avoidance') else None
|
||||||
self.external_data_tools = json.dumps(model_config['external_data_tools']) \
|
self.external_data_tools = json.dumps(model_config['external_data_tools']) \
|
||||||
if model_config.get('external_data_tools') else None
|
if model_config.get('external_data_tools') else None
|
||||||
self.model = json.dumps(model_config['model'])
|
self.model = json.dumps(model_config['model']) \
|
||||||
self.user_input_form = json.dumps(model_config['user_input_form'])
|
if model_config.get('model') else None
|
||||||
|
self.user_input_form = json.dumps(model_config['user_input_form']) \
|
||||||
|
if model_config.get('user_input_form') else None
|
||||||
self.dataset_query_variable = model_config.get('dataset_query_variable')
|
self.dataset_query_variable = model_config.get('dataset_query_variable')
|
||||||
self.pre_prompt = model_config['pre_prompt']
|
self.pre_prompt = model_config['pre_prompt']
|
||||||
self.agent_mode = json.dumps(model_config['agent_mode'])
|
self.agent_mode = json.dumps(model_config['agent_mode']) \
|
||||||
|
if model_config.get('agent_mode') else None
|
||||||
self.retriever_resource = json.dumps(model_config['retriever_resource']) \
|
self.retriever_resource = json.dumps(model_config['retriever_resource']) \
|
||||||
if model_config.get('retriever_resource') else None
|
if model_config.get('retriever_resource') else None
|
||||||
self.prompt_type = model_config.get('prompt_type', 'simple')
|
self.prompt_type = model_config.get('prompt_type', 'simple')
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from core.prompt.advanced_prompt_templates import (
|
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from core.apps.app_config_validators.advanced_chat_app import AdvancedChatAppConfigValidator
|
from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator
|
||||||
from core.apps.app_config_validators.agent_chat_app import AgentChatAppConfigValidator
|
from core.app.agent_chat.config_validator import AgentChatAppConfigValidator
|
||||||
from core.apps.app_config_validators.chat_app import ChatAppConfigValidator
|
from core.app.chat.config_validator import ChatAppConfigValidator
|
||||||
from core.apps.app_config_validators.completion_app import CompletionAppConfigValidator
|
from core.app.completion.config_validator import CompletionAppConfigValidator
|
||||||
from core.apps.app_config_validators.workflow_app import WorkflowAppConfigValidator
|
from core.app.workflow.config_validator import WorkflowAppConfigValidator
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from typing import Any, Union
|
|||||||
|
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
|
|
||||||
from core.application_manager import ApplicationManager
|
from core.app.app_manager import AppManager
|
||||||
from core.apps.config_validators.model import ModelValidator
|
from core.app.validators.model_validator import ModelValidator
|
||||||
from core.entities.application_entities import InvokeFrom
|
from core.entities.application_entities import InvokeFrom
|
||||||
from core.file.message_file_parser import MessageFileParser
|
from core.file.message_file_parser import MessageFileParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -137,7 +137,7 @@ class CompletionService:
|
|||||||
user
|
user
|
||||||
)
|
)
|
||||||
|
|
||||||
application_manager = ApplicationManager()
|
application_manager = AppManager()
|
||||||
return application_manager.generate(
|
return application_manager.generate(
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
@ -193,7 +193,7 @@ class CompletionService:
|
|||||||
message.files, app_model_config
|
message.files, app_model_config
|
||||||
)
|
)
|
||||||
|
|
||||||
application_manager = ApplicationManager()
|
application_manager = AppManager()
|
||||||
return application_manager.generate(
|
return application_manager.generate(
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from core.generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from core.generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.application_manager import ApplicationManager
|
from core.app.app_manager import AppManager
|
||||||
from core.entities.application_entities import (
|
from core.entities.application_entities import (
|
||||||
DatasetEntity,
|
DatasetEntity,
|
||||||
DatasetRetrieveConfigEntity,
|
DatasetRetrieveConfigEntity,
|
||||||
@ -111,7 +111,7 @@ class WorkflowConverter:
|
|||||||
new_app_mode = self._get_new_app_mode(app_model)
|
new_app_mode = self._get_new_app_mode(app_model)
|
||||||
|
|
||||||
# convert app model config
|
# convert app model config
|
||||||
application_manager = ApplicationManager()
|
application_manager = AppManager()
|
||||||
app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict(
|
app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict(
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
app_model_config_dict=app_model_config.to_dict(),
|
app_model_config_dict=app_model_config.to_dict(),
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from core.file.file_obj import FileObj, FileType, FileTransferMethod
|
|||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
|
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
|
||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from models.model import Conversation
|
from models.model import Conversation
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user