refactor: migrate session.query to select API in core tools (#34814)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
aliworksx08 2026-04-09 00:44:49 -05:00 committed by GitHub
parent f5ea61e93e
commit b5acc8e392
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 57 deletions

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import httpx
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
from sqlalchemy import select
from configs import dify_config
from core.db.session_factory import session_factory
@ -166,13 +167,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == id,
)
.first()
)
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1))
if not tool_file:
return None
@ -190,13 +185,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with session_factory.create_session() as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.where(
MessageFile.id == id,
)
.first()
)
message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1))
# Check if message_file is not None
if message_file is not None:
@ -210,13 +199,7 @@ class ToolFileManager:
else:
tool_file_id = None
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == tool_file_id,
)
.first()
)
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
if not tool_file:
return None
@ -234,13 +217,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == tool_file_id,
)
.first()
)
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
if not tool_file:
return None, None

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from graphon.variables.input_entities import VariableEntity, VariableEntityType
from pydantic import Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController):
:param app: the app
:return: the tool
"""
workflow: Workflow | None = (
session.query(Workflow)
workflow: Workflow | None = session.scalar(
select(Workflow)
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
.limit(1)
)
if not workflow:
@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController):
return self.tools
with Session(db.engine, expire_on_commit=False) as session, session.begin():
db_provider: WorkflowToolProvider | None = (
session.query(WorkflowToolProvider)
db_provider: WorkflowToolProvider | None = session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == self.provider_id,
)
.first()
.limit(1)
)
if not db_provider:

View File

@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
# Act
with _patch_session_factory(session):
@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None:
manager = ToolFileManager()
tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain")
session = Mock()
session.query.return_value.where.return_value.first.return_value = tool_file
session.scalar.return_value = tool_file
# Act
with patch("core.tools.tool_file_manager.storage") as storage:
@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = None
second_query.where.return_value.first.return_value = None
session.query.side_effect = [first_query, second_query]
session.scalar.side_effect = [None, None]
# Act
with _patch_session_factory(session):
@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None:
manager = ToolFileManager()
message_file = SimpleNamespace(url=None)
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = message_file
second_query.where.return_value.first.return_value = None
session.query.side_effect = [first_query, second_query]
session.scalar.side_effect = [message_file, None]
# Act
with _patch_session_factory(session):
@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None:
message_file = SimpleNamespace(url="https://x/files/tools/tool123.png")
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = message_file
second_query.where.return_value.first.return_value = tool_file
session.query.side_effect = [first_query, second_query]
session.scalar.side_effect = [message_file, tool_file]
# Act
with patch("core.tools.tool_file_manager.storage") as storage:
@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
# Act
with _patch_session_factory(session):
@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None:
size=12,
)
session = Mock()
session.query.return_value.where.return_value.first.return_value = tool_file
session.scalar.return_value = tool_file
# Act
with patch("core.tools.tool_file_manager.storage") as storage:

View File

@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity():
controller = _controller()
session = Mock()
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
session.query.return_value.where.return_value.first.return_value = workflow
session.scalar.return_value = workflow
app = SimpleNamespace(id="app-1")
db_provider = SimpleNamespace(
id="provider-1",
@ -136,7 +136,7 @@ def test_from_db_builds_controller():
parameter_configurations=[],
)
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
session.get.side_effect = [app, user]
fake_cm = MagicMock()
fake_cm.__enter__.return_value = session
@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing():
mock_db.engine = object()
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
session_cls.return_value.__enter__.return_value = session
assert controller.get_tools("tenant-1") == []
@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing():
mock_db.engine = object()
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
session.get.return_value = None
session_cls.return_value.__enter__.return_value = session
with pytest.raises(ValueError, match="app not found"):