From b75cd2514e81b8d5e923abf795e66bd2599d8d71 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 16:29:55 +0800 Subject: [PATCH] optimize db connections --- api/controllers/console/app/app.py | 62 ++++---- api/controllers/console/app/model_config.py | 135 +++++++++--------- .../easy_ui_based_app/dataset/manager.py | 3 +- .../app/apps/advanced_chat/app_generator.py | 2 - api/core/app/apps/advanced_chat/app_runner.py | 2 + api/core/app/apps/agent_chat/app_generator.py | 2 +- api/core/app/apps/agent_chat/app_runner.py | 4 +- api/core/app/apps/chat/app_generator.py | 2 +- api/core/app/apps/completion/app_generator.py | 2 +- api/core/app/apps/completion/app_runner.py | 2 + .../app/apps/message_based_app_generator.py | 2 - api/core/app/apps/workflow/app_runner.py | 2 + api/core/tools/tool_manager.py | 2 +- api/models/model.py | 2 +- 14 files changed, 116 insertions(+), 108 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index ef3c3bd6ae..baa44e5ba8 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,3 +1,5 @@ +import json + from flask_login import current_user from flask_restful import Resource, inputs, marshal_with, reqparse from werkzeug.exceptions import Forbidden, BadRequest @@ -6,6 +8,8 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from core.agent.entities import AgentToolEntity +from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, @@ -14,10 +18,8 @@ from fields.app_fields import ( from libs.login import login_required from services.app_service import AppService from models.model import App, AppModelConfig, AppMode -from services.workflow_service import WorkflowService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager -from core.entities.application_entities import AgentToolEntity ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow'] @@ -108,36 +110,38 @@ class AppApi(Resource): def get(self, app_model): """Get app detail""" # get original app model config - model_config: AppModelConfig = app_model.app_model_config - agent_mode = model_config.agent_mode_dict - # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: - agent_tool_entity = AgentToolEntity(**tool) - # get tool - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) + if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + model_config: AppModelConfig = app_model.app_model_config + agent_mode = model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + # get tool + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - masked_parameter = {} + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + masked_parameter = {} - # override tool parameters - tool['tool_parameters'] = masked_parameter + # override tool parameters + tool['tool_parameters'] = masked_parameter - # override agent mode - model_config.agent_mode = json.dumps(agent_mode) + # override agent mode + model_config.agent_mode = json.dumps(agent_mode) + db.session.commit() return app_model diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 9d3cbd8d97..94b07761da 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -8,7 +8,7 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AgentToolEntity +from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated @@ -38,81 +38,82 @@ class ModelConfigResource(Resource): ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) - # get original app model config - original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( - AppModelConfig.id == app.app_model_config_id - ).first() - agent_mode = original_app_model_config.agent_mode_dict - # decrypt agent tool parameters if it's secret-input - parameter_map = {} - masked_parameter_map = {} - tool_map = {} - for tool in agent_mode.get('tools') or []: - agent_tool_entity = AgentToolEntity(**tool) - # get tool - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - parameters = {} - masked_parameter = {} - - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' - masked_parameter_map[key] = masked_parameter - parameter_map[key] = parameters - tool_map[key] = tool_runtime - - # encrypt agent tool parameters if it's secret-input - agent_mode = new_app_model_config.agent_mode_dict - for tool in agent_mode.get('tools') or []: - agent_tool_entity = AgentToolEntity(**tool) - - # get tool - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' - if key in tool_map: - tool_runtime = tool_map[key] - else: + if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + # get original app model config + original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( + AppModelConfig.id == app_model.app_model_config_id + ).first() + agent_mode = original_app_model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + parameter_map = {} + masked_parameter_map = {} + tool_map = {} + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + # get tool tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=current_user.current_tenant_id, agent_tool=agent_tool_entity, agent_callback=None ) - - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - manager.delete_tool_parameters_cache() + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) - # override parameters if it equals to masked parameters - if agent_tool_entity.tool_parameters: - if key not in masked_parameter_map: - continue + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + parameters = {} + masked_parameter = {} - if agent_tool_entity.tool_parameters == masked_parameter_map[key]: - agent_tool_entity.tool_parameters = parameter_map[key] + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + masked_parameter_map[key] = masked_parameter + parameter_map[key] = parameters + tool_map[key] = tool_runtime - # encrypt parameters - if agent_tool_entity.tool_parameters: - tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + # encrypt agent tool parameters if it's secret-input + agent_mode = new_app_model_config.agent_mode_dict + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) - # update app model config - new_app_model_config.agent_mode = json.dumps(agent_mode) + # get tool + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + if key in tool_map: + tool_runtime = tool_map[key] + else: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + manager.delete_tool_parameters_cache() + + # override parameters if it equals to masked parameters + if agent_tool_entity.tool_parameters: + if key not in masked_parameter_map: + continue + + if agent_tool_entity.tool_parameters == masked_parameter_map[key]: + agent_tool_entity.tool_parameters = parameter_map[key] + + # encrypt parameters + if agent_tool_entity.tool_parameters: + tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + + # update app model config + new_app_model_config.agent_mode = json.dumps(agent_mode) db.session.add(new_app_model_config) db.session.flush() diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 4c08f62d27..c10aa98dba 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -123,7 +123,8 @@ class DatasetConfigManager: if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get("datasets") + need_manual_query_datasets = (config.get("dataset_configs") + and config["dataset_configs"].get("datasets", {}).get("datasets")) if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index b1bc839966..1a33a3230b 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -153,8 +153,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) - db.session.close() - # chatbot app runner = AdvancedChatAppRunner() runner.run( diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 3279e00355..c42620b92f 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -72,6 +72,8 @@ class AdvancedChatAppRunner(AppRunner): ): return + db.session.close() + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 700a340c96..cc9b0785f5 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -193,4 +193,4 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2e142c63f1..0dc8a1e218 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -201,8 +201,8 @@ class AgentChatAppRunner(AppRunner): if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - db.session.refresh(conversation) - db.session.refresh(message) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + message = db.session.query(Message).filter(Message.id == message.id).first() db.session.close() # start agent runner diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 317d045c04..58287ba658 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -193,4 +193,4 @@ class ChatAppGenerator(MessageBasedAppGenerator): logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index b948938aac..fb62469720 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -182,7 +182,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def generate_more_like_this(self, app_model: App, message_id: str, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 04adf77be5..649d73d961 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -160,6 +160,8 @@ class CompletionAppRunner(AppRunner): model=application_generate_entity.model_config.model ) + db.session.close() + invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=application_generate_entity.model_config.parameters, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 5d0f4bc63a..5e676c40bd 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -64,8 +64,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): else: logger.exception(e) raise e - finally: - db.session.remove() def _get_conversation_by_user(self, app_model: App, conversation_id: str, user: Union[Account, EndUser]) -> Conversation: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 59a385cb38..2d032fcdcb 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -57,6 +57,8 @@ class WorkflowAppRunner: ): return + db.session.close() + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2ac8f27bab..24b2f287c1 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,8 +5,8 @@ import mimetypes from os import listdir, path from typing import Any, Union +from core.agent.entities import AgentToolEntity from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler -from core.entities.application_entities import AgentToolEntity from core.model_runtime.entities.message_entities import PromptMessage from core.provider_manager import ProviderManager from core.tools.entities.common_entities import I18nObject diff --git a/api/models/model.py b/api/models/model.py index f891c68ed1..a7ac32f8ff 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -322,7 +322,7 @@ class AppModelConfig(db.Model): } def from_model_config_dict(self, model_config: dict): - self.opening_statement = model_config['opening_statement'] + self.opening_statement = model_config.get('opening_statement') self.suggested_questions = json.dumps(model_config['suggested_questions']) \ if model_config.get('suggested_questions') else None self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \