mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 02:16:57 +08:00
refactor: use sessionmaker in tool_label_manager.py (#34895)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
711fe6ba2c
commit
f7c6270f74
@ -1,4 +1,5 @@
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@ -19,10 +20,18 @@ class ToolLabelManager:
|
||||
return list(set(tool_labels))
|
||||
|
||||
@classmethod
|
||||
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
|
||||
def update_tool_labels(
|
||||
cls, controller: ToolProviderController, labels: list[str], session: Session | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update tool labels
|
||||
|
||||
:param controller: tool provider controller
|
||||
:param labels: list of tool labels
|
||||
:param session: database session, if None, a new session will be created
|
||||
:return: None
|
||||
"""
|
||||
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
@ -30,26 +39,46 @@ class ToolLabelManager:
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
if session is not None:
|
||||
cls._update_tool_labels_logics(session, provider_id, controller, labels)
|
||||
else:
|
||||
with sessionmaker(db.engine).begin() as _session:
|
||||
cls._update_tool_labels_logics(_session, provider_id, controller, labels)
|
||||
|
||||
@classmethod
|
||||
def _update_tool_labels_logics(
|
||||
cls, session: Session, provider_id: str, controller: ToolProviderController, labels: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
Update tool labels logics
|
||||
|
||||
:param session: database session
|
||||
:param provider_id: tool provider ID
|
||||
:param controller: tool provider controller
|
||||
:param labels: list of tool labels
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# delete old labels
|
||||
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
|
||||
_ = session.execute(
|
||||
delete(ToolLabelBinding).where(
|
||||
ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type
|
||||
)
|
||||
)
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
session.add(ToolLabelBinding(tool_id=provider_id, tool_type=controller.provider_type, label_name=label))
|
||||
|
||||
@classmethod
|
||||
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
|
||||
"""
|
||||
Get tool labels
|
||||
|
||||
:param controller: tool provider controller
|
||||
:return: list of tool labels (str)
|
||||
"""
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
@ -60,9 +89,11 @@ class ToolLabelManager:
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
return list(labels)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
labels: list[str] = list(_session.scalars(stmt).all())
|
||||
|
||||
return labels
|
||||
|
||||
@classmethod
|
||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
||||
@ -78,16 +109,22 @@ class ToolLabelManager:
|
||||
if not tool_providers:
|
||||
return {}
|
||||
|
||||
provider_ids: list[str] = []
|
||||
provider_types: set[str] = set()
|
||||
|
||||
for controller in tool_providers:
|
||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id)
|
||||
provider_types.add(controller.provider_type)
|
||||
|
||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||
labels: list[ToolLabelBinding] = []
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
stmt = select(ToolLabelBinding).where(
|
||||
ToolLabelBinding.tool_id.in_(provider_ids), ToolLabelBinding.tool_type.in_(list(provider_types))
|
||||
)
|
||||
labels = list(_session.scalars(stmt).all())
|
||||
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
|
||||
@ -139,62 +139,82 @@ class WorkflowToolManageService:
|
||||
:param labels: labels
|
||||
:return: the updated tool
|
||||
"""
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
existing_workflow_tool_provider: WorkflowToolProvider | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
# query if the name exists for other tools
|
||||
existing_workflow_tool_provider = _session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# if the name exists raise error
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.limit(1)
|
||||
)
|
||||
# query the workflow tool provider
|
||||
workflow_tool_provider: WorkflowToolProvider | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
workflow_tool_provider = _session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# if not found raise error
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
# query the app
|
||||
app: App | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
app = _session.scalar(
|
||||
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
# if not found raise error
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
# query the workflow
|
||||
workflow: Workflow | None = app.workflow
|
||||
|
||||
# if not found raise error
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
# check if workflow configuration is synced
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
|
||||
|
||||
workflow_tool_provider.name = name
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
with sessionmaker(db.engine).begin() as _session:
|
||||
_session.add(workflow_tool_provider)
|
||||
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
# update workflow tool provider
|
||||
workflow_tool_provider.name = name
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
|
||||
db.session.commit()
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
|
||||
labels,
|
||||
session=_session,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import PropertyMock, patch
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -12,11 +12,13 @@ from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
|
||||
# Create a mock class for testing abstract/base classes
|
||||
class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
return None
|
||||
|
||||
|
||||
# Factory function to create a "lightweight" controller for testing
|
||||
def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController:
|
||||
controller = object.__new__(ApiToolProviderController)
|
||||
controller.provider_id = provider_id
|
||||
@ -29,6 +31,7 @@ def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderContr
|
||||
return controller
|
||||
|
||||
|
||||
# Test pure logic: filtering and deduplication
|
||||
def test_tool_label_manager_filter_tool_labels():
|
||||
filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"])
|
||||
assert set(filtered) == {"search", "news"}
|
||||
@ -36,22 +39,68 @@ def test_tool_label_manager_filter_tool_labels():
|
||||
|
||||
|
||||
def test_tool_label_manager_update_tool_labels_db():
|
||||
"""
|
||||
Test the database update logic for tool labels.
|
||||
Focus: Verify that labels are filtered, de-duplicated, and safely handled within a database session.
|
||||
"""
|
||||
# 1. Setup expected data from the controller
|
||||
controller = _api_controller("api-1")
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
expected_id = controller.provider_id
|
||||
expected_type = controller.provider_type
|
||||
|
||||
# 2. Patching External Dependencies
|
||||
# - We patch 'db' to prevent Flask from trying to access a real database.
|
||||
# - We patch 'sessionmaker' to intercept and control the creation of SQLAlchemy sessions.
|
||||
with (
|
||||
patch("core.tools.tool_label_manager.db"),
|
||||
patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker,
|
||||
):
|
||||
# 3. Constructing the "Mocking Chain"
|
||||
# In the business logic, we use: with sessionmaker(db.engine).begin() as _session:
|
||||
# We need to link our 'mock_session' to the end of this complex context manager chain:
|
||||
# Step A: sessionmaker(db.engine) -> returns an object (mock_sessionmaker.return_value)
|
||||
# Step B: .begin() -> returns a context manager (begin.return_value)
|
||||
# Step C: with ... as _session: -> calls __enter__(), and _session gets the __enter__.return_value
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# 4. Trigger the logic under test
|
||||
# Input: ["search", "search", "invalid"]
|
||||
# Logic:
|
||||
# - "invalid" should be filtered out (not in default_tool_label_name_list).
|
||||
# - The duplicate "search" should be merged (unique labels).
|
||||
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
|
||||
|
||||
mock_db.session.execute.assert_called_once()
|
||||
# only one valid unique label should be inserted.
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_db.session.commit.assert_called_once()
|
||||
# 5. Behavior Assertion: DELETE operation
|
||||
# Verify that the manager first attempts to clear existing labels for this specific tool.
|
||||
# This ensures the update is idempotent.
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
# 6. Behavior Assertion: INSERT operation
|
||||
# Verify that only ONE valid label ("search") was added after filtering and deduplication.
|
||||
# If call_count == 1, it proves filter_tool_labels() worked as expected.
|
||||
assert mock_session.add.call_count == 1
|
||||
|
||||
# 7. State Assertion: Data Integrity & Isolation
|
||||
# Inspect the actual object passed to session.add() to ensure it has correct properties.
|
||||
# This confirms that the data isolation (tool_id + tool_type) we refactored is active.
|
||||
call_args = mock_session.add.call_args
|
||||
added_label = call_args[0][0] # Retrieve the ToolLabelBinding instance
|
||||
|
||||
assert added_label.label_name == "search", "The label name should be 'search' after filtering."
|
||||
assert added_label.tool_id == expected_id, "The tool_id must match the provider_id for correct binding."
|
||||
assert added_label.tool_type == expected_type, "Isolation failed: tool_type must be verified during update."
|
||||
|
||||
|
||||
# Test error handling
|
||||
def test_tool_label_manager_update_tool_labels_unsupported():
|
||||
with pytest.raises(ValueError, match="Unsupported tool type"):
|
||||
ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# Test retrieval logic
|
||||
def test_tool_label_manager_get_tool_labels_for_builtin_and_db():
|
||||
# Mocking a property (@property) using PropertyMock
|
||||
with patch.object(
|
||||
_ConcreteBuiltinToolProviderController,
|
||||
"tool_labels",
|
||||
@ -62,29 +111,67 @@ def test_tool_label_manager_get_tool_labels_for_builtin_and_db():
|
||||
assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"]
|
||||
|
||||
api = _api_controller("api-1")
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
mock_db.session.scalars.return_value.all.return_value = ["search", "news"]
|
||||
labels = ToolLabelManager.get_tool_labels(api)
|
||||
assert labels == ["search", "news"]
|
||||
with (
|
||||
patch("core.tools.tool_label_manager.db"),
|
||||
patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Inject mock data into the query result: session.scalars(stmt).all()
|
||||
mock_session.scalars.return_value.all.return_value = ["search", "news"]
|
||||
|
||||
labels = ToolLabelManager.get_tool_labels(api)
|
||||
assert labels == ["search", "news"]
|
||||
|
||||
|
||||
def test_tool_label_manager_get_tool_labels_unsupported():
|
||||
"""
|
||||
Negative Test: Ensure get_tool_labels raises ValueError for unsupported controller types.
|
||||
This protects the internal API contract against accidental regressions during refactoring.
|
||||
"""
|
||||
# Passing a generic object() which doesn't match Api, Workflow, or Builtin controllers.
|
||||
with pytest.raises(ValueError, match="Unsupported tool type"):
|
||||
ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# Test batch processing and mapping
|
||||
def test_tool_label_manager_get_tools_labels_batch():
|
||||
assert ToolLabelManager.get_tools_labels([]) == {}
|
||||
|
||||
api = _api_controller("api-1")
|
||||
wf = _workflow_controller("wf-1")
|
||||
|
||||
# SimpleNamespace is a quick way to simulate SQLAlchemy row objects
|
||||
records = [
|
||||
SimpleNamespace(tool_id="api-1", label_name="search"),
|
||||
SimpleNamespace(tool_id="api-1", label_name="news"),
|
||||
SimpleNamespace(tool_id="wf-1", label_name="utilities"),
|
||||
]
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
mock_db.session.scalars.return_value.all.return_value = records
|
||||
|
||||
with (
|
||||
patch("core.tools.tool_label_manager.db"),
|
||||
patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Simulating the batch query result
|
||||
mock_session.scalars.return_value.all.return_value = records
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels([api, wf])
|
||||
|
||||
# Verify the final dictionary mapping
|
||||
assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]}
|
||||
|
||||
|
||||
def test_tool_label_manager_get_tools_labels_unsupported():
|
||||
"""
|
||||
Negative Test: Ensure get_tools_labels raises ValueError if the list contains
|
||||
unsupported controller types, even alongside valid ones.
|
||||
"""
|
||||
api = _api_controller("api-1")
|
||||
|
||||
# Passing a list with one valid controller and one invalid object()
|
||||
with pytest.raises(ValueError, match="Unsupported tool type"):
|
||||
ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user