mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor: migrate session.query to select API in core tools (#34814)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f5ea61e93e
commit
b5acc8e392
@ -11,6 +11,7 @@ from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@ -166,13 +167,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -190,13 +185,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.where(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1))
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
@ -210,13 +199,7 @@ class ToolFileManager:
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -234,13 +217,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None, None
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Mapping
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow | None = (
|
||||
session.query(Workflow)
|
||||
workflow: Workflow | None = session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
return self.tools
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
db_provider: WorkflowToolProvider | None = (
|
||||
session.query(WorkflowToolProvider)
|
||||
db_provider: WorkflowToolProvider | None = session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
|
||||
@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None:
|
||||
manager = ToolFileManager()
|
||||
tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain")
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
session.scalar.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = None
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [None, None]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None:
|
||||
manager = ToolFileManager()
|
||||
message_file = SimpleNamespace(url=None)
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [message_file, None]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None:
|
||||
message_file = SimpleNamespace(url="https://x/files/tools/tool123.png")
|
||||
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = tool_file
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [message_file, tool_file]
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None:
|
||||
size=12,
|
||||
)
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
session.scalar.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
|
||||
@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity():
|
||||
controller = _controller()
|
||||
session = Mock()
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
app = SimpleNamespace(id="app-1")
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
@ -136,7 +136,7 @@ def test_from_db_builds_controller():
|
||||
parameter_configurations=[],
|
||||
)
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
session.get.side_effect = [app, user]
|
||||
fake_cm = MagicMock()
|
||||
fake_cm.__enter__.return_value = session
|
||||
@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing():
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
|
||||
assert controller.get_tools("tenant-1") == []
|
||||
@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing():
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
session.get.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user