From 799db69e4f334a20cbbfad540b518bffc4b698d9 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 17:33:52 +0800 Subject: [PATCH] refactor app --- api/controllers/console/app/completion.py | 6 +- api/controllers/console/app/generator.py | 2 +- api/controllers/console/explore/completion.py | 6 +- api/controllers/service_api/app/completion.py | 6 +- api/controllers/web/completion.py | 6 +- api/core/{app_runner => agent}/__init__.py | 0 .../base_agent_runner.py} | 8 +- .../cot_agent_runner.py} | 6 +- .../fc_agent_runner.py} | 6 +- api/core/{apps => app}/__init__.py | 0 .../advanced_chat}/__init__.py | 0 .../advanced_chat/config_validator.py} | 14 +- .../agent_chat}/__init__.py | 0 .../agent_chat/app_runner.py} | 19 +- api/core/app/agent_chat/config_validator.py | 162 +++++++ api/core/app/app_manager.py | 382 +++++++++++++++ .../app_orchestration_config_converter.py} | 434 +----------------- .../app_queue_manager.py} | 6 +- .../app_runner.py => app/base_app_runner.py} | 26 +- api/core/{features => app/chat}/__init__.py | 0 .../chat/app_runner.py} | 16 +- .../chat/config_validator.py} | 26 +- .../completion}/__init__.py | 0 api/core/app/completion/app_runner.py | 266 +++++++++++ .../completion/config_validator.py} | 20 +- .../agent => app/features}/__init__.py | 0 .../features/annotation_reply}/__init__.py | 0 .../annotation_reply}/annotation_reply.py | 0 .../features/hosting_moderation/__init__.py | 0 .../hosting_moderation}/hosting_moderation.py | 0 .../generate_task_pipeline.py | 12 +- api/core/app/validators/__init__.py | 0 .../validators/dataset_retrieval.py} | 0 .../validators/external_data_fetch.py} | 2 +- .../validators}/file_upload.py | 0 .../validators/model_validator.py} | 0 .../validators}/moderation.py | 0 .../validators}/more_like_this.py | 0 .../validators}/opening_statement.py | 0 .../validators}/prompt.py | 0 .../validators}/retriever_resource.py | 0 .../validators}/speech_to_text.py | 0 .../validators}/suggested_questions.py | 0 .../validators}/text_to_speech.py | 0 .../validators}/user_input_form.py | 0 api/core/app/workflow/__init__.py | 0 .../workflow/config_validator.py} | 6 +- .../app_config_validators/agent_chat_app.py | 82 ---- api/core/apps/config_validators/agent.py | 81 ---- .../agent_loop_gather_callback_handler.py | 4 +- .../index_tool_callback_handler.py | 4 +- .../external_data_fetch.py | 2 +- api/core/indexing_runner.py | 2 +- api/core/llm_generator/__init__.py | 0 .../llm_generator.py | 8 +- .../llm_generator/output_parser/__init__.py | 0 .../output_parser/rule_config_generator.py | 2 +- .../suggested_questions_after_answer.py | 2 +- api/core/{prompt => llm_generator}/prompts.py | 0 .../input_moderation.py} | 2 +- .../output_moderation.py} | 4 +- api/core/prompt/__init__.py | 0 api/core/prompt/advanced_prompt_transform.py | 2 +- api/core/prompt/prompt_templates/__init__.py | 0 .../advanced_prompt_templates.py | 0 .../baichuan_chat.json | 0 .../baichuan_completion.json | 0 .../common_chat.json | 0 .../common_completion.json | 0 api/core/prompt/simple_prompt_transform.py | 4 +- api/core/prompt/utils/__init__.py | 0 .../prompt_template_parser.py} | 0 .../processor/qa_index_processor.py | 2 +- api/core/rag/retrieval/__init__.py | 0 api/core/rag/retrieval/agent/__init__.py | 0 .../retrieval}/agent/agent_llm_callback.py | 0 .../retrieval}/agent/fake_llm.py | 0 .../retrieval}/agent/llm_chain.py | 4 +- .../agent/multi_dataset_router_agent.py | 2 +- .../retrieval/agent/output_parser/__init__.py | 0 .../agent/output_parser/structured_chat.py | 0 .../structed_multi_dataset_router_agent.py | 2 +- .../agent_based_dataset_executor.py | 8 +- .../retrieval}/dataset_retrieval.py | 4 +- api/core/tools/tool/dataset_retriever_tool.py | 4 +- ...rsation_name_when_first_message_created.py | 2 +- api/models/model.py | 18 +- .../advanced_prompt_template_service.py | 2 +- api/services/app_model_config_service.py | 10 +- api/services/completion_service.py | 8 +- api/services/conversation_service.py | 2 +- api/services/message_service.py | 2 +- api/services/workflow/workflow_converter.py | 4 +- .../prompt/test_advanced_prompt_transform.py | 2 +- 94 files changed, 991 insertions(+), 721 deletions(-) rename api/core/{app_runner => agent}/__init__.py (100%) rename api/core/{features/assistant_base_runner.py => agent/base_agent_runner.py} (99%) rename api/core/{features/assistant_cot_runner.py => agent/cot_agent_runner.py} (99%) rename api/core/{features/assistant_fc_runner.py => agent/fc_agent_runner.py} (98%) rename api/core/{apps => app}/__init__.py (100%) rename api/core/{apps/app_config_validators => app/advanced_chat}/__init__.py (100%) rename api/core/{apps/app_config_validators/advanced_chat_app.py => app/advanced_chat/config_validator.py} (77%) rename api/core/{apps/config_validators => app/agent_chat}/__init__.py (100%) rename api/core/{app_runner/assistant_app_runner.py => app/agent_chat/app_runner.py} (95%) create mode 100644 api/core/app/agent_chat/config_validator.py create mode 100644 api/core/app/app_manager.py rename api/core/{application_manager.py => app/app_orchestration_config_converter.py} (52%) rename api/core/{application_queue_manager.py => app/app_queue_manager.py} (97%) rename api/core/{app_runner/app_runner.py => app/base_app_runner.py} (94%) rename api/core/{features => app/chat}/__init__.py (100%) rename api/core/{app_runner/basic_app_runner.py => app/chat/app_runner.py} (95%) rename api/core/{apps/app_config_validators/chat_app.py => app/chat/config_validator.py} (75%) rename api/core/{features/dataset_retrieval => app/completion}/__init__.py (100%) create mode 100644 api/core/app/completion/app_runner.py rename api/core/{apps/app_config_validators/completion_app.py => app/completion/config_validator.py} (76%) rename api/core/{features/dataset_retrieval/agent => app/features}/__init__.py (100%) rename api/core/{features/dataset_retrieval/agent/output_parser => app/features/annotation_reply}/__init__.py (100%) rename api/core/{features => app/features/annotation_reply}/annotation_reply.py (100%) create mode 100644 api/core/app/features/hosting_moderation/__init__.py rename api/core/{features => app/features/hosting_moderation}/hosting_moderation.py (100%) rename api/core/{app_runner => app}/generate_task_pipeline.py (98%) create mode 100644 api/core/app/validators/__init__.py rename api/core/{apps/config_validators/dataset.py => app/validators/dataset_retrieval.py} (100%) rename api/core/{apps/config_validators/external_data_tools.py => app/validators/external_data_fetch.py} (97%) rename api/core/{apps/config_validators => app/validators}/file_upload.py (100%) rename api/core/{apps/config_validators/model.py => app/validators/model_validator.py} (100%) rename api/core/{apps/config_validators => app/validators}/moderation.py (100%) rename api/core/{apps/config_validators => app/validators}/more_like_this.py (100%) rename api/core/{apps/config_validators => app/validators}/opening_statement.py (100%) rename api/core/{apps/config_validators => app/validators}/prompt.py (100%) rename api/core/{apps/config_validators => app/validators}/retriever_resource.py (100%) rename api/core/{apps/config_validators => app/validators}/speech_to_text.py (100%) rename api/core/{apps/config_validators => app/validators}/suggested_questions.py (100%) rename api/core/{apps/config_validators => app/validators}/text_to_speech.py (100%) rename api/core/{apps/config_validators => app/validators}/user_input_form.py (100%) create mode 100644 api/core/app/workflow/__init__.py rename api/core/{apps/app_config_validators/workflow_app.py => app/workflow/config_validator.py} (83%) delete mode 100644 api/core/apps/app_config_validators/agent_chat_app.py delete mode 100644 api/core/apps/config_validators/agent.py rename api/core/{features => external_data_tool}/external_data_fetch.py (98%) create mode 100644 api/core/llm_generator/__init__.py rename api/core/{generator => llm_generator}/llm_generator.py (93%) create mode 100644 api/core/llm_generator/output_parser/__init__.py rename api/core/{prompt => llm_generator}/output_parser/rule_config_generator.py (94%) rename api/core/{prompt => llm_generator}/output_parser/suggested_questions_after_answer.py (87%) rename api/core/{prompt => llm_generator}/prompts.py (100%) rename api/core/{features/moderation.py => moderation/input_moderation.py} (98%) rename api/core/{app_runner/moderation_handler.py => moderation/output_moderation.py} (97%) create mode 100644 api/core/prompt/__init__.py create mode 100644 api/core/prompt/prompt_templates/__init__.py rename api/core/prompt/{ => prompt_templates}/advanced_prompt_templates.py (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/baichuan_chat.json (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/baichuan_completion.json (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/common_chat.json (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/common_completion.json (100%) create mode 100644 api/core/prompt/utils/__init__.py rename api/core/prompt/{prompt_template.py => utils/prompt_template_parser.py} (100%) create mode 100644 api/core/rag/retrieval/__init__.py create mode 100644 api/core/rag/retrieval/agent/__init__.py rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/agent_llm_callback.py (100%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/fake_llm.py (100%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/llm_chain.py (91%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/multi_dataset_router_agent.py (98%) create mode 100644 api/core/rag/retrieval/agent/output_parser/__init__.py rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/output_parser/structured_chat.py (100%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/structed_multi_dataset_router_agent.py (99%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent_based_dataset_executor.py (92%) rename api/core/{features/dataset_retrieval => rag/retrieval}/dataset_retrieval.py (98%) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index e62475308f..0632c0439b 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -21,7 +21,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -94,7 +94,7 @@ class CompletionMessageStopApi(Resource): def post(self, app_model, task_id): account = flask_login.current_user - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) return {'result': 'success'}, 200 @@ -172,7 +172,7 @@ class ChatMessageStopApi(Resource): def post(self, app_model, task_id): account = flask_login.current_user - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) return {'result': 'success'}, 200 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 3ec932b5f1..ee02fc1846 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -11,7 +11,7 @@ from controllers.console.app.error import ( from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 6406d5b3b0..22ea4bbac2 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -21,7 +21,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -90,7 +90,7 @@ class CompletionStopApi(InstalledAppResource): if app_model.mode != 'completion': raise NotCompletionAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {'result': 'success'}, 200 @@ -154,7 +154,7 @@ class ChatStopApi(InstalledAppResource): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index c6cfb24378..fd4ce831b3 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -19,7 +19,7 @@ from controllers.service_api.app.error import ( ProviderQuotaExceededError, ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -85,7 +85,7 @@ class CompletionStopApi(Resource): if app_model.mode != 'completion': raise AppUnavailableError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 @@ -147,7 +147,7 @@ class ChatStopApi(Resource): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 61d4f8c362..fd94ec7646 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -20,7 +20,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -84,7 +84,7 @@ class CompletionStopApi(WebApiResource): if app_model.mode != 'completion': raise NotCompletionAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {'result': 'success'}, 200 @@ -144,7 +144,7 @@ class ChatStopApi(WebApiResource): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {'result': 'success'}, 200 diff --git a/api/core/app_runner/__init__.py b/api/core/agent/__init__.py similarity index 100% rename from api/core/app_runner/__init__.py rename to api/core/agent/__init__.py diff --git a/api/core/features/assistant_base_runner.py b/api/core/agent/base_agent_runner.py similarity index 99% rename from api/core/features/assistant_base_runner.py rename to api/core/agent/base_agent_runner.py index 1d9541070f..0658124d14 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,8 +5,8 @@ from datetime import datetime from mimetypes import guess_extension from typing import Optional, Union, cast -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( @@ -48,13 +48,13 @@ from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) -class BaseAssistantApplicationRunner(AppRunner): +class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, application_generate_entity: ApplicationGenerateEntity, app_orchestration_config: AppOrchestrationConfigEntity, model_config: ModelConfigEntity, config: AgentEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, message: Message, user_id: str, memory: Optional[TokenBufferMemory] = None, diff --git a/api/core/features/assistant_cot_runner.py b/api/core/agent/cot_agent_runner.py similarity index 99% rename from api/core/features/assistant_cot_runner.py rename to api/core/agent/cot_agent_runner.py index 3762ddcf62..152e445795 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -3,9 +3,9 @@ import re from collections.abc import Generator from typing import Literal, Union -from core.application_queue_manager import PublishFrom +from core.app.app_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit -from core.features.assistant_base_runner import BaseAssistantApplicationRunner +from core.agent.base_agent_runner import BaseAgentRunner from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -262,7 +262,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): tool_call_args = json.loads(tool_call_args) except json.JSONDecodeError: pass - + tool_response = tool_instance.invoke( user_id=self.user_id, tool_parameters=tool_call_args diff --git a/api/core/features/assistant_fc_runner.py b/api/core/agent/fc_agent_runner.py similarity index 98% rename from api/core/features/assistant_fc_runner.py rename to api/core/agent/fc_agent_runner.py index 391e040c53..0cf0d3762c 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -3,8 +3,8 @@ import logging from collections.abc import Generator from typing import Any, Union -from core.application_queue_manager import PublishFrom -from core.features.assistant_base_runner import BaseAssistantApplicationRunner +from core.app.app_queue_manager import PublishFrom +from core.agent.base_agent_runner import BaseAgentRunner from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -26,7 +26,7 @@ from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) -class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): +class FunctionCallAgentRunner(BaseAgentRunner): def run(self, conversation: Conversation, message: Message, query: str, diff --git a/api/core/apps/__init__.py b/api/core/app/__init__.py similarity index 100% rename from api/core/apps/__init__.py rename to api/core/app/__init__.py diff --git a/api/core/apps/app_config_validators/__init__.py b/api/core/app/advanced_chat/__init__.py similarity index 100% rename from api/core/apps/app_config_validators/__init__.py rename to api/core/app/advanced_chat/__init__.py diff --git a/api/core/apps/app_config_validators/advanced_chat_app.py b/api/core/app/advanced_chat/config_validator.py similarity index 77% rename from api/core/apps/app_config_validators/advanced_chat_app.py rename to api/core/app/advanced_chat/config_validator.py index dc7664b844..39c00c028e 100644 --- a/api/core/apps/app_config_validators/advanced_chat_app.py +++ b/api/core/app/advanced_chat/config_validator.py @@ -1,10 +1,10 @@ -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.opening_statement import OpeningStatementValidator -from core.apps.config_validators.retriever_resource import RetrieverResourceValidator -from core.apps.config_validators.speech_to_text import SpeechToTextValidator -from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.opening_statement import OpeningStatementValidator +from core.app.validators.retriever_resource import RetrieverResourceValidator +from core.app.validators.speech_to_text import SpeechToTextValidator +from core.app.validators.suggested_questions import SuggestedQuestionsValidator +from core.app.validators.text_to_speech import TextToSpeechValidator class AdvancedChatAppConfigValidator: diff --git a/api/core/apps/config_validators/__init__.py b/api/core/app/agent_chat/__init__.py similarity index 100% rename from api/core/apps/config_validators/__init__.py rename to api/core/app/agent_chat/__init__.py diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app/agent_chat/app_runner.py similarity index 95% rename from api/core/app_runner/assistant_app_runner.py rename to api/core/app/agent_chat/app_runner.py index 655a5a1c7c..b046e935a5 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app/agent_chat/app_runner.py @@ -1,11 +1,11 @@ import logging from typing import cast -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity -from core.features.assistant_cot_runner import AssistantCotApplicationRunner -from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner +from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -19,12 +19,13 @@ from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) -class AssistantApplicationRunner(AppRunner): + +class AgentChatAppRunner(AppRunner): """ - Assistant Application Runner + Agent Application Runner """ def run(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -201,7 +202,7 @@ class AssistantApplicationRunner(AppRunner): # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: - assistant_cot_runner = AssistantCotApplicationRunner( + assistant_cot_runner = CotAgentRunner( tenant_id=application_generate_entity.tenant_id, application_generate_entity=application_generate_entity, app_orchestration_config=app_orchestration_config, @@ -223,7 +224,7 @@ class AssistantApplicationRunner(AppRunner): inputs=inputs, ) elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: - assistant_fc_runner = AssistantFunctionCallApplicationRunner( + assistant_fc_runner = FunctionCallAgentRunner( tenant_id=application_generate_entity.tenant_id, application_generate_entity=application_generate_entity, app_orchestration_config=app_orchestration_config, diff --git a/api/core/app/agent_chat/config_validator.py b/api/core/app/agent_chat/config_validator.py new file mode 100644 index 0000000000..6596b19f99 --- /dev/null +++ b/api/core/app/agent_chat/config_validator.py @@ -0,0 +1,162 @@ +import uuid + +from core.entities.agent_entities import PlanningStrategy +from core.app.validators.dataset_retrieval import DatasetValidator +from core.app.validators.external_data_fetch import ExternalDataFetchValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.model_validator import ModelValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.opening_statement import OpeningStatementValidator +from core.app.validators.prompt import PromptValidator +from core.app.validators.retriever_resource import RetrieverResourceValidator +from core.app.validators.speech_to_text import SpeechToTextValidator +from core.app.validators.suggested_questions import SuggestedQuestionsValidator +from core.app.validators.text_to_speech import TextToSpeechValidator +from core.app.validators.user_input_form import UserInputFormValidator +from models.model import AppMode + + +OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] + + +class AgentChatAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for agent chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.AGENT_CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # external data tools validation + config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # agent_mode + config, current_related_config_keys = cls.validate_agent_mode_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config + + @classmethod + def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate agent_mode and set defaults for agent feature + + :param tenant_id: tenant ID + :param config: app model config args + """ + if not config.get("agent_mode"): + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + if config["agent_mode"]["strategy"] not in [member.value for member in + list(PlanningStrategy.__members__.values())]: + raise ValueError("strategy in agent_mode must be in the specified strategy list") + + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key in OLD_TOOLS: + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if key == "dataset": + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + else: + # latest style, use key-value pair + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + if "provider_type" not in tool: + raise ValueError("provider_type is required in agent_mode.tools") + if "provider_id" not in tool: + raise ValueError("provider_id is required in agent_mode.tools") + if "tool_name" not in tool: + raise ValueError("tool_name is required in agent_mode.tools") + if "tool_parameters" not in tool: + raise ValueError("tool_parameters is required in agent_mode.tools") + + return config, ["agent_mode"] diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py new file mode 100644 index 0000000000..0819ed864b --- /dev/null +++ b/api/core/app/app_manager.py @@ -0,0 +1,382 @@ +import json +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Any, Optional, Union, cast + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter +from core.app.agent_chat.app_runner import AgentChatAppRunner +from core.app.chat.app_runner import ChatAppRunner +from core.app.generate_task_pipeline import GenerateTaskPipeline +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.entities.application_entities import ( + ApplicationGenerateEntity, + InvokeFrom, +) +from core.file.file_obj import FileObj +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from extensions.ext_database import db +from models.account import Account +from models.model import App, Conversation, EndUser, Message, MessageFile + +logger = logging.getLogger(__name__) + + +class AppManager: + """ + This class is responsible for managing application + """ + + def generate(self, tenant_id: str, + app_id: str, + app_model_config_id: str, + app_model_config_dict: dict, + app_model_config_override: bool, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + inputs: dict[str, str], + query: Optional[str] = None, + files: Optional[list[FileObj]] = None, + conversation: Optional[Conversation] = None, + stream: bool = False, + extras: Optional[dict[str, Any]] = None) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param tenant_id: workspace ID + :param app_id: app ID + :param app_model_config_id: app model config id + :param app_model_config_dict: app model config dict + :param app_model_config_override: app model config override + :param user: account or end user + :param invoke_from: invoke from source + :param inputs: inputs + :param query: query + :param files: file obj list + :param conversation: conversation + :param stream: is stream + :param extras: extras + """ + # init task id + task_id = str(uuid.uuid4()) + + # init application generate entity + application_generate_entity = ApplicationGenerateEntity( + task_id=task_id, + tenant_id=tenant_id, + app_id=app_id, + app_model_config_id=app_model_config_id, + app_model_config_dict=app_model_config_dict, + app_orchestration_config_entity=AppOrchestrationConfigConverter.convert_from_app_model_config_dict( + tenant_id=tenant_id, + app_model_config_dict=app_model_config_dict + ), + app_model_config_override=app_model_config_override, + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else inputs, + query=query.replace('\x00', '') if query else None, + files=files if files else [], + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + if not stream and application_generate_entity.app_orchestration_config_entity.agent: + raise ValueError("Agent app is not supported in blocking mode.") + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: ApplicationGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + if application_generate_entity.app_orchestration_config_entity.agent: + # agent app + runner = AgentChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + else: + # basic app + runner = ChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = GenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + try: + return generate_task_pipeline.process(stream=stream) + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise ConversationTaskStoppedException() + else: + logger.exception(e) + raise e + finally: + db.session.remove() + + def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ + -> tuple[Conversation, Message]: + """ + Initialize generate records + :param application_generate_entity: application generate entity + :return: + """ + app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity + + model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_schema = model_type_instance.get_model_schema( + model=app_orchestration_config_entity.model_config.model, + credentials=app_orchestration_config_entity.model_config.credentials + ) + + app_record = (db.session.query(App) + .filter(App.id == application_generate_entity.app_id).first()) + + app_mode = app_record.mode + + # get from source + end_user_id = None + account_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + from_source = 'api' + end_user_id = application_generate_entity.user_id + else: + from_source = 'console' + account_id = application_generate_entity.user_id + + override_model_configs = None + if application_generate_entity.app_model_config_override: + override_model_configs = application_generate_entity.app_model_config_dict + + introduction = '' + if app_mode == 'chat': + # get conversation introduction + introduction = self._get_conversation_introduction(application_generate_entity) + + if not application_generate_entity.conversation_id: + conversation = Conversation( + app_id=app_record.id, + app_model_config_id=application_generate_entity.app_model_config_id, + model_provider=app_orchestration_config_entity.model_config.provider, + model_id=app_orchestration_config_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + mode=app_mode, + name='New conversation', + inputs=application_generate_entity.inputs, + introduction=introduction, + system_instruction="", + system_instruction_tokens=0, + status='normal', + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + ) + + db.session.add(conversation) + db.session.commit() + else: + conversation = ( + db.session.query(Conversation) + .filter( + Conversation.id == application_generate_entity.conversation_id, + Conversation.app_id == app_record.id + ).first() + ) + + currency = model_schema.pricing.currency if model_schema.pricing else 'USD' + + message = Message( + app_id=app_record.id, + model_provider=app_orchestration_config_entity.model_config.provider, + model_id=app_orchestration_config_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + conversation_id=conversation.id, + inputs=application_generate_entity.inputs, + query=application_generate_entity.query or "", + message="", + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer="", + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + provider_response_latency=0, + total_price=0, + currency=currency, + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + agent_based=app_orchestration_config_entity.agent is not None + ) + + db.session.add(message) + db.session.commit() + + for file in application_generate_entity.files: + message_file = MessageFile( + message_id=message.id, + type=file.type.value, + transfer_method=file.transfer_method.value, + belongs_to='user', + url=file.url, + upload_file_id=file.upload_file_id, + created_by_role=('account' if account_id else 'end_user'), + created_by=account_id or end_user_id, + ) + db.session.add(message_file) + db.session.commit() + + return conversation, message + + def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: + """ + Get conversation introduction + :param application_generate_entity: application generate entity + :return: conversation introduction + """ + app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity + introduction = app_orchestration_config_entity.opening_statement + + if introduction: + try: + inputs = application_generate_entity.inputs + prompt_template = PromptTemplateParser(template=introduction) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + introduction = prompt_template.format(prompt_inputs) + except KeyError: + pass + + return introduction + + def _get_conversation(self, conversation_id: str) -> Conversation: + """ + Get conversation by conversation id + :param conversation_id: conversation id + :return: conversation + """ + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) + + return conversation + + def _get_message(self, message_id: str) -> Message: + """ + Get message by message id + :param message_id: message id + :return: message + """ + message = ( + db.session.query(Message) + .filter(Message.id == message_id) + .first() + ) + + return message diff --git a/api/core/application_manager.py b/api/core/app/app_orchestration_config_converter.py similarity index 52% rename from api/core/application_manager.py rename to api/core/app/app_orchestration_config_converter.py index ea0c85427d..ddf49949a3 100644 --- a/api/core/application_manager.py +++ b/api/core/app/app_orchestration_config_converter.py @@ -1,241 +1,21 @@ -import json -import logging -import threading -import uuid -from collections.abc import Generator -from typing import Any, Optional, Union, cast +from typing import cast -from flask import Flask, current_app -from pydantic import ValidationError - -from core.app_runner.assistant_app_runner import AssistantApplicationRunner -from core.app_runner.basic_app_runner import BasicApplicationRunner -from core.app_runner.generate_task_pipeline import GenerateTaskPipeline -from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom -from core.entities.application_entities import ( - AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, - AgentEntity, - AgentPromptEntity, - AgentToolEntity, - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - FileUploadEntity, - InvokeFrom, - ModelConfigEntity, - PromptTemplateEntity, - SensitiveWordAvoidanceEntity, - TextToSpeechEntity, - VariableEntity, -) +from core.entities.application_entities import AppOrchestrationConfigEntity, SensitiveWordAvoidanceEntity, \ + TextToSpeechEntity, DatasetRetrieveConfigEntity, DatasetEntity, AgentPromptEntity, AgentEntity, AgentToolEntity, \ + ExternalDataVariableEntity, VariableEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, \ + AdvancedChatPromptTemplateEntity, ModelConfigEntity, FileUploadEntity from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.file.file_obj import FileObj +from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_template import PromptTemplateParser from core.provider_manager import ProviderManager from core.tools.prompt.template import REACT_PROMPT_TEMPLATES -from extensions.ext_database import db -from models.account import Account -from models.model import App, Conversation, EndUser, Message, MessageFile - -logger = logging.getLogger(__name__) -class ApplicationManager: - """ - This class is responsible for managing application - """ - - def generate(self, tenant_id: str, - app_id: str, - app_model_config_id: str, - app_model_config_dict: dict, - app_model_config_override: bool, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - inputs: dict[str, str], - query: Optional[str] = None, - files: Optional[list[FileObj]] = None, - conversation: Optional[Conversation] = None, - stream: bool = False, - extras: Optional[dict[str, Any]] = None) \ - -> Union[dict, Generator]: - """ - Generate App response. - - :param tenant_id: workspace ID - :param app_id: app ID - :param app_model_config_id: app model config id - :param app_model_config_dict: app model config dict - :param app_model_config_override: app model config override - :param user: account or end user - :param invoke_from: invoke from source - :param inputs: inputs - :param query: query - :param files: file obj list - :param conversation: conversation - :param stream: is stream - :param extras: extras - """ - # init task id - task_id = str(uuid.uuid4()) - - # init application generate entity - application_generate_entity = ApplicationGenerateEntity( - task_id=task_id, - tenant_id=tenant_id, - app_id=app_id, - app_model_config_id=app_model_config_id, - app_model_config_dict=app_model_config_dict, - app_orchestration_config_entity=self.convert_from_app_model_config_dict( - tenant_id=tenant_id, - app_model_config_dict=app_model_config_dict - ), - app_model_config_override=app_model_config_override, - conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else inputs, - query=query.replace('\x00', '') if query else None, - files=files if files else [], - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras=extras - ) - - if not stream and application_generate_entity.app_orchestration_config_entity.agent: - raise ValueError("Agent app is not supported in blocking mode.") - - # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = ApplicationQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id - ) - - # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) - - worker_thread.start() - - # return response or stream generator - return self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - stream=stream - ) - - def _generate_worker(self, flask_app: Flask, - application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - conversation_id: str, - message_id: str) -> None: - """ - Generate worker in a new thread. - :param flask_app: Flask app - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation_id: conversation ID - :param message_id: message ID - :return: - """ - with flask_app.app_context(): - try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - if application_generate_entity.app_orchestration_config_entity.agent: - # agent app - runner = AssistantApplicationRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - else: - # basic app - runner = BasicApplicationRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - except ConversationTaskStoppedException: - pass - except InvokeAuthorizationError: - queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER - ) - except ValidationError as e: - logger.exception("Validation Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except Exception as e: - logger.exception("Unknown Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - finally: - db.session.close() - - def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - conversation: Conversation, - message: Message, - stream: bool = False) -> Union[dict, Generator]: - """ - Handle response. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param stream: is stream - :return: - """ - # init generate task pipeline - generate_task_pipeline = GenerateTaskPipeline( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - - try: - return generate_task_pipeline.process(stream=stream) - except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() - else: - logger.exception(e) - raise e - - def convert_from_app_model_config_dict(self, tenant_id: str, +class AppOrchestrationConfigConverter: + @classmethod + def convert_from_app_model_config_dict(cls, tenant_id: str, app_model_config_dict: dict, skip_check: bool = False) \ -> AppOrchestrationConfigEntity: @@ -394,7 +174,7 @@ class ApplicationManager: ) properties['variables'] = [] - + # variables and external_data_tools for variable in copy_app_model_config_dict.get('user_input_form', []): typ = list(variable.keys())[0] @@ -444,7 +224,7 @@ class ApplicationManager: show_retrieve_source = True properties['show_retrieve_source'] = show_retrieve_source - + dataset_ids = [] if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}): datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', { @@ -452,26 +232,23 @@ class ApplicationManager: 'datasets': [] }) - for dataset in datasets.get('datasets', []): keys = list(dataset.keys()) if len(keys) == 0 or keys[0] != 'dataset': continue dataset = dataset['dataset'] - + if 'enabled' not in dataset or not dataset['enabled']: continue - + dataset_id = dataset.get('id', None) if dataset_id: dataset_ids.append(dataset_id) - else: - datasets = {'strategy': 'router', 'datasets': []} if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ and 'enabled' in copy_app_model_config_dict['agent_mode'] \ and copy_app_model_config_dict['agent_mode']['enabled']: - + agent_dict = copy_app_model_config_dict.get('agent_mode', {}) agent_strategy = agent_dict.get('strategy', 'cot') @@ -515,7 +292,7 @@ class ApplicationManager: dataset_id = tool_item['id'] dataset_ids.append(dataset_id) - + if 'strategy' in copy_app_model_config_dict['agent_mode'] and \ copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']: agent_prompt = agent_dict.get('prompt', None) or {} @@ -523,13 +300,18 @@ class ApplicationManager: model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion') if model_mode == 'completion': agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']), + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['completion'][ + 'agent_scratchpad']), ) else: agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), ) properties['agent'] = AgentEntity( @@ -551,7 +333,7 @@ class ApplicationManager: dataset_ids=dataset_ids, retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs['retrieval_model'] ) ) @@ -624,169 +406,3 @@ class ApplicationManager: ) return AppOrchestrationConfigEntity(**properties) - - def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ - -> tuple[Conversation, Message]: - """ - Initialize generate records - :param application_generate_entity: application generate entity - :return: - """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - - model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_schema = model_type_instance.get_model_schema( - model=app_orchestration_config_entity.model_config.model, - credentials=app_orchestration_config_entity.model_config.credentials - ) - - app_record = (db.session.query(App) - .filter(App.id == application_generate_entity.app_id).first()) - - app_mode = app_record.mode - - # get from source - end_user_id = None - account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' - end_user_id = application_generate_entity.user_id - else: - from_source = 'console' - account_id = application_generate_entity.user_id - - override_model_configs = None - if application_generate_entity.app_model_config_override: - override_model_configs = application_generate_entity.app_model_config_dict - - introduction = '' - if app_mode == 'chat': - # get conversation introduction - introduction = self._get_conversation_introduction(application_generate_entity) - - if not application_generate_entity.conversation_id: - conversation = Conversation( - app_id=app_record.id, - app_model_config_id=application_generate_entity.app_model_config_id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - mode=app_mode, - name='New conversation', - inputs=application_generate_entity.inputs, - introduction=introduction, - system_instruction="", - system_instruction_tokens=0, - status='normal', - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - ) - - db.session.add(conversation) - db.session.commit() - db.session.refresh(conversation) - else: - conversation = ( - db.session.query(Conversation) - .filter( - Conversation.id == application_generate_entity.conversation_id, - Conversation.app_id == app_record.id - ).first() - ) - - currency = model_schema.pricing.currency if model_schema.pricing else 'USD' - - message = Message( - app_id=app_record.id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - conversation_id=conversation.id, - inputs=application_generate_entity.inputs, - query=application_generate_entity.query or "", - message="", - message_tokens=0, - message_unit_price=0, - message_price_unit=0, - answer="", - answer_tokens=0, - answer_unit_price=0, - answer_price_unit=0, - provider_response_latency=0, - total_price=0, - currency=currency, - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - agent_based=app_orchestration_config_entity.agent is not None - ) - - db.session.add(message) - db.session.commit() - db.session.refresh(message) - - for file in application_generate_entity.files: - message_file = MessageFile( - message_id=message.id, - type=file.type.value, - transfer_method=file.transfer_method.value, - belongs_to='user', - url=file.url, - upload_file_id=file.upload_file_id, - created_by_role=('account' if account_id else 'end_user'), - created_by=account_id or end_user_id, - ) - db.session.add(message_file) - db.session.commit() - - return conversation, message - - def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: - """ - Get conversation introduction - :param application_generate_entity: application generate entity - :return: conversation introduction - """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - introduction = app_orchestration_config_entity.opening_statement - - if introduction: - try: - inputs = application_generate_entity.inputs - prompt_template = PromptTemplateParser(template=introduction) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - introduction = prompt_template.format(prompt_inputs) - except KeyError: - pass - - return introduction - - def _get_conversation(self, conversation_id: str) -> Conversation: - """ - Get conversation by conversation id - :param conversation_id: conversation id - :return: conversation - """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) - - return conversation - - def _get_message(self, message_id: str) -> Message: - """ - Get message by message id - :param message_id: message id - :return: message - """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) - - return message diff --git a/api/core/application_queue_manager.py b/api/core/app/app_queue_manager.py similarity index 97% rename from api/core/application_queue_manager.py rename to api/core/app/app_queue_manager.py index 9590a1e726..c09cae3245 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/app/app_queue_manager.py @@ -32,7 +32,7 @@ class PublishFrom(Enum): TASK_PIPELINE = 2 -class ApplicationQueueManager: +class AppQueueManager: def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, @@ -50,7 +50,7 @@ class ApplicationQueueManager: self._message_id = str(message_id) user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") + redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") q = queue.Queue() @@ -239,7 +239,7 @@ class ApplicationQueueManager: Check if task is stopped :return: """ - stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id) + stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id) result = redis_client.get(stopped_cache_key) if result is not None: return True diff --git a/api/core/app_runner/app_runner.py b/api/core/app/base_app_runner.py similarity index 94% rename from api/core/app_runner/app_runner.py rename to api/core/app/base_app_runner.py index 95f2f568dc..788e3f91a3 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app/base_app_runner.py @@ -2,7 +2,7 @@ import time from collections.abc import Generator from typing import Optional, Union, cast -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import ( ApplicationGenerateEntity, AppOrchestrationConfigEntity, @@ -11,10 +11,10 @@ from core.entities.application_entities import ( ModelConfigEntity, PromptTemplateEntity, ) -from core.features.annotation_reply import AnnotationReplyFeature -from core.features.external_data_fetch import ExternalDataFetchFeature -from core.features.hosting_moderation import HostingModerationFeature -from core.features.moderation import ModerationFeature +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.external_data_tool.external_data_fetch import ExternalDataFetch +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature +from core.moderation.input_moderation import InputModeration from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -169,7 +169,7 @@ class AppRunner: return prompt_messages, stop - def direct_output(self, queue_manager: ApplicationQueueManager, + def direct_output(self, queue_manager: AppQueueManager, app_orchestration_config: AppOrchestrationConfigEntity, prompt_messages: list, text: str, @@ -210,7 +210,7 @@ class AppRunner: ) def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, stream: bool, agent: bool = False) -> None: """ @@ -234,7 +234,7 @@ class AppRunner: ) def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, agent: bool) -> None: """ Handle invoke result direct @@ -248,7 +248,7 @@ class AppRunner: ) def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, agent: bool) -> None: """ Handle invoke result @@ -306,7 +306,7 @@ class AppRunner: :param query: query :return: """ - moderation_feature = ModerationFeature() + moderation_feature = InputModeration() return moderation_feature.check( app_id=app_id, tenant_id=tenant_id, @@ -316,7 +316,7 @@ class AppRunner: ) def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: """ Check hosting moderation @@ -358,7 +358,7 @@ class AppRunner: :param query: the query :return: the filled inputs """ - external_data_fetch_feature = ExternalDataFetchFeature() + external_data_fetch_feature = ExternalDataFetch() return external_data_fetch_feature.fetch( tenant_id=tenant_id, app_id=app_id, @@ -388,4 +388,4 @@ class AppRunner: query=query, user_id=user_id, invoke_from=invoke_from - ) \ No newline at end of file + ) diff --git a/api/core/features/__init__.py b/api/core/app/chat/__init__.py similarity index 100% rename from api/core/features/__init__.py rename to api/core/app/chat/__init__.py diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app/chat/app_runner.py similarity index 95% rename from api/core/app_runner/basic_app_runner.py rename to api/core/app/chat/app_runner.py index 0e0fe6e3bf..a1613e37a2 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app/chat/app_runner.py @@ -1,8 +1,8 @@ import logging from typing import Optional -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, @@ -10,7 +10,7 @@ from core.entities.application_entities import ( InvokeFrom, ModelConfigEntity, ) -from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException @@ -20,13 +20,13 @@ from models.model import App, AppMode, Conversation, Message logger = logging.getLogger(__name__) -class BasicApplicationRunner(AppRunner): +class ChatAppRunner(AppRunner): """ - Basic Application Runner + Chat Application Runner """ def run(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -215,7 +215,7 @@ class BasicApplicationRunner(AppRunner): def retrieve_dataset_context(self, tenant_id: str, app_record: App, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, model_config: ModelConfigEntity, dataset_config: DatasetEntity, show_retrieve_source: bool, @@ -254,7 +254,7 @@ class BasicApplicationRunner(AppRunner): and dataset_config.retrieve_config.query_variable): query = inputs.get(dataset_config.retrieve_config.query_variable, "") - dataset_retrieval = DatasetRetrievalFeature() + dataset_retrieval = DatasetRetrieval() return dataset_retrieval.retrieve( tenant_id=tenant_id, model_config=model_config, diff --git a/api/core/apps/app_config_validators/chat_app.py b/api/core/app/chat/config_validator.py similarity index 75% rename from api/core/apps/app_config_validators/chat_app.py rename to api/core/app/chat/config_validator.py index 83c792e610..adb8408e28 100644 --- a/api/core/apps/app_config_validators/chat_app.py +++ b/api/core/app/chat/config_validator.py @@ -1,15 +1,15 @@ -from core.apps.config_validators.dataset import DatasetValidator -from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.model import ModelValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.opening_statement import OpeningStatementValidator -from core.apps.config_validators.prompt import PromptValidator -from core.apps.config_validators.retriever_resource import RetrieverResourceValidator -from core.apps.config_validators.speech_to_text import SpeechToTextValidator -from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator -from core.apps.config_validators.user_input_form import UserInputFormValidator +from core.app.validators.dataset_retrieval import DatasetValidator +from core.app.validators.external_data_fetch import ExternalDataFetchValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.model_validator import ModelValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.opening_statement import OpeningStatementValidator +from core.app.validators.prompt import PromptValidator +from core.app.validators.retriever_resource import RetrieverResourceValidator +from core.app.validators.speech_to_text import SpeechToTextValidator +from core.app.validators.suggested_questions import SuggestedQuestionsValidator +from core.app.validators.text_to_speech import TextToSpeechValidator +from core.app.validators.user_input_form import UserInputFormValidator from models.model import AppMode @@ -35,7 +35,7 @@ class ChatAppConfigValidator: related_config_keys.extend(current_related_config_keys) # external data tools validation - config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # file upload validation diff --git a/api/core/features/dataset_retrieval/__init__.py b/api/core/app/completion/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/__init__.py rename to api/core/app/completion/__init__.py diff --git a/api/core/app/completion/app_runner.py b/api/core/app/completion/app_runner.py new file mode 100644 index 0000000000..34c6a5156f --- /dev/null +++ b/api/core/app/completion/app_runner.py @@ -0,0 +1,266 @@ +import logging +from typing import Optional + +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.entities.application_entities import ( + ApplicationGenerateEntity, + DatasetEntity, + InvokeFrom, + ModelConfigEntity, +) +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.moderation.base import ModerationException +from extensions.ext_database import db +from models.model import App, AppMode, Conversation, Message + +logger = logging.getLogger(__name__) + + +class CompletionAppRunner(AppRunner): + """ + Completion Application Runner + """ + + def run(self, application_generate_entity: ApplicationGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param conversation: conversation + :param message: message + :return: + """ + app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + if not app_record: + raise ValueError("App not found") + + app_orchestration_config = application_generate_entity.app_orchestration_config_entity + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # Pre-calculate the number of tokens of the prompt messages, + # and return the rest number of tokens by model context token size limit and max token size limit. + # If the rest number of tokens is not enough, raise exception. + # Include: prompt template, inputs, query(optional), files(optional) + # Not Include: memory, external data, dataset context + self.get_pre_calculate_rest_tokens( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query + ) + + memory = None + if application_generate_entity.conversation_id: + # get memory of conversation (read-only) + model_instance = ModelInstance( + provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, + model=app_orchestration_config.model_config.model + ) + + memory = TokenBufferMemory( + conversation=conversation, + model_instance=model_instance + ) + + # organize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory + ) + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=application_generate_entity.tenant_id, + app_orchestration_config_entity=app_orchestration_config, + inputs=inputs, + query=query, + ) + except ModerationException as e: + self.direct_output( + queue_manager=queue_manager, + app_orchestration_config=app_orchestration_config, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream + ) + return + + if query: + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from + ) + + if annotation_reply: + queue_manager.publish_annotation_reply( + message_annotation_id=annotation_reply.id, + pub_from=PublishFrom.APPLICATION_MANAGER + ) + self.direct_output( + queue_manager=queue_manager, + app_orchestration_config=app_orchestration_config, + prompt_messages=prompt_messages, + text=annotation_reply.content, + stream=application_generate_entity.stream + ) + return + + # fill in variable inputs from external data tools if exists + external_data_tools = app_orchestration_config.external_data_variables + if external_data_tools: + inputs = self.fill_in_inputs_from_external_data_tools( + tenant_id=app_record.tenant_id, + app_id=app_record.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query + ) + + # get context from datasets + context = None + if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: + context = self.retrieve_dataset_context( + tenant_id=app_record.tenant_id, + app_record=app_record, + queue_manager=queue_manager, + model_config=app_orchestration_config.model_config, + show_retrieve_source=app_orchestration_config.show_retrieve_source, + dataset_config=app_orchestration_config.dataset, + message=message, + inputs=inputs, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + memory=memory + ) + + # reorganize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional), external data, dataset context(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query, + context=context, + memory=memory + ) + + # check hosting moderation + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages + ) + + if hosting_moderation_result: + return + + # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit + self.recale_llm_max_tokens( + model_config=app_orchestration_config.model_config, + prompt_messages=prompt_messages + ) + + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, + model=app_orchestration_config.model_config.model + ) + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=app_orchestration_config.model_config.parameters, + stop=stop, + stream=application_generate_entity.stream, + user=application_generate_entity.user_id, + ) + + # handle invoke result + self._handle_invoke_result( + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream + ) + + def retrieve_dataset_context(self, tenant_id: str, + app_record: App, + queue_manager: AppQueueManager, + model_config: ModelConfigEntity, + dataset_config: DatasetEntity, + show_retrieve_source: bool, + message: Message, + inputs: dict, + query: str, + user_id: str, + invoke_from: InvokeFrom, + memory: Optional[TokenBufferMemory] = None) -> Optional[str]: + """ + Retrieve dataset context + :param tenant_id: tenant id + :param app_record: app record + :param queue_manager: queue manager + :param model_config: model config + :param dataset_config: dataset config + :param show_retrieve_source: show retrieve source + :param message: message + :param inputs: inputs + :param query: query + :param user_id: user id + :param invoke_from: invoke from + :param memory: memory + :return: + """ + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + user_id, + invoke_from + ) + + # TODO + if (app_record.mode == AppMode.COMPLETION.value and dataset_config + and dataset_config.retrieve_config.query_variable): + query = inputs.get(dataset_config.retrieve_config.query_variable, "") + + dataset_retrieval = DatasetRetrieval() + return dataset_retrieval.retrieve( + tenant_id=tenant_id, + model_config=model_config, + config=dataset_config, + query=query, + invoke_from=invoke_from, + show_retrieve_source=show_retrieve_source, + hit_callback=hit_callback, + memory=memory + ) + \ No newline at end of file diff --git a/api/core/apps/app_config_validators/completion_app.py b/api/core/app/completion/config_validator.py similarity index 76% rename from api/core/apps/app_config_validators/completion_app.py rename to api/core/app/completion/config_validator.py index 00371f8d05..7cc35efd64 100644 --- a/api/core/apps/app_config_validators/completion_app.py +++ b/api/core/app/completion/config_validator.py @@ -1,12 +1,12 @@ -from core.apps.config_validators.dataset import DatasetValidator -from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.model import ModelValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.more_like_this import MoreLikeThisValidator -from core.apps.config_validators.prompt import PromptValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator -from core.apps.config_validators.user_input_form import UserInputFormValidator +from core.app.validators.dataset_retrieval import DatasetValidator +from core.app.validators.external_data_fetch import ExternalDataFetchValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.model_validator import ModelValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.more_like_this import MoreLikeThisValidator +from core.app.validators.prompt import PromptValidator +from core.app.validators.text_to_speech import TextToSpeechValidator +from core.app.validators.user_input_form import UserInputFormValidator from models.model import AppMode @@ -32,7 +32,7 @@ class CompletionAppConfigValidator: related_config_keys.extend(current_related_config_keys) # external data tools validation - config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # file upload validation diff --git a/api/core/features/dataset_retrieval/agent/__init__.py b/api/core/app/features/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/__init__.py rename to api/core/app/features/__init__.py diff --git a/api/core/features/dataset_retrieval/agent/output_parser/__init__.py b/api/core/app/features/annotation_reply/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/output_parser/__init__.py rename to api/core/app/features/annotation_reply/__init__.py diff --git a/api/core/features/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py similarity index 100% rename from api/core/features/annotation_reply.py rename to api/core/app/features/annotation_reply/annotation_reply.py diff --git a/api/core/app/features/hosting_moderation/__init__.py b/api/core/app/features/hosting_moderation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/features/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py similarity index 100% rename from api/core/features/hosting_moderation.py rename to api/core/app/features/hosting_moderation/hosting_moderation.py diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py similarity index 98% rename from api/core/app_runner/generate_task_pipeline.py rename to api/core/app/generate_task_pipeline.py index 1cc56483ad..6d52fa7348 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -6,8 +6,8 @@ from typing import Optional, Union, cast from pydantic import BaseModel -from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom from core.entities.queue_entities import ( AnnotationReplyEvent, @@ -35,7 +35,7 @@ from core.model_runtime.entities.message_entities import ( from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created from extensions.ext_database import db @@ -59,7 +59,7 @@ class GenerateTaskPipeline: """ def __init__(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -633,7 +633,7 @@ class GenerateTaskPipeline: return prompts - def _init_output_moderation(self) -> Optional[OutputModerationHandler]: + def _init_output_moderation(self) -> Optional[OutputModeration]: """ Init output moderation. :return: @@ -642,7 +642,7 @@ class GenerateTaskPipeline: sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance if sensitive_word_avoidance: - return OutputModerationHandler( + return OutputModeration( tenant_id=self._application_generate_entity.tenant_id, app_id=self._application_generate_entity.app_id, rule=ModerationRule( diff --git a/api/core/app/validators/__init__.py b/api/core/app/validators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/apps/config_validators/dataset.py b/api/core/app/validators/dataset_retrieval.py similarity index 100% rename from api/core/apps/config_validators/dataset.py rename to api/core/app/validators/dataset_retrieval.py diff --git a/api/core/apps/config_validators/external_data_tools.py b/api/core/app/validators/external_data_fetch.py similarity index 97% rename from api/core/apps/config_validators/external_data_tools.py rename to api/core/app/validators/external_data_fetch.py index 02ecc8d715..5910aa17e7 100644 --- a/api/core/apps/config_validators/external_data_tools.py +++ b/api/core/app/validators/external_data_fetch.py @@ -2,7 +2,7 @@ from core.external_data_tool.factory import ExternalDataToolFactory -class ExternalDataToolsValidator: +class ExternalDataFetchValidator: @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/apps/config_validators/file_upload.py b/api/core/app/validators/file_upload.py similarity index 100% rename from api/core/apps/config_validators/file_upload.py rename to api/core/app/validators/file_upload.py diff --git a/api/core/apps/config_validators/model.py b/api/core/app/validators/model_validator.py similarity index 100% rename from api/core/apps/config_validators/model.py rename to api/core/app/validators/model_validator.py diff --git a/api/core/apps/config_validators/moderation.py b/api/core/app/validators/moderation.py similarity index 100% rename from api/core/apps/config_validators/moderation.py rename to api/core/app/validators/moderation.py diff --git a/api/core/apps/config_validators/more_like_this.py b/api/core/app/validators/more_like_this.py similarity index 100% rename from api/core/apps/config_validators/more_like_this.py rename to api/core/app/validators/more_like_this.py diff --git a/api/core/apps/config_validators/opening_statement.py b/api/core/app/validators/opening_statement.py similarity index 100% rename from api/core/apps/config_validators/opening_statement.py rename to api/core/app/validators/opening_statement.py diff --git a/api/core/apps/config_validators/prompt.py b/api/core/app/validators/prompt.py similarity index 100% rename from api/core/apps/config_validators/prompt.py rename to api/core/app/validators/prompt.py diff --git a/api/core/apps/config_validators/retriever_resource.py b/api/core/app/validators/retriever_resource.py similarity index 100% rename from api/core/apps/config_validators/retriever_resource.py rename to api/core/app/validators/retriever_resource.py diff --git a/api/core/apps/config_validators/speech_to_text.py b/api/core/app/validators/speech_to_text.py similarity index 100% rename from api/core/apps/config_validators/speech_to_text.py rename to api/core/app/validators/speech_to_text.py diff --git a/api/core/apps/config_validators/suggested_questions.py b/api/core/app/validators/suggested_questions.py similarity index 100% rename from api/core/apps/config_validators/suggested_questions.py rename to api/core/app/validators/suggested_questions.py diff --git a/api/core/apps/config_validators/text_to_speech.py b/api/core/app/validators/text_to_speech.py similarity index 100% rename from api/core/apps/config_validators/text_to_speech.py rename to api/core/app/validators/text_to_speech.py diff --git a/api/core/apps/config_validators/user_input_form.py b/api/core/app/validators/user_input_form.py similarity index 100% rename from api/core/apps/config_validators/user_input_form.py rename to api/core/app/validators/user_input_form.py diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/apps/app_config_validators/workflow_app.py b/api/core/app/workflow/config_validator.py similarity index 83% rename from api/core/apps/app_config_validators/workflow_app.py rename to api/core/app/workflow/config_validator.py index 545d3d79a3..b76eabaeb5 100644 --- a/api/core/apps/app_config_validators/workflow_app.py +++ b/api/core/app/workflow/config_validator.py @@ -1,6 +1,6 @@ -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.text_to_speech import TextToSpeechValidator class WorkflowAppConfigValidator: diff --git a/api/core/apps/app_config_validators/agent_chat_app.py b/api/core/apps/app_config_validators/agent_chat_app.py deleted file mode 100644 index d507fae685..0000000000 --- a/api/core/apps/app_config_validators/agent_chat_app.py +++ /dev/null @@ -1,82 +0,0 @@ -from core.apps.config_validators.agent import AgentValidator -from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.model import ModelValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.opening_statement import OpeningStatementValidator -from core.apps.config_validators.prompt import PromptValidator -from core.apps.config_validators.retriever_resource import RetrieverResourceValidator -from core.apps.config_validators.speech_to_text import SpeechToTextValidator -from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator -from core.apps.config_validators.user_input_form import UserInputFormValidator -from models.model import AppMode - - -class AgentChatAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: - """ - Validate for agent chat app model config - - :param tenant_id: tenant id - :param config: app model config args - """ - app_mode = AppMode.AGENT_CHAT - - related_config_keys = [] - - # model - config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # user_input_form - config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # external data tools validation - config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # prompt - config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # agent_mode - config, current_related_config_keys = AgentValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # opening_statement - config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # speech_to_text - config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # return retriever resource - config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/apps/config_validators/agent.py b/api/core/apps/config_validators/agent.py deleted file mode 100644 index b445aedbf8..0000000000 --- a/api/core/apps/config_validators/agent.py +++ /dev/null @@ -1,81 +0,0 @@ -import uuid - -from core.apps.config_validators.dataset import DatasetValidator -from core.entities.agent_entities import PlanningStrategy - -OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] - - -class AgentValidator: - @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: - """ - Validate and set defaults for agent feature - - :param tenant_id: tenant ID - :param config: app model config args - """ - if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } - - if not isinstance(config["agent_mode"], dict): - raise ValueError("agent_mode must be of object type") - - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False - - if not isinstance(config["agent_mode"]["enabled"], bool): - raise ValueError("enabled in agent_mode must be of boolean type") - - if not config["agent_mode"].get("strategy"): - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - - if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: - raise ValueError("strategy in agent_mode must be in the specified strategy list") - - if not config["agent_mode"].get("tools"): - config["agent_mode"]["tools"] = [] - - if not isinstance(config["agent_mode"]["tools"], list): - raise ValueError("tools in agent_mode must be a list of objects") - - for tool in config["agent_mode"]["tools"]: - key = list(tool.keys())[0] - if key in OLD_TOOLS: - # old style, use tool name as key - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - tool_item["enabled"] = False - - if not isinstance(tool_item["enabled"], bool): - raise ValueError("enabled in agent_mode.tools must be of boolean type") - - if key == "dataset": - if 'id' not in tool_item: - raise ValueError("id is required in dataset") - - try: - uuid.UUID(tool_item["id"]) - except ValueError: - raise ValueError("id in dataset must be of UUID type") - - if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): - raise ValueError("Dataset ID does not exist, please check your permission.") - else: - # latest style, use key-value pair - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - if "provider_type" not in tool: - raise ValueError("provider_type is required in agent_mode.tools") - if "provider_id" not in tool: - raise ValueError("provider_id is required in agent_mode.tools") - if "tool_name" not in tool: - raise ValueError("tool_name is required in agent_mode.tools") - if "tool_parameters" not in tool: - raise ValueError("tool_parameters is required in agent_mode.tools") - - return config, ["agent_mode"] diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index 1d25b8ab69..8a340a8b81 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -7,7 +7,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.callback_handler.entity.agent_loop import AgentLoop from core.entities.application_entities import ModelConfigEntity from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult @@ -22,7 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): raise_error: bool = True def __init__(self, model_config: ModelConfigEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, message: Message, message_chain: MessageChain) -> None: """Initialize callback handler.""" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 879c9df69d..e49a09d4c4 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,5 +1,5 @@ -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import InvokeFrom from core.rag.models.document import Document from extensions.ext_database import db @@ -10,7 +10,7 @@ from models.model import DatasetRetrieverResource class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, queue_manager: ApplicationQueueManager, + def __init__(self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, diff --git a/api/core/features/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py similarity index 98% rename from api/core/features/external_data_fetch.py rename to api/core/external_data_tool/external_data_fetch.py index ef37f05528..64c7d1e859 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -11,7 +11,7 @@ from core.external_data_tool.factory import ExternalDataToolFactory logger = logging.getLogger(__name__) -class ExternalDataFetchFeature: +class ExternalDataFetch: def fetch(self, tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index dd46aa27dc..01a8ea3a5d 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -13,7 +13,7 @@ from sqlalchemy.orm.exc import ObjectDeletedError from core.docstore.dataset_docstore import DatasetDocumentStore from core.errors.error import ProviderTokenNotInitError -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/llm_generator/__init__.py b/api/core/llm_generator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/generator/llm_generator.py b/api/core/llm_generator/llm_generator.py similarity index 93% rename from api/core/generator/llm_generator.py rename to api/core/llm_generator/llm_generator.py index 072b02dc94..6ce70df703 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -7,10 +7,10 @@ from core.model_manager import ModelManager from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser -from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser -from core.prompt.prompt_template import PromptTemplateParser -from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT class LLMGenerator: diff --git a/api/core/llm_generator/output_parser/__init__.py b/api/core/llm_generator/output_parser/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/prompt/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py similarity index 94% rename from api/core/prompt/output_parser/rule_config_generator.py rename to api/core/llm_generator/output_parser/rule_config_generator.py index 619555ce2e..b95653f69c 100644 --- a/api/core/prompt/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -2,7 +2,7 @@ from typing import Any from langchain.schema import BaseOutputParser, OutputParserException -from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE +from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE from libs.json_in_md_parser import parse_and_check_json_markdown diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py similarity index 87% rename from api/core/prompt/output_parser/suggested_questions_after_answer.py rename to api/core/llm_generator/output_parser/suggested_questions_after_answer.py index e37142ec91..ad30bcfa07 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -4,7 +4,7 @@ from typing import Any from langchain.schema import BaseOutputParser -from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/prompt/prompts.py b/api/core/llm_generator/prompts.py similarity index 100% rename from api/core/prompt/prompts.py rename to api/core/llm_generator/prompts.py diff --git a/api/core/features/moderation.py b/api/core/moderation/input_moderation.py similarity index 98% rename from api/core/features/moderation.py rename to api/core/moderation/input_moderation.py index a9d65f56e8..2129c58d8d 100644 --- a/api/core/features/moderation.py +++ b/api/core/moderation/input_moderation.py @@ -7,7 +7,7 @@ from core.moderation.factory import ModerationFactory logger = logging.getLogger(__name__) -class ModerationFeature: +class InputModeration: def check(self, app_id: str, tenant_id: str, app_orchestration_config_entity: AppOrchestrationConfigEntity, diff --git a/api/core/app_runner/moderation_handler.py b/api/core/moderation/output_moderation.py similarity index 97% rename from api/core/app_runner/moderation_handler.py rename to api/core/moderation/output_moderation.py index b2098344c8..749ee431e8 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/moderation/output_moderation.py @@ -6,7 +6,7 @@ from typing import Any, Optional from flask import Flask, current_app from pydantic import BaseModel -from core.application_queue_manager import PublishFrom +from core.app.app_queue_manager import PublishFrom from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory @@ -18,7 +18,7 @@ class ModerationRule(BaseModel): config: dict[str, Any] -class OutputModerationHandler(BaseModel): +class OutputModeration(BaseModel): DEFAULT_BUFFER_SIZE: int = 300 tenant_id: str diff --git a/api/core/prompt/__init__.py b/api/core/prompt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 7519971ce7..6178453920 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode diff --git a/api/core/prompt/prompt_templates/__init__.py b/api/core/prompt/prompt_templates/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/prompt/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py similarity index 100% rename from api/core/prompt/advanced_prompt_templates.py rename to api/core/prompt/prompt_templates/advanced_prompt_templates.py diff --git a/api/core/prompt/generate_prompts/baichuan_chat.json b/api/core/prompt/prompt_templates/baichuan_chat.json similarity index 100% rename from api/core/prompt/generate_prompts/baichuan_chat.json rename to api/core/prompt/prompt_templates/baichuan_chat.json diff --git a/api/core/prompt/generate_prompts/baichuan_completion.json b/api/core/prompt/prompt_templates/baichuan_completion.json similarity index 100% rename from api/core/prompt/generate_prompts/baichuan_completion.json rename to api/core/prompt/prompt_templates/baichuan_completion.json diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/prompt_templates/common_chat.json similarity index 100% rename from api/core/prompt/generate_prompts/common_chat.json rename to api/core/prompt/prompt_templates/common_chat.json diff --git a/api/core/prompt/generate_prompts/common_completion.json b/api/core/prompt/prompt_templates/common_completion.json similarity index 100% rename from api/core/prompt/generate_prompts/common_completion.json rename to api/core/prompt/prompt_templates/common_completion.json diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index fcae0dc786..f3a03b01c7 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from models.model import AppMode @@ -275,7 +275,7 @@ class SimplePromptTransform(PromptTransform): return prompt_file_contents[prompt_file_name] # Get the absolute path of the subdirectory - prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates') json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') # Open the JSON file and read its content diff --git a/api/core/prompt/utils/__init__.py b/api/core/prompt/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/utils/prompt_template_parser.py similarity index 100% rename from api/core/prompt/prompt_template.py rename to api/core/prompt/utils/prompt_template_parser.py diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 0d81c419d6..139bfe15f3 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -9,7 +9,7 @@ import pandas as pd from flask import Flask, current_app from werkzeug.datastructures import FileStorage -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/core/rag/retrieval/__init__.py b/api/core/rag/retrieval/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/retrieval/agent/__init__.py b/api/core/rag/retrieval/agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/features/dataset_retrieval/agent/agent_llm_callback.py b/api/core/rag/retrieval/agent/agent_llm_callback.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/agent_llm_callback.py rename to api/core/rag/retrieval/agent/agent_llm_callback.py diff --git a/api/core/features/dataset_retrieval/agent/fake_llm.py b/api/core/rag/retrieval/agent/fake_llm.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/fake_llm.py rename to api/core/rag/retrieval/agent/fake_llm.py diff --git a/api/core/features/dataset_retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py similarity index 91% rename from api/core/features/dataset_retrieval/agent/llm_chain.py rename to api/core/rag/retrieval/agent/llm_chain.py index e5155e15a0..d07ee0a582 100644 --- a/api/core/features/dataset_retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -7,8 +7,8 @@ from langchain.schema.language_model import BaseLanguageModel from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback -from core.features.dataset_retrieval.agent.fake_llm import FakeLLM +from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.rag.retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance diff --git a/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py similarity index 98% rename from api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py rename to api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 59923202fd..8cc2e29743 100644 --- a/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -12,7 +12,7 @@ from pydantic import root_validator from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.features.dataset_retrieval.agent.fake_llm import FakeLLM +from core.rag.retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool diff --git a/api/core/rag/retrieval/agent/output_parser/__init__.py b/api/core/rag/retrieval/agent/output_parser/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py b/api/core/rag/retrieval/agent/output_parser/structured_chat.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py rename to api/core/rag/retrieval/agent/output_parser/structured_chat.py diff --git a/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py similarity index 99% rename from api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py rename to api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py index e69302bfd6..4d7d33038b 100644 --- a/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py @@ -13,7 +13,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool from core.entities.application_entities import ModelConfigEntity -from core.features.dataset_retrieval.agent.llm_chain import LLMChain +from core.rag.retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. diff --git a/api/core/features/dataset_retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py similarity index 92% rename from api/core/features/dataset_retrieval/agent_based_dataset_executor.py rename to api/core/rag/retrieval/agent_based_dataset_executor.py index 588ccc91f5..f1ccf986e9 100644 --- a/api/core/features/dataset_retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -10,10 +10,10 @@ from pydantic import BaseModel, Extra from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages -from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback -from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent -from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent +from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent +from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser +from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.helper import moderation from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError diff --git a/api/core/features/dataset_retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py similarity index 98% rename from api/core/features/dataset_retrieval/dataset_retrieval.py rename to api/core/rag/retrieval/dataset_retrieval.py index 3e54d8644d..07682389d6 100644 --- a/api/core/features/dataset_retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -5,7 +5,7 @@ from langchain.tools import BaseTool from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity -from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor +from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -15,7 +15,7 @@ from extensions.ext_database import db from models.dataset import Dataset -class DatasetRetrievalFeature: +class DatasetRetrieval: def retrieve(self, tenant_id: str, model_config: ModelConfigEntity, config: DatasetEntity, diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 30128c4dca..629ed23613 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -4,7 +4,7 @@ from langchain.tools import BaseTool from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom -from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.tool.tool import Tool @@ -30,7 +30,7 @@ class DatasetRetrieverTool(Tool): if retrieve_config is None: return [] - feature = DatasetRetrievalFeature() + feature = DatasetRetrieval() # save original retrieve strategy, and set retrieve strategy to SINGLE # Agent only support SINGLE mode diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index 74dc8d5112..f5f3ba2540 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -1,4 +1,4 @@ -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from events.message_event import message_was_created from extensions.ext_database import db diff --git a/api/models/model.py b/api/models/model.py index 8d286d3482..235f77abc3 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -310,22 +310,28 @@ class AppModelConfig(db.Model): def from_model_config_dict(self, model_config: dict): self.opening_statement = model_config['opening_statement'] - self.suggested_questions = json.dumps(model_config['suggested_questions']) - self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) + self.suggested_questions = json.dumps(model_config['suggested_questions']) \ + if model_config.get('suggested_questions') else None + self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ + if model_config.get('suggested_questions_after_answer') else None self.speech_to_text = json.dumps(model_config['speech_to_text']) \ if model_config.get('speech_to_text') else None self.text_to_speech = json.dumps(model_config['text_to_speech']) \ if model_config.get('text_to_speech') else None - self.more_like_this = json.dumps(model_config['more_like_this']) + self.more_like_this = json.dumps(model_config['more_like_this']) \ + if model_config.get('more_like_this') else None self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ if model_config.get('sensitive_word_avoidance') else None self.external_data_tools = json.dumps(model_config['external_data_tools']) \ if model_config.get('external_data_tools') else None - self.model = json.dumps(model_config['model']) - self.user_input_form = json.dumps(model_config['user_input_form']) + self.model = json.dumps(model_config['model']) \ + if model_config.get('model') else None + self.user_input_form = json.dumps(model_config['user_input_form']) \ + if model_config.get('user_input_form') else None self.dataset_query_variable = model_config.get('dataset_query_variable') self.pre_prompt = model_config['pre_prompt'] - self.agent_mode = json.dumps(model_config['agent_mode']) + self.agent_mode = json.dumps(model_config['agent_mode']) \ + if model_config.get('agent_mode') else None self.retriever_resource = json.dumps(model_config['retriever_resource']) \ if model_config.get('retriever_resource') else None self.prompt_type = model_config.get('prompt_type', 'simple') diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 1e893e0eca..213df26222 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,7 +1,7 @@ import copy -from core.prompt.advanced_prompt_templates import ( +from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index c1e0ecebe8..789d74ed2c 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,8 +1,8 @@ -from core.apps.app_config_validators.advanced_chat_app import AdvancedChatAppConfigValidator -from core.apps.app_config_validators.agent_chat_app import AgentChatAppConfigValidator -from core.apps.app_config_validators.chat_app import ChatAppConfigValidator -from core.apps.app_config_validators.completion_app import CompletionAppConfigValidator -from core.apps.app_config_validators.workflow_app import WorkflowAppConfigValidator +from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator +from core.app.agent_chat.config_validator import AgentChatAppConfigValidator +from core.app.chat.config_validator import ChatAppConfigValidator +from core.app.completion.config_validator import CompletionAppConfigValidator +from core.app.workflow.config_validator import WorkflowAppConfigValidator from models.model import AppMode diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 9acd62b997..8a9639e521 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -4,8 +4,8 @@ from typing import Any, Union from sqlalchemy import and_ -from core.application_manager import ApplicationManager -from core.apps.config_validators.model import ModelValidator +from core.app.app_manager import AppManager +from core.app.validators.model_validator import ModelValidator from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db @@ -137,7 +137,7 @@ class CompletionService: user ) - application_manager = ApplicationManager() + application_manager = AppManager() return application_manager.generate( tenant_id=app_model.tenant_id, app_id=app_model.id, @@ -193,7 +193,7 @@ class CompletionService: message.files, app_model_config ) - application_manager = ApplicationManager() + application_manager = AppManager() return application_manager.generate( tenant_id=app_model.tenant_id, app_id=app_model.id, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ac3df380b2..1a0213799e 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,6 +1,6 @@ from typing import Optional, Union -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account diff --git a/api/services/message_service.py b/api/services/message_service.py index ad2ff60f6b..20918a8781 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,7 +1,7 @@ import json from typing import Optional, Union -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index fb6cf1fd5a..f384855e7a 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,7 +1,7 @@ import json from typing import Optional -from core.application_manager import ApplicationManager +from core.app.app_manager import AppManager from core.entities.application_entities import ( DatasetEntity, DatasetRetrieveConfigEntity, @@ -111,7 +111,7 @@ class WorkflowConverter: new_app_mode = self._get_new_app_mode(app_model) # convert app model config - application_manager = ApplicationManager() + application_manager = AppManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, app_model_config_dict=app_model_config.to_dict(), diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 95f1e30b44..69acb23681 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -8,7 +8,7 @@ from core.file.file_obj import FileObj, FileType, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation