refactor: use sessionmaker in tool_label_manager.py (#34895)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
HeYinKazune 2026-04-14 16:23:29 +09:00 committed by GitHub
parent 711fe6ba2c
commit f7c6270f74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 209 additions and 65 deletions

View File

@ -1,4 +1,5 @@
from sqlalchemy import delete, select
from sqlalchemy.orm import Session, sessionmaker
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -19,10 +20,18 @@ class ToolLabelManager:
return list(set(tool_labels))
@classmethod
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
def update_tool_labels(
cls, controller: ToolProviderController, labels: list[str], session: Session | None = None
) -> None:
"""
Update tool labels
:param controller: tool provider controller
:param labels: list of tool labels
:param session: database session, if None, a new session will be created
:return: None
"""
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
@ -30,26 +39,46 @@ class ToolLabelManager:
else:
raise ValueError("Unsupported tool type")
if session is not None:
cls._update_tool_labels_logics(session, provider_id, controller, labels)
else:
with sessionmaker(db.engine).begin() as _session:
cls._update_tool_labels_logics(_session, provider_id, controller, labels)
@classmethod
def _update_tool_labels_logics(
cls, session: Session, provider_id: str, controller: ToolProviderController, labels: list[str]
) -> None:
"""
Update tool labels logics
:param session: database session
:param provider_id: tool provider ID
:param controller: tool provider controller
:param labels: list of tool labels
:return: None
"""
# delete old labels
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
_ = session.execute(
delete(ToolLabelBinding).where(
ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type
)
)
# insert new labels
for label in labels:
db.session.add(
ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type,
label_name=label,
)
)
db.session.commit()
session.add(ToolLabelBinding(tool_id=provider_id, tool_type=controller.provider_type, label_name=label))
@classmethod
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
"""
Get tool labels
:param controller: tool provider controller
:return: list of tool labels (str)
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
@ -60,9 +89,11 @@ class ToolLabelManager:
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type,
)
labels = db.session.scalars(stmt).all()
return list(labels)
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
labels: list[str] = list(_session.scalars(stmt).all())
return labels
@classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
@ -78,16 +109,22 @@ class ToolLabelManager:
if not tool_providers:
return {}
provider_ids: list[str] = []
provider_types: set[str] = set()
for controller in tool_providers:
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
raise ValueError("Unsupported tool type")
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id)
provider_types.add(controller.provider_type)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
labels: list[ToolLabelBinding] = []
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
stmt = select(ToolLabelBinding).where(
ToolLabelBinding.tool_id.in_(provider_ids), ToolLabelBinding.tool_type.in_(list(provider_types))
)
labels = list(_session.scalars(stmt).all())
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}

View File

@ -139,62 +139,82 @@ class WorkflowToolManageService:
:param labels: labels
:return: the updated tool
"""
# check if the name is unique
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.limit(1)
)
existing_workflow_tool_provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
# query if the name exists for other tools
existing_workflow_tool_provider = _session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.limit(1)
)
# if the name exists raise error
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
# query the workflow tool provider
workflow_tool_provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
workflow_tool_provider = _session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
# if not found raise error
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
)
# query the app
app: App | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
app = _session.scalar(
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
)
# if not found raise error
if app is None:
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
# query the workflow
workflow: Workflow | None = app.workflow
# if not found raise error
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
# check if workflow configuration is synced
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()
with sessionmaker(db.engine).begin() as _session:
_session.add(workflow_tool_provider)
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
# update workflow tool provider
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()
db.session.commit()
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
labels,
session=_session,
)
return {"result": "success"}

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from types import SimpleNamespace
from typing import Any
from unittest.mock import PropertyMock, patch
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@ -12,11 +12,13 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create a mock class for testing abstract/base classes
class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
return None
# Factory function to create a "lightweight" controller for testing
def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController:
controller = object.__new__(ApiToolProviderController)
controller.provider_id = provider_id
@ -29,6 +31,7 @@ def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderContr
return controller
# Test pure logic: filtering and deduplication
def test_tool_label_manager_filter_tool_labels():
filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"])
assert set(filtered) == {"search", "news"}
@ -36,22 +39,68 @@ def test_tool_label_manager_filter_tool_labels():
def test_tool_label_manager_update_tool_labels_db():
"""
Test the database update logic for tool labels.
Focus: Verify that labels are filtered, de-duplicated, and safely handled within a database session.
"""
# 1. Setup expected data from the controller
controller = _api_controller("api-1")
with patch("core.tools.tool_label_manager.db") as mock_db:
expected_id = controller.provider_id
expected_type = controller.provider_type
# 2. Patching External Dependencies
# - We patch 'db' to prevent Flask from trying to access a real database.
# - We patch 'sessionmaker' to intercept and control the creation of SQLAlchemy sessions.
with (
patch("core.tools.tool_label_manager.db"),
patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker,
):
# 3. Constructing the "Mocking Chain"
# In the business logic, we use: with sessionmaker(db.engine).begin() as _session:
# We need to link our 'mock_session' to the end of this complex context manager chain:
# Step A: sessionmaker(db.engine) -> returns an object (mock_sessionmaker.return_value)
# Step B: .begin() -> returns a context manager (begin.return_value)
# Step C: with ... as _session: -> calls __enter__(), and _session gets the __enter__.return_value
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
# 4. Trigger the logic under test
# Input: ["search", "search", "invalid"]
# Logic:
# - "invalid" should be filtered out (not in default_tool_label_name_list).
# - The duplicate "search" should be merged (unique labels).
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
mock_db.session.execute.assert_called_once()
# only one valid unique label should be inserted.
assert mock_db.session.add.call_count == 1
mock_db.session.commit.assert_called_once()
# 5. Behavior Assertion: DELETE operation
# Verify that the manager first attempts to clear existing labels for this specific tool.
# This ensures the update is idempotent.
mock_session.execute.assert_called_once()
# 6. Behavior Assertion: INSERT operation
# Verify that only ONE valid label ("search") was added after filtering and deduplication.
# If call_count == 1, it proves filter_tool_labels() worked as expected.
assert mock_session.add.call_count == 1
# 7. State Assertion: Data Integrity & Isolation
# Inspect the actual object passed to session.add() to ensure it has correct properties.
# This confirms that the data isolation (tool_id + tool_type) we refactored is active.
call_args = mock_session.add.call_args
added_label = call_args[0][0] # Retrieve the ToolLabelBinding instance
assert added_label.label_name == "search", "The label name should be 'search' after filtering."
assert added_label.tool_id == expected_id, "The tool_id must match the provider_id for correct binding."
assert added_label.tool_type == expected_type, "Isolation failed: tool_type must be verified during update."
# Test error handling
def test_tool_label_manager_update_tool_labels_unsupported():
with pytest.raises(ValueError, match="Unsupported tool type"):
ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type]
# Test retrieval logic
def test_tool_label_manager_get_tool_labels_for_builtin_and_db():
# Mocking a property (@property) using PropertyMock
with patch.object(
_ConcreteBuiltinToolProviderController,
"tool_labels",
@ -62,29 +111,67 @@ def test_tool_label_manager_get_tool_labels_for_builtin_and_db():
assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"]
api = _api_controller("api-1")
with patch("core.tools.tool_label_manager.db") as mock_db:
mock_db.session.scalars.return_value.all.return_value = ["search", "news"]
labels = ToolLabelManager.get_tool_labels(api)
assert labels == ["search", "news"]
with (
patch("core.tools.tool_label_manager.db"),
patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker,
):
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
# Inject mock data into the query result: session.scalars(stmt).all()
mock_session.scalars.return_value.all.return_value = ["search", "news"]
labels = ToolLabelManager.get_tool_labels(api)
assert labels == ["search", "news"]
def test_tool_label_manager_get_tool_labels_unsupported():
"""
Negative Test: Ensure get_tool_labels raises ValueError for unsupported controller types.
This protects the internal API contract against accidental regressions during refactoring.
"""
# Passing a generic object() which doesn't match Api, Workflow, or Builtin controllers.
with pytest.raises(ValueError, match="Unsupported tool type"):
ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type]
# Test batch processing and mapping
def test_tool_label_manager_get_tools_labels_batch():
assert ToolLabelManager.get_tools_labels([]) == {}
api = _api_controller("api-1")
wf = _workflow_controller("wf-1")
# SimpleNamespace is a quick way to simulate SQLAlchemy row objects
records = [
SimpleNamespace(tool_id="api-1", label_name="search"),
SimpleNamespace(tool_id="api-1", label_name="news"),
SimpleNamespace(tool_id="wf-1", label_name="utilities"),
]
with patch("core.tools.tool_label_manager.db") as mock_db:
mock_db.session.scalars.return_value.all.return_value = records
with (
patch("core.tools.tool_label_manager.db"),
patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker,
):
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
# Simulating the batch query result
mock_session.scalars.return_value.all.return_value = records
labels = ToolLabelManager.get_tools_labels([api, wf])
# Verify the final dictionary mapping
assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]}
def test_tool_label_manager_get_tools_labels_unsupported():
"""
Negative Test: Ensure get_tools_labels raises ValueError if the list contains
unsupported controller types, even alongside valid ones.
"""
api = _api_controller("api-1")
# Passing a list with one valid controller and one invalid object()
with pytest.raises(ValueError, match="Unsupported tool type"):
ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item]