From 7f7156b325ef3179219e4e0d68d5003dc9f68a83 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 15 Aug 2025 13:09:54 +0800 Subject: [PATCH] refactor: improve session management in ToolManager --- api/core/tools/tool_manager.py | 6 +- .../tools/builtin_tools_manage_service.py | 81 +++++++++---------- 2 files changed, 44 insertions(+), 43 deletions(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2737bcfb16..4023560afe 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast import sqlalchemy as sa from pydantic import TypeAdapter +from sqlalchemy.orm import Session from yarl import URL import contexts @@ -617,8 +618,9 @@ class ToolManager: WHERE tenant_id = :tenant_id ORDER BY tenant_id, provider, is_default DESC, created_at DESC """ - ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] - return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() + with Session(db.engine).no_autoflush as session: + ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] + return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() @classmethod def list_providers_from_api( diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 862ac30780..fa0fbee8fd 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -546,54 +546,53 @@ class BuiltinToolManageService: # get all builtin providers provider_controllers = ToolManager.list_builtin_providers(tenant_id) - with db.session.no_autoflush: - # get all user added providers - db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) + # get all user added providers + db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) - # rewrite db_providers - for db_provider in db_providers: - db_provider.provider = str(ToolProviderID(db_provider.provider)) + # rewrite db_providers + for db_provider in db_providers: + db_provider.provider = str(ToolProviderID(db_provider.provider)) - # find provider - def find_provider(provider): - return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + # find provider + def find_provider(provider): + return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) - result: list[ToolProviderApiEntity] = [] + result: list[ToolProviderApiEntity] = [] - for provider_controller in provider_controllers: - try: - # handle include, exclude - if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore - data=provider_controller, - name_func=lambda x: x.identity.name, - ): - continue + for provider_controller in provider_controllers: + try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore + data=provider_controller, + name_func=lambda x: x.identity.name, + ): + continue - # convert provider controller to user provider - user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( - provider_controller=provider_controller, - db_provider=find_provider(provider_controller.entity.identity.name), - decrypt_credentials=True, + # convert provider controller to user provider + user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=find_provider(provider_controller.entity.identity.name), + decrypt_credentials=True, + ) + + # add icon + ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) + + tools = provider_controller.get_tools() + for tool in tools or []: + user_builtin_provider.tools.append( + ToolTransformService.convert_tool_entity_to_api_entity( + tenant_id=tenant_id, + tool=tool, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) ) - # add icon - ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) - - tools = provider_controller.get_tools() - for tool in tools or []: - user_builtin_provider.tools.append( - ToolTransformService.convert_tool_entity_to_api_entity( - tenant_id=tenant_id, - tool=tool, - labels=ToolLabelManager.get_tool_labels(provider_controller), - ) - ) - - result.append(user_builtin_provider) - except Exception as e: - raise e + result.append(user_builtin_provider) + except Exception as e: + raise e return BuiltinToolProviderSort.sort(result)