From de3978fdbb7a0b41883afd493af4abee718f651f Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 13:19:17 +0800 Subject: [PATCH] optimize db connections --- api/config.py | 2 ++ api/core/app/apps/advanced_chat/app_generator.py | 13 ++++++++++--- .../apps/advanced_chat/generate_task_pipeline.py | 2 ++ api/core/app/apps/message_based_app_generator.py | 8 ++++++++ .../app/apps/workflow/generate_task_pipeline.py | 2 ++ .../apps/workflow_based_generate_task_pipeline.py | 11 +++++++++++ api/core/workflow/workflow_engine_manager.py | 5 +++++ 7 files changed, 40 insertions(+), 3 deletions(-) diff --git a/api/config.py b/api/config.py index a6bc731b82..a4ec6fcef9 100644 --- a/api/config.py +++ b/api/config.py @@ -27,6 +27,7 @@ DEFAULTS = { 'CHECK_UPDATE_URL': 'https://updates.dify.ai', 'DEPLOY_ENV': 'PRODUCTION', 'SQLALCHEMY_POOL_SIZE': 30, + 'SQLALCHEMY_MAX_OVERFLOW': 10, 'SQLALCHEMY_POOL_RECYCLE': 3600, 'SQLALCHEMY_ECHO': 'False', 'SENTRY_TRACES_SAMPLE_RATE': 1.0, @@ -148,6 +149,7 @@ class Config: self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}" self.SQLALCHEMY_ENGINE_OPTIONS = { 'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')), + 'max_overflow': int(get_env('SQLALCHEMY_MAX_OVERFLOW')), 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')) } diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index a0f197ec37..50b561dfe6 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -95,6 +95,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): extras=extras ) + workflow = db.session.query(Workflow).filter(Workflow.id == workflow.id).first() + user = (db.session.query(Account).filter(Account.id == user.id).first() + if isinstance(user, Account) + else db.session.query(EndUser).filter(EndUser.id == user.id).first()) + db.session.close() + # init generate records ( conversation, @@ -153,6 +159,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + db.session.close() + # chatbot app runner = AdvancedChatAppRunner() runner.run( @@ -177,7 +185,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow, @@ -198,6 +206,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :return: """ # init generate task pipeline + generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, workflow=workflow, @@ -216,5 +225,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception(e) raise e - # finally: - # db.session.remove() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 048b429304..6991b8704a 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -122,6 +122,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._output_moderation_handler = self._init_output_moderation() self._stream = stream + db.session.close() + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 0e76c96ff7..be7538ea07 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -177,6 +177,9 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(conversation) db.session.commit() + conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + db.session.close() + message = Message( app_id=app_config.app_id, model_provider=model_provider, @@ -204,6 +207,9 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(message) db.session.commit() + message = db.session.query(Message).filter(Message.id == message.id).first() + db.session.close() + for file in application_generate_entity.files: message_file = MessageFile( message_id=message.id, @@ -218,6 +224,8 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(message_file) db.session.commit() + db.session.close() + return conversation, message def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 26e4769fa6..2c2f941bee 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -99,6 +99,8 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._output_moderation_handler = self._init_output_moderation() self._stream = stream + db.session.close() + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py index 3e9a7b9e1f..640159bae3 100644 --- a/api/core/app/apps/workflow_based_generate_task_pipeline.py +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -61,6 +61,9 @@ class WorkflowBasedGenerateTaskPipeline: db.session.add(workflow_run) db.session.commit() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run.id).first() + db.session.close() + return workflow_run def _workflow_run_success(self, workflow_run: WorkflowRun, @@ -85,6 +88,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_run @@ -112,6 +116,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_run @@ -151,6 +156,10 @@ class WorkflowBasedGenerateTaskPipeline: db.session.add(workflow_node_execution) db.session.commit() + workflow_node_execution = (db.session.query(WorkflowNodeExecution) + .filter(WorkflowNodeExecution.id == workflow_node_execution.id).first()) + db.session.close() + return workflow_node_execution def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, @@ -179,6 +188,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_node_execution @@ -198,5 +208,6 @@ class WorkflowBasedGenerateTaskPipeline: workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_node_execution diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 0b96717de7..50f79df1f0 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -19,6 +19,7 @@ from core.workflow.nodes.start.start_node import StartNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode +from extensions.ext_database import db from models.workflow import ( Workflow, WorkflowNodeExecutionStatus, @@ -282,6 +283,8 @@ class WorkflowEngineManager: predecessor_node_id=predecessor_node.node_id if predecessor_node else None ) + db.session.close() + workflow_nodes_and_result = WorkflowNodeAndResult( node=node, result=None @@ -339,6 +342,8 @@ class WorkflowEngineManager: if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) + db.session.close() + def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, node: BaseNode, node_run_result: NodeRunResult) -> None: