From b5acc8e3925000ff3755474896a84a7082541e49 Mon Sep 17 00:00:00 2001 From: aliworksx08 <57456290+aliworksx08@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:44:49 -0500 Subject: [PATCH] refactor: migrate session.query to select API in core tools (#34814) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/tools/tool_file_manager.py | 33 +++---------------- api/core/tools/workflow_as_tool/provider.py | 13 ++++---- .../core/tools/test_tool_file_manager.py | 26 ++++----------- .../tools/workflow_as_tool/test_provider.py | 8 ++--- 4 files changed, 23 insertions(+), 57 deletions(-) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 7ac29cf069..a59d167a0a 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -11,6 +11,7 @@ from uuid import uuid4 import httpx from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type +from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory @@ -166,13 +167,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ with session_factory.create_session() as session: - tool_file: ToolFile | None = ( - session.query(ToolFile) - .where( - ToolFile.id == id, - ) - .first() - ) + tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1)) if not tool_file: return None @@ -190,13 +185,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ with session_factory.create_session() as session: - message_file: MessageFile | None = ( - session.query(MessageFile) - .where( - MessageFile.id == id, - ) - .first() - ) + message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1)) # Check if message_file is not None if message_file is not None: @@ -210,13 +199,7 @@ class ToolFileManager: else: tool_file_id = None - tool_file: ToolFile | None = ( - session.query(ToolFile) - .where( - ToolFile.id == tool_file_id, - ) - .first() - ) + tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1)) if not tool_file: return None @@ -234,13 +217,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ with session_factory.create_session() as session: - tool_file: ToolFile | None = ( - session.query(ToolFile) - .where( - ToolFile.id == tool_file_id, - ) - .first() - ) + tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1)) if not tool_file: return None, None diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index f48b24be30..a01004448a 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field +from sqlalchemy import select from sqlalchemy.orm import Session from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController): :param app: the app :return: the tool """ - workflow: Workflow | None = ( - session.query(Workflow) + workflow: Workflow | None = session.scalar( + select(Workflow) .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) - .first() + .limit(1) ) if not workflow: @@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController): return self.tools with Session(db.engine, expire_on_commit=False) as session, session.begin(): - db_provider: WorkflowToolProvider | None = ( - session.query(WorkflowToolProvider) + db_provider: WorkflowToolProvider | None = session.scalar( + select(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == self.provider_id, ) - .first() + .limit(1) ) if not db_provider: diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index 7fcebde3c5..2889cb9db1 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None: # Arrange manager = ToolFileManager() session = Mock() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None # Act with _patch_session_factory(session): @@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None: manager = ToolFileManager() tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain") session = Mock() - session.query.return_value.where.return_value.first.return_value = tool_file + session.scalar.return_value = tool_file # Act with patch("core.tools.tool_file_manager.storage") as storage: @@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None: # Arrange manager = ToolFileManager() session = Mock() - first_query = Mock() - second_query = Mock() - first_query.where.return_value.first.return_value = None - second_query.where.return_value.first.return_value = None - session.query.side_effect = [first_query, second_query] + session.scalar.side_effect = [None, None] # Act with _patch_session_factory(session): @@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None: manager = ToolFileManager() message_file = SimpleNamespace(url=None) session = Mock() - first_query = Mock() - second_query = Mock() - first_query.where.return_value.first.return_value = message_file - second_query.where.return_value.first.return_value = None - session.query.side_effect = [first_query, second_query] + session.scalar.side_effect = [message_file, None] # Act with _patch_session_factory(session): @@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None: message_file = SimpleNamespace(url="https://x/files/tools/tool123.png") tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") session = Mock() - first_query = Mock() - second_query = Mock() - first_query.where.return_value.first.return_value = message_file - second_query.where.return_value.first.return_value = tool_file - session.query.side_effect = [first_query, second_query] + session.scalar.side_effect = [message_file, tool_file] # Act with patch("core.tools.tool_file_manager.storage") as storage: @@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None: # Arrange manager = ToolFileManager() session = Mock() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None # Act with _patch_session_factory(session): @@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None: size=12, ) session = Mock() - session.query.return_value.where.return_value.first.return_value = tool_file + session.scalar.return_value = tool_file # Act with patch("core.tools.tool_file_manager.storage") as storage: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index 2607861b59..4767480a5a 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity(): controller = _controller() session = Mock() workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={}) - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow app = SimpleNamespace(id="app-1") db_provider = SimpleNamespace( id="provider-1", @@ -136,7 +136,7 @@ def test_from_db_builds_controller(): parameter_configurations=[], ) session = _mock_session_with_begin() - session.query.return_value.where.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider session.get.side_effect = [app, user] fake_cm = MagicMock() fake_cm.__enter__.return_value = session @@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing(): mock_db.engine = object() with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: session = _mock_session_with_begin() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None session_cls.return_value.__enter__.return_value = session assert controller.get_tools("tenant-1") == [] @@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing(): mock_db.engine = object() with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: session = _mock_session_with_begin() - session.query.return_value.where.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider session.get.return_value = None session_cls.return_value.__enter__.return_value = session with pytest.raises(ValueError, match="app not found"):