From 4304044905cbf466b1847c445fe82dc4187be0a5 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 18 Jun 2026 11:16:09 +0900 Subject: [PATCH] chore: example of make db.session pass from parameter. (#37561) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/agent/roster.py | 2 +- api/controllers/console/app/app.py | 4 +- api/controllers/console/datasets/datasets.py | 1 + api/controllers/console/tag/tags.py | 14 ++-- api/controllers/console/workspace/snippets.py | 1 + api/controllers/openapi/apps.py | 4 +- .../service_api/dataset/dataset.py | 18 +++-- api/services/app_service.py | 14 ++-- api/services/dataset_service.py | 7 +- api/services/snippet_service.py | 18 ++--- api/services/tag_service.py | 79 +++++++++---------- .../service_api/dataset/test_dataset.py | 17 ++-- .../services/test_app_service.py | 56 +++++++++---- .../test_dataset_service_retrieval.py | 39 ++++++--- .../services/test_tag_service.py | 78 ++++++++++-------- .../console/agent/test_agent_controllers.py | 2 +- .../controllers/console/tag/test_tags.py | 19 +++-- .../console/workspace/test_snippets.py | 3 +- .../services/agent/test_agent_services.py | 66 ++++++++-------- .../services/test_snippet_service.py | 8 +- .../unit_tests/services/test_tag_service.py | 30 ++++--- 21 files changed, 274 insertions(+), 206 deletions(-) diff --git a/api/controllers/console/agent/roster.py b/api/controllers/console/agent/roster.py index bd34914909c..ce75f9f1d06 100644 --- a/api/controllers/console/agent/roster.py +++ b/api/controllers/console/agent/roster.py @@ -286,7 +286,7 @@ class AgentAppListApi(Resource): status="normal", ) - app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params) + app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params, db.session) if app_pagination is None: empty = AgentAppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) return empty.model_dump(mode="json") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2fb3a402962..3a8d0c1eb56 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -594,7 +594,7 @@ class AppListApi(Resource): # get app list app_service = AppService() - app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params) + app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params, db.session) if not app_pagination: empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) return empty.model_dump(mode="json"), 200 @@ -661,7 +661,7 @@ class StarredAppListApi(Resource): is_created_by_me=args.is_created_by_me, ) - app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params) + app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params, db.session) if not app_pagination: empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) return empty.model_dump(mode="json"), 200 diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ca2ef5d6e2d..623c02631c0 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -409,6 +409,7 @@ class DatasetListApi(Resource): datasets, total = DatasetService.get_datasets( query.page, query.limit, + db.session, current_tenant_id, current_user, query.keyword, diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 82a713f1c6f..38e7395ccf8 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -122,7 +122,7 @@ class TagListApi(Resource): raise Forbidden() payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type)) + tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type), db.session) response = TagResponse.model_validate( {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} @@ -146,9 +146,9 @@ class TagUpdateDeleteApi(Resource): raise Forbidden() payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {}) - tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str) + tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str, db.session) - binding_count = TagService.get_tag_binding_count(tag_id_str) + binding_count = TagService.get_tag_binding_count(tag_id_str, db.session) response = TagResponse.model_validate( {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} @@ -164,7 +164,7 @@ class TagUpdateDeleteApi(Resource): def delete(self, tag_id: UUID): tag_id_str = str(tag_id) - TagService.delete_tag(tag_id_str) + TagService.delete_tag(tag_id_str, db.session) return "", 204 @@ -189,7 +189,8 @@ def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]: tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type, - ) + ), + db.session, ) return {"result": "success"}, 200 @@ -203,7 +204,8 @@ def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]: tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type, - ) + ), + db.session, ) return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/snippets.py b/api/controllers/console/workspace/snippets.py index 4bd493d25e9..7fcca1f79e8 100644 --- a/api/controllers/console/workspace/snippets.py +++ b/api/controllers/console/workspace/snippets.py @@ -126,6 +126,7 @@ class CustomizedSnippetsApi(Resource): snippet_service = _snippet_service() snippets, total, has_more = snippet_service.get_snippets( tenant_id=current_tenant_id, + session=db.session, page=query.page, limit=query.limit, keyword=query.keyword, diff --git a/api/controllers/openapi/apps.py b/api/controllers/openapi/apps.py index 84b1610d5f5..c4796313c0b 100644 --- a/api/controllers/openapi/apps.py +++ b/api/controllers/openapi/apps.py @@ -174,7 +174,7 @@ class AppListApi(Resource): tag_ids: list[str] | None = None if query.tag: - tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag) + tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag, db.session) if not tags: return empty tag_ids = [tag.id for tag in tags] @@ -191,7 +191,7 @@ class AppListApi(Resource): openapi_visible=True, ) - pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params) + pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params, db.session) if pagination is None: return empty diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 305acae4a6e..0ca5c5bbf6b 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -770,7 +770,7 @@ class DatasetTagsApi(DatasetApiResource): raise Forbidden() payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) - tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE)) + tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE), db.session) response = dump_response( KnowledgeTagResponse, @@ -808,9 +808,9 @@ class DatasetTagsApi(DatasetApiResource): payload = TagUpdatePayload.model_validate(service_api_ns.payload or {}) tag_id = payload.tag_id - tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id) + tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id, db.session) - binding_count = TagService.get_tag_binding_count(tag_id) + binding_count = TagService.get_tag_binding_count(tag_id, db.session) response = dump_response( KnowledgeTagResponse, @@ -840,7 +840,7 @@ class DatasetTagsApi(DatasetApiResource): def delete(self, _): """Delete a knowledge type tag.""" payload = TagDeletePayload.model_validate(service_api_ns.payload or {}) - TagService.delete_tag(payload.tag_id) + TagService.delete_tag(payload.tag_id, db.session) return "", 204 @@ -873,7 +873,8 @@ class DatasetTagBindingApi(DatasetApiResource): payload = TagBindingPayload.model_validate(service_api_ns.payload or {}) TagService.save_tag_binding( - TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE) + TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE), + db.session, ) return "", 204 @@ -907,7 +908,8 @@ class DatasetTagUnbindingApi(DatasetApiResource): payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) TagService.delete_tag_binding( - TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE) + TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE), + db.session, ) return "", 204 @@ -942,6 +944,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): dataset_id = kwargs.get("dataset_id") assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None - tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id)) + tags = TagService.get_tags_by_target_id( + "knowledge", current_user.current_tenant_id, str(dataset_id), db.session + ) tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200 diff --git a/api/services/app_service.py b/api/services/app_service.py index c435a672520..18f4e70df4b 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -75,7 +75,7 @@ class CreateAppParams(BaseModel): class AppService: @staticmethod def _build_app_list_filters( - user_id: str, tenant_id: str, params: AppListBaseParams + user_id: str, tenant_id: str, params: AppListBaseParams, session: scoped_session ) -> list[sa.ColumnElement[bool]]: filters = [App.tenant_id == tenant_id, App.is_universal == False] @@ -115,7 +115,7 @@ class AppService: escaped_name = escape_like_pattern(name) filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\")) if params.tag_ids and len(params.tag_ids) > 0: - target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, match_all=True) + target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, session, match_all=True) if target_ids and len(target_ids) > 0: filters.append(App.id.in_(target_ids)) else: @@ -197,7 +197,9 @@ class AppService: ).scalars() ) - def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None: + def get_paginate_apps( + self, user_id: str, tenant_id: str, params: AppListParams, session: scoped_session + ) -> Pagination | None: """ Get app list with pagination, filters, and explicit sort order. :param user_id: user id @@ -205,7 +207,7 @@ class AppService: :param params: query parameters :return: """ - filters = self._build_app_list_filters(user_id, tenant_id, params) + filters = self._build_app_list_filters(user_id, tenant_id, params, session) if not filters: return None @@ -231,12 +233,12 @@ class AppService: return app_models def get_paginate_starred_apps( - self, user_id: str, tenant_id: str, params: StarredAppListParams + self, user_id: str, tenant_id: str, params: StarredAppListParams, session: scoped_session ) -> Pagination | None: """ Get apps starred by the current account with pagination, filters, and explicit sort order. """ - filters = self._build_app_list_filters(user_id, tenant_id, params) + filters = self._build_app_list_filters(user_id, tenant_id, params, session) if not filters: return None diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 364d9b36b9d..e8b17a137f9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -13,7 +13,7 @@ import sqlalchemy as sa from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator from redis.exceptions import LockNotOwnedError from sqlalchemy import delete, exists, func, select, update -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session, scoped_session, sessionmaker from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config @@ -235,7 +235,9 @@ class _EstimateArgs(BaseModel): class DatasetService: @staticmethod - def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): + def get_datasets( + page, per_page, session: scoped_session, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False + ): query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id) if user: @@ -295,6 +297,7 @@ class DatasetService: "knowledge", tenant_id, tag_ids, + session, match_all=True, ) else: diff --git a/api/services/snippet_service.py b/api/services/snippet_service.py index 9f16d412040..75282db9d4c 100644 --- a/api/services/snippet_service.py +++ b/api/services/snippet_service.py @@ -6,7 +6,7 @@ from datetime import UTC, datetime from typing import Any from sqlalchemy import delete, func, select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session, scoped_session, sessionmaker from core.db import session_factory from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING @@ -192,6 +192,7 @@ class SnippetService: self, *, tenant_id: str, + session: scoped_session, page: int = 1, limit: int = 20, keyword: str | None = None, @@ -229,20 +230,19 @@ class SnippetService: stmt = stmt.where(CustomizedSnippet.created_by.in_(creators)) if tag_ids: - target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, match_all=True) + target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, session, match_all=True) if target_ids: stmt = stmt.where(CustomizedSnippet.id.in_(target_ids)) else: return [], 0, False - with self._session_scope() as session: - # Get total count - count_stmt = select(func.count()).select_from(stmt.subquery()) - total = session.scalar(count_stmt) or 0 + # Get total count + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = session.scalar(count_stmt) or 0 - # Apply pagination - stmt = stmt.limit(limit + 1).offset((page - 1) * limit) - snippets = list(session.scalars(stmt).all()) + # Apply pagination + stmt = stmt.limit(limit + 1).offset((page - 1) * limit) + snippets = list(session.scalars(stmt).all()) has_more = len(snippets) > limit if has_more: diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 20f9a2c73d5..59dd5f7bb36 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -6,10 +6,9 @@ from flask_login import current_user from pydantic import BaseModel, Field from sqlalchemy import delete, func, select from sqlalchemy.engine import CursorResult -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, scoped_session from werkzeug.exceptions import NotFound -from extensions.ext_database import db from models.dataset import Dataset from models.enums import TagType from models.model import App, Tag, TagBinding @@ -56,7 +55,7 @@ class TagService: @staticmethod def get_target_ids_by_tag_ids( - tag_type: str, current_tenant_id: str, tag_ids: list[str], *, match_all: bool = False + tag_type: str, current_tenant_id: str, tag_ids: list[str], session: scoped_session, *, match_all: bool = False ): """ Return target IDs bound to tags for the given tenant and resource type. @@ -70,7 +69,7 @@ class TagService: return [] # Deduplicate repeated query params so match_all counts each requested tag once. requested_tag_ids = list(dict.fromkeys(tag_ids)) - tags = db.session.scalars( + tags = session.scalars( select(Tag.id).where( Tag.id.in_(requested_tag_ids), Tag.tenant_id == current_tenant_id, @@ -86,13 +85,13 @@ class TagService: if match_all: if len(tag_ids) != len(requested_tag_ids): return [] - return db.session.scalars( + return session.scalars( select(TagBinding.target_id) .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) .group_by(TagBinding.target_id) .having(func.count(sa.distinct(TagBinding.tag_id)) == len(tag_ids)) ).all() - tag_bindings = db.session.scalars( + tag_bindings = session.scalars( select(TagBinding.target_id).where( TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id ) @@ -100,11 +99,11 @@ class TagService: return tag_bindings @staticmethod - def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str): + def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str, session: scoped_session): if not tag_type or not tag_name: return [] tags = list( - db.session.scalars( + session.scalars( select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) ).all() ) @@ -113,8 +112,8 @@ class TagService: return tags @staticmethod - def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str): - tags = db.session.scalars( + def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str, session: scoped_session): + tags = session.scalars( select(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) .where( @@ -128,8 +127,8 @@ class TagService: return tags or [] @staticmethod - def save_tags(payload: SaveTagPayload) -> Tag: - if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name): + def save_tags(payload: SaveTagPayload, session: scoped_session) -> Tag: + if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name, session): raise ValueError("Tag name already exists") tag = Tag( name=payload.name, @@ -138,17 +137,17 @@ class TagService: tenant_id=current_user.current_tenant_id, ) tag.id = str(uuid.uuid4()) - db.session.add(tag) - db.session.commit() + session.add(tag) + session.commit() return tag @staticmethod - def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag: - tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1)) + def update_tags(payload: UpdateTagPayload, tag_id: str, session: scoped_session) -> Tag: + tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1)) if not tag: raise NotFound("Tag not found") if payload.name != tag.name: - existing = db.session.scalar( + existing = session.scalar( select(Tag) .where( Tag.name == payload.name, @@ -161,31 +160,31 @@ class TagService: if existing: raise ValueError("Tag name already exists") tag.name = payload.name - db.session.commit() + session.commit() return tag @staticmethod - def get_tag_binding_count(tag_id: str) -> int: - count = db.session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0 + def get_tag_binding_count(tag_id: str, session: scoped_session) -> int: + count = session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0 return count @staticmethod - def delete_tag(tag_id: str): - tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1)) + def delete_tag(tag_id: str, session: scoped_session): + tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1)) if not tag: raise NotFound("Tag not found") - db.session.delete(tag) + session.delete(tag) # delete tag binding - tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all() + tag_bindings = session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all() if tag_bindings: for tag_binding in tag_bindings: - db.session.delete(tag_binding) - db.session.commit() + session.delete(tag_binding) + session.commit() @staticmethod - def save_tag_binding(payload: TagBindingCreatePayload): - TagService.check_target_exists(payload.type, payload.target_id) - valid_tag_ids = db.session.scalars( + def save_tag_binding(payload: TagBindingCreatePayload, session: scoped_session): + TagService.check_target_exists(payload.type, payload.target_id, session) + valid_tag_ids = session.scalars( select(Tag.id).where( Tag.id.in_(payload.tag_ids), Tag.tenant_id == current_user.current_tenant_id, @@ -193,7 +192,7 @@ class TagService: ) ).all() for tag_id in valid_tag_ids: - tag_binding = db.session.scalar( + tag_binding = session.scalar( select(TagBinding) .where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id) .limit(1) @@ -206,15 +205,15 @@ class TagService: tenant_id=current_user.current_tenant_id, created_by=current_user.id, ) - db.session.add(new_tag_binding) - db.session.commit() + session.add(new_tag_binding) + session.commit() @staticmethod - def delete_tag_binding(payload: TagBindingDeletePayload): - TagService.check_target_exists(payload.type, payload.target_id) + def delete_tag_binding(payload: TagBindingDeletePayload, session: scoped_session): + TagService.check_target_exists(payload.type, payload.target_id, session) result = cast( CursorResult, - db.session.execute( + session.execute( delete(TagBinding).where( TagBinding.target_id == payload.target_id, TagBinding.tag_id.in_(payload.tag_ids), @@ -230,12 +229,12 @@ class TagService: ) if result.rowcount: - db.session.commit() + session.commit() @staticmethod - def check_target_exists(type: str, target_id: str): + def check_target_exists(type: str, target_id: str, session: scoped_session): if type == "knowledge": - dataset = db.session.scalar( + dataset = session.scalar( select(Dataset) .where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) .limit(1) @@ -243,13 +242,13 @@ class TagService: if not dataset: raise NotFound("Dataset not found") elif type == "app": - app = db.session.scalar( + app = session.scalar( select(App).where(App.tenant_id == current_user.current_tenant_id, App.id == target_id).limit(1) ) if not app: raise NotFound("App not found") elif type == "snippet": - snippet = db.session.scalar( + snippet = session.scalar( select(CustomizedSnippet) .where(CustomizedSnippet.tenant_id == current_user.current_tenant_id, CustomizedSnippet.id == target_id) .limit(1) diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 642dd3ab62d..ac166df454f 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -16,7 +16,7 @@ since these test controller-level behavior. import uuid from contextlib import ExitStack from datetime import UTC, datetime -from unittest.mock import Mock, PropertyMock, patch +from unittest.mock import ANY, Mock, PropertyMock, patch import pytest from flask import Flask @@ -1129,7 +1129,7 @@ class TestDatasetTagsApiPatch: assert status == 200 assert response == {"id": "tag-1", "name": "Updated Tag", "type": "knowledge", "binding_count": "5"} mock_tag_svc.update_tags.assert_called_once() - update_payload, tag_id = mock_tag_svc.update_tags.call_args.args + update_payload, tag_id, session = mock_tag_svc.update_tags.call_args.args assert update_payload.name == "Updated Tag" assert tag_id == "tag-1" @@ -1184,7 +1184,7 @@ class TestDatasetTagsApiDelete: result = api.delete(_=None) assert result == ("", 204) - mock_tag_svc.delete_tag.assert_called_once_with("tag-1") + mock_tag_svc.delete_tag.assert_called_once_with("tag-1", ANY) @patch("libs.login.current_user") def test_delete_tag_forbidden(self, mock_current_user, app: Flask): @@ -1233,7 +1233,7 @@ class TestDatasetTagsBindingStatusApi: assert status_code == 200 assert response["data"] == [{"id": "tag_1", "name": "Test Tag"}] assert response["total"] == 1 - mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123") + mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123", ANY) class TestDatasetTagBindingApiPost: @@ -1266,7 +1266,8 @@ class TestDatasetTagBindingApiPost: from services.tag_service import TagBindingCreatePayload mock_tag_svc.save_tag_binding.assert_called_once_with( - TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE) + TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE), + ANY, ) @patch("controllers.service_api.dataset.dataset.current_user") @@ -1317,7 +1318,8 @@ class TestDatasetTagUnbindingApiPost: from services.tag_service import TagBindingDeletePayload mock_tag_svc.delete_tag_binding.assert_called_once_with( - TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE) + TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE), + ANY, ) @patch("controllers.service_api.dataset.dataset.TagService") @@ -1347,7 +1349,8 @@ class TestDatasetTagUnbindingApiPost: from services.tag_service import TagBindingDeletePayload mock_tag_svc.delete_tag_binding.assert_called_once_with( - TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE) + TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE), + ANY, ) @patch("controllers.service_api.dataset.dataset.current_user") diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 384f83fce3e..43c254d407d 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -234,7 +234,7 @@ class TestAppService: # Get paginated apps params = AppListParams(page=1, limit=10, mode="chat") - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params) + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers) # Verify pagination results assert paginated_apps is not None @@ -295,7 +295,7 @@ class TestAppService: db_session_with_containers.commit() last_modified_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat") + account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers ) assert last_modified_apps is not None assert [app.name for app in last_modified_apps.items] == [ @@ -305,7 +305,10 @@ class TestAppService: ] recently_created_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created") + account.id, + tenant.id, + AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"), + db_session_with_containers, ) assert recently_created_apps is not None assert [app.name for app in recently_created_apps.items] == [ @@ -315,7 +318,10 @@ class TestAppService: ] earliest_created_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created") + account.id, + tenant.id, + AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"), + db_session_with_containers, ) assert earliest_created_apps is not None assert [app.name for app in earliest_created_apps.items] == [ @@ -366,7 +372,7 @@ class TestAppService: assert star_count == 1 paginated_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat") + account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers ) assert paginated_apps is not None starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items} @@ -377,7 +383,7 @@ class TestAppService: db_session_with_containers.commit() paginated_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat") + account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers ) assert paginated_apps is not None starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items} @@ -442,7 +448,7 @@ class TestAppService: db_session_with_containers.commit() last_modified_apps = app_service.get_paginate_starred_apps( - account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat") + account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat"), db_session_with_containers ) assert last_modified_apps is not None assert [app.name for app in last_modified_apps.items] == [ @@ -457,6 +463,7 @@ class TestAppService: account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"), + db_session_with_containers, ) assert recently_created_apps is not None assert [app.name for app in recently_created_apps.items] == [ @@ -469,6 +476,7 @@ class TestAppService: account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"), + db_session_with_containers, ) assert earliest_created_apps is not None assert [app.name for app in earliest_created_apps.items] == [ @@ -522,20 +530,25 @@ class TestAppService: completion_app = app_service.create_app(tenant.id, completion_app_params, account) # Test filter by mode - chat_apps = app_service.get_paginate_apps(account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")) + chat_apps = app_service.get_paginate_apps( + account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers + ) assert len(chat_apps.items) == 1 assert chat_apps.items[0].mode == "chat" # Test filter by name filtered_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat") + account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat"), db_session_with_containers ) assert len(filtered_apps.items) == 1 assert "Chat" in filtered_apps.items[0].name # Test filter by created_by_me my_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True) + account.id, + tenant.id, + AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True), + db_session_with_containers, ) assert len(my_apps.items) == 1 @@ -588,6 +601,7 @@ class TestAppService: first_account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", creator_ids=[second_account.id]), + db_session_with_containers, ) assert filtered_apps is not None @@ -635,10 +649,12 @@ class TestAppService: # Test with tag filter params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["tag1", "tag2"]) - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params) + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers) # Verify tag service was called - mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"], match_all=True) + mock_tag_service.assert_called_once_with( + "app", tenant.id, ["tag1", "tag2"], db_session_with_containers, match_all=True + ) # Verify results assert paginated_apps is not None @@ -651,7 +667,7 @@ class TestAppService: params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["nonexistent_tag"]) - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params) + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers) # Should return None when no apps match tag filter assert paginated_apps is None @@ -1467,7 +1483,7 @@ class TestAppService: # Test 1: Search with % character paginated_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10) + account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers ) assert paginated_apps is not None assert paginated_apps.total == 1 @@ -1476,7 +1492,10 @@ class TestAppService: # Test 2: Search with _ character paginated_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(name="test_data", mode="chat", page=1, limit=10) + account.id, + tenant.id, + AppListParams(name="test_data", mode="chat", page=1, limit=10), + db_session_with_containers, ) assert paginated_apps is not None assert paginated_apps.total == 1 @@ -1485,7 +1504,10 @@ class TestAppService: # Test 3: Search with \ character paginated_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10) + account.id, + tenant.id, + AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10), + db_session_with_containers, ) assert paginated_apps is not None assert paginated_apps.total == 1 @@ -1494,7 +1516,7 @@ class TestAppService: # Test 4: Search with % should NOT match 100% (verifies escaping works) paginated_apps = app_service.get_paginate_apps( - account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10) + account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers ) assert paginated_apps is not None assert paginated_apps.total == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 4e2bf9fc103..27ab600871b 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -227,7 +227,7 @@ class TestDatasetServiceGetDatasets: ) # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id) + datasets, total = DatasetService.get_datasets(page, per_page, db_session_with_containers, tenant_id=tenant.id) # Assert assert len(datasets) == 5 @@ -257,7 +257,9 @@ class TestDatasetServiceGetDatasets: ) # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, search=search) + datasets, total = DatasetService.get_datasets( + page, per_page, db_session_with_containers, tenant_id=tenant.id, search=search + ) # Assert assert len(datasets) == 1 @@ -301,7 +303,9 @@ class TestDatasetServiceGetDatasets: tag_ids = [tag_1.id, tag_2.id] # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids) + datasets, total = DatasetService.get_datasets( + page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids + ) # Assert assert len(datasets) == 1 @@ -326,7 +330,9 @@ class TestDatasetServiceGetDatasets: ) # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids) + datasets, total = DatasetService.get_datasets( + page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids + ) # Assert # When tag_ids is empty, tag filtering is skipped, so normal query results are returned @@ -356,7 +362,9 @@ class TestDatasetServiceGetDatasets: ) # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, user=None) + datasets, total = DatasetService.get_datasets( + page, per_page, db_session_with_containers, tenant_id=tenant.id, user=None + ) # Assert assert len(datasets) == 1 @@ -384,6 +392,7 @@ class TestDatasetServiceGetDatasets: datasets, total = DatasetService.get_datasets( page=1, per_page=20, + session=db_session_with_containers, tenant_id=tenant.id, user=owner, include_all=True, @@ -408,7 +417,9 @@ class TestDatasetServiceGetDatasets: ) # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) + datasets, total = DatasetService.get_datasets( + page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user + ) # Assert assert len(datasets) == 1 @@ -432,7 +443,9 @@ class TestDatasetServiceGetDatasets: ) # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) + datasets, total = DatasetService.get_datasets( + page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user + ) # Assert assert len(datasets) == 1 @@ -459,7 +472,9 @@ class TestDatasetServiceGetDatasets: ) # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) + datasets, total = DatasetService.get_datasets( + page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user + ) # Assert assert len(datasets) == 1 @@ -486,7 +501,9 @@ class TestDatasetServiceGetDatasets: ) # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator) + datasets, total = DatasetService.get_datasets( + page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator + ) # Assert assert len(datasets) == 1 @@ -509,7 +526,9 @@ class TestDatasetServiceGetDatasets: ) # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator) + datasets, total = DatasetService.get_datasets( + page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator + ) # Assert assert datasets == [] diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 517d5d2ed4c..197415ee6bd 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -449,7 +449,7 @@ class TestTagService: # Act: Execute the method under test tag_ids = [tag.id for tag in tags] - result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids) + result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, db_session_with_containers) # Assert: Verify the expected outcomes assert result is not None @@ -485,7 +485,7 @@ class TestTagService: ) # Act: Execute the method under test with empty tag IDs - result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, []) + result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [], db_session_with_containers) # Assert: Verify the expected outcomes assert result is not None @@ -533,13 +533,19 @@ class TestTagService: # Act: Execute the method under test tag_ids = [tag.id for tag in tags] - result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, match_all=True) + result = TagService.get_target_ids_by_tag_ids( + "knowledge", tenant.id, tag_ids, db_session_with_containers, match_all=True + ) # Assert: Verify the expected outcomes assert result == [dataset_with_all_tags.id] missing_tag_result = TagService.get_target_ids_by_tag_ids( - "knowledge", tenant.id, [tags[0].id, str(uuid.uuid4())], match_all=True + "knowledge", + tenant.id, + [tags[0].id, str(uuid.uuid4())], + db_session_with_containers, + match_all=True, ) assert missing_tag_result == [] @@ -565,7 +571,9 @@ class TestTagService: non_existent_tag_ids = [str(uuid.uuid4()), str(uuid.uuid4())] # Act: Execute the method under test - result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, non_existent_tag_ids) + result = TagService.get_target_ids_by_tag_ids( + "knowledge", tenant.id, non_existent_tag_ids, db_session_with_containers + ) # Assert: Verify the expected outcomes assert result is not None @@ -599,7 +607,7 @@ class TestTagService: db_session_with_containers.commit() # Act: Execute the method under test - result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag") + result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag", db_session_with_containers) # Assert: Verify the expected outcomes assert result is not None @@ -625,7 +633,7 @@ class TestTagService: ) # Act: Execute the method under test with non-existent tag name - result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag") + result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag", db_session_with_containers) # Assert: Verify the expected outcomes assert result is not None @@ -650,8 +658,8 @@ class TestTagService: ) # Act: Execute the method under test with empty parameters - result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag") - result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "") + result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag", db_session_with_containers) + result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "", db_session_with_containers) # Assert: Verify the expected outcomes assert result_empty_type is not None @@ -688,7 +696,7 @@ class TestTagService: ) # Act: Execute the method under test - result = TagService.get_tags_by_target_id("app", tenant.id, app.id) + result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers) # Assert: Verify the expected outcomes assert result is not None @@ -720,7 +728,7 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Execute the method under test - result = TagService.get_tags_by_target_id("app", tenant.id, app.id) + result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers) # Assert: Verify the expected outcomes assert result is not None @@ -745,7 +753,7 @@ class TestTagService: tag_args = SaveTagPayload(name="test_tag_name", type="knowledge") # Act: Execute the method under test - result = TagService.save_tags(tag_args) + result = TagService.save_tags(tag_args, db_session_with_containers) # Assert: Verify the expected outcomes assert result is not None @@ -783,11 +791,11 @@ class TestTagService: # Create first tag tag_args = SaveTagPayload(name="duplicate_tag", type="app") - TagService.save_tags(tag_args) + TagService.save_tags(tag_args, db_session_with_containers) # Act & Assert: Verify proper error handling with pytest.raises(ValueError) as exc_info: - TagService.save_tags(tag_args) + TagService.save_tags(tag_args, db_session_with_containers) assert "Tag name already exists" in str(exc_info.value) def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): @@ -807,13 +815,13 @@ class TestTagService: # Create a tag to update tag_args = SaveTagPayload(name="original_name", type="knowledge") - tag = TagService.save_tags(tag_args) + tag = TagService.save_tags(tag_args, db_session_with_containers) # Update args update_args = UpdateTagPayload(name="updated_name") # Act: Execute the method under test - result = TagService.update_tags(update_args, tag.id) + result = TagService.update_tags(update_args, tag.id, db_session_with_containers) # Assert: Verify the expected outcomes assert result is not None @@ -854,7 +862,7 @@ class TestTagService: # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: - TagService.update_tags(update_args, non_existent_tag_id) + TagService.update_tags(update_args, non_existent_tag_id, db_session_with_containers) assert "Tag not found" in str(exc_info.value) def test_update_tags_duplicate_name_error( @@ -875,17 +883,17 @@ class TestTagService: # Create two tags tag1_args = SaveTagPayload(name="first_tag", type="app") - tag1 = TagService.save_tags(tag1_args) + tag1 = TagService.save_tags(tag1_args, db_session_with_containers) tag2_args = SaveTagPayload(name="second_tag", type="app") - tag2 = TagService.save_tags(tag2_args) + tag2 = TagService.save_tags(tag2_args, db_session_with_containers) # Try to update second tag with first tag's name update_args = UpdateTagPayload(name="first_tag") # Act & Assert: Verify proper error handling with pytest.raises(ValueError) as exc_info: - TagService.update_tags(update_args, tag2.id) + TagService.update_tags(update_args, tag2.id, db_session_with_containers) assert "Tag name already exists" in str(exc_info.value) def test_get_tag_binding_count_success( @@ -917,8 +925,8 @@ class TestTagService: ) # Act: Execute the method under test - result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id) - result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id) + result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id, db_session_with_containers) + result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id, db_session_with_containers) # Assert: Verify the expected outcomes assert result_tag_with_bindings == 1 @@ -946,7 +954,7 @@ class TestTagService: non_existent_tag_id = str(uuid.uuid4()) # Act: Execute the method under test - result = TagService.get_tag_binding_count(non_existent_tag_id) + result = TagService.get_tag_binding_count(non_existent_tag_id, db_session_with_containers) # Assert: Verify the expected outcomes assert result == 0 @@ -986,7 +994,7 @@ class TestTagService: assert binding_before is not None # Act: Execute the method under test - TagService.delete_tag(tag.id) + TagService.delete_tag(tag.id, db_session_with_containers) # Assert: Verify the expected outcomes # Verify tag was deleted @@ -1018,7 +1026,7 @@ class TestTagService: # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: - TagService.delete_tag(non_existent_tag_id) + TagService.delete_tag(non_existent_tag_id, db_session_with_containers) assert "Tag not found" in str(exc_info.value) def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): @@ -1048,7 +1056,7 @@ class TestTagService: binding_payload = TagBindingCreatePayload( type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags] ) - TagService.save_tag_binding(binding_payload) + TagService.save_tag_binding(binding_payload, db_session_with_containers) # Assert: Verify the expected outcomes @@ -1090,10 +1098,10 @@ class TestTagService: # Create first binding binding_payload = TagBindingCreatePayload(type="app", target_id=app.id, tag_ids=[tag.id]) - TagService.save_tag_binding(binding_payload) + TagService.save_tag_binding(binding_payload, db_session_with_containers) # Act: Try to create duplicate binding - TagService.save_tag_binding(binding_payload) + TagService.save_tag_binding(binding_payload, db_session_with_containers) # Assert: Verify the expected outcomes @@ -1173,7 +1181,7 @@ class TestTagService: delete_payload = TagBindingDeletePayload( type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags] ) - TagService.delete_tag_binding(delete_payload) + TagService.delete_tag_binding(delete_payload, db_session_with_containers) # Assert: Verify the expected outcomes # Verify tag bindings were deleted @@ -1209,7 +1217,7 @@ class TestTagService: # Act: Try to delete non-existent binding delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_ids=[tag.id]) - TagService.delete_tag_binding(delete_payload) + TagService.delete_tag_binding(delete_payload, db_session_with_containers) # Assert: Verify the expected outcomes # No error should be raised, and database state should remain unchanged @@ -1240,7 +1248,7 @@ class TestTagService: dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Execute the method under test - TagService.check_target_exists("knowledge", dataset.id) + TagService.check_target_exists("knowledge", dataset.id, db_session_with_containers) # Assert: Verify the expected outcomes # No exception should be raised for existing dataset @@ -1268,7 +1276,7 @@ class TestTagService: # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: - TagService.check_target_exists("knowledge", non_existent_dataset_id) + TagService.check_target_exists("knowledge", non_existent_dataset_id, db_session_with_containers) assert "Dataset not found" in str(exc_info.value) def test_check_target_exists_app_success( @@ -1292,7 +1300,7 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Execute the method under test - TagService.check_target_exists("app", app.id) + TagService.check_target_exists("app", app.id, db_session_with_containers) # Assert: Verify the expected outcomes # No exception should be raised for existing app @@ -1320,7 +1328,7 @@ class TestTagService: # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: - TagService.check_target_exists("app", non_existent_app_id) + TagService.check_target_exists("app", non_existent_app_id, db_session_with_containers) assert "App not found" in str(exc_info.value) def test_check_target_exists_invalid_type( @@ -1346,5 +1354,5 @@ class TestTagService: # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: - TagService.check_target_exists("invalid_type", non_existent_target_id) + TagService.check_target_exists("invalid_type", non_existent_target_id, db_session_with_containers) assert "Invalid binding type" in str(exc_info.value) diff --git a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py index e62985d64dc..dfd51f462ea 100644 --- a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py +++ b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py @@ -191,7 +191,7 @@ def test_agent_app_list_and_create_use_agent_route( def get_app(self, app_obj: object) -> object: return app_obj - def get_paginate_apps(self, user_id: str, tenant_id: str, params) -> object: + def get_paginate_apps(self, user_id: str, tenant_id: str, params, session) -> object: captured["list"] = {"user_id": user_id, "tenant_id": tenant_id, "params": params} return SimpleNamespace( page=1, diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index dc3dd00a6c0..84a70835437 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -2,15 +2,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock, PropertyMock, patch import pytest -from sqlalchemy.orm import Session - - -class SessionMatcher: - def __eq__(self, other): - return isinstance(other, Session) - - from flask import Flask +from sqlalchemy.orm import Session, scoped_session from werkzeug.exceptions import Forbidden import controllers.console.tag.tags as module @@ -27,6 +20,11 @@ from models.enums import TagType from services.tag_service import UpdateTagPayload +class SessionMatcher: + def __eq__(self, other): + return isinstance(other, Session | scoped_session) + + def unwrap(func): """ Recursively unwrap decorated functions. @@ -193,9 +191,10 @@ class TestTagUpdateDeleteApi: result, status = method(api, admin_user, "tag-1") assert status == 200 - update_payload, tag_id = update_tags_mock.call_args.args + update_payload, tag_id, session = update_tags_mock.call_args.args assert update_payload == UpdateTagPayload(name="updated") assert tag_id == "tag-1" + assert session == module.db.session assert result["binding_count"] == "3" def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch): @@ -221,7 +220,7 @@ class TestTagUpdateDeleteApi: ): result, status = method(api, "tag-1") - delete_mock.assert_called_once_with("tag-1") + delete_mock.assert_called_once_with("tag-1", module.db.session) assert status == 204 diff --git a/api/tests/unit_tests/controllers/console/workspace/test_snippets.py b/api/tests/unit_tests/controllers/console/workspace/test_snippets.py index b8914fc26cb..a276a181d62 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_snippets.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_snippets.py @@ -1,6 +1,6 @@ from inspect import unwrap from types import SimpleNamespace -from unittest.mock import Mock +from unittest.mock import ANY, Mock import pytest from werkzeug.exceptions import NotFound @@ -94,6 +94,7 @@ def test_list_snippets_returns_pagination(app, monkeypatch): } get_snippets.assert_called_once_with( tenant_id="tenant-1", + session=ANY, page=2, limit=10, keyword=None, diff --git a/api/tests/unit_tests/services/agent/test_agent_services.py b/api/tests/unit_tests/services/agent/test_agent_services.py index fb0a316648a..41d3d434ba4 100644 --- a/api/tests/unit_tests/services/agent/test_agent_services.py +++ b/api/tests/unit_tests/services/agent/test_agent_services.py @@ -100,7 +100,7 @@ def test_agent_soul_has_model(): assert agent_soul_has_model(AgentSoulConfig()) is False -def test_load_workflow_composer_returns_empty_state(monkeypatch): +def test_load_workflow_composer_returns_empty_state(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1")) monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: None) @@ -118,7 +118,7 @@ def test_load_workflow_composer_returns_empty_state(monkeypatch): assert files_output["array_item"] == {"type": "file", "description": None, "children": []} -def test_load_workflow_composer_serializes_existing_binding(monkeypatch): +def test_load_workflow_composer_serializes_existing_binding(monkeypatch: pytest.MonkeyPatch): binding = SimpleNamespace( agent_id="agent-1", binding_type=WorkflowAgentBindingType.ROSTER_AGENT, @@ -220,7 +220,7 @@ def test_save_workflow_composer_rejects_agent_app_variant(): ) -def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch): +def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch: pytest.MonkeyPatch): fake_session = FakeSession(scalar=[None]) created_version = SimpleNamespace(id="version-1") @@ -249,7 +249,7 @@ def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch): assert fake_session.commits == 1 -def test_save_agent_app_composer_updates_current_version(monkeypatch): +def test_save_agent_app_composer_updates_current_version(monkeypatch: pytest.MonkeyPatch): agent = SimpleNamespace(id="agent-1", active_config_snapshot_id="version-1", updated_by=None) fake_session = FakeSession(scalar=[agent]) updated = {} @@ -283,7 +283,7 @@ def test_save_agent_app_composer_updates_current_version(monkeypatch): assert fake_session.commits == 1 -def test_agent_app_composer_candidates_and_impact(monkeypatch): +def test_agent_app_composer_candidates_and_impact(monkeypatch: pytest.MonkeyPatch): bindings = [ SimpleNamespace(app_id="app-1", workflow_id="workflow-1", node_id="node-1"), SimpleNamespace(app_id="app-1", workflow_id="workflow-1", node_id="node-2"), @@ -316,7 +316,7 @@ def test_agent_app_composer_candidates_and_impact(monkeypatch): assert impact["bindings"][1]["node_id"] == "node-2" -def test_serialize_workflow_state_changes_lock_and_save_options(monkeypatch): +def test_serialize_workflow_state_changes_lock_and_save_options(monkeypatch: pytest.MonkeyPatch): binding = WorkflowAgentNodeBinding( id="binding-1", tenant_id="tenant-1", @@ -342,7 +342,7 @@ def test_serialize_workflow_state_changes_lock_and_save_options(monkeypatch): assert effective_names == ["text", "files", "json"] -def test_serialize_workflow_state_passes_user_declared_outputs_through_effective(monkeypatch): +def test_serialize_workflow_state_passes_user_declared_outputs_through_effective(monkeypatch: pytest.MonkeyPatch): binding = WorkflowAgentNodeBinding( id="binding-1", tenant_id="tenant-1", @@ -369,7 +369,7 @@ def test_serialize_workflow_state_passes_user_declared_outputs_through_effective assert effective[0]["required"] is True -def test_composer_save_helpers_create_and_rebind_agents(monkeypatch): +def test_composer_save_helpers_create_and_rebind_agents(monkeypatch: pytest.MonkeyPatch): fake_session = FakeSession() monkeypatch.setattr(composer_service.db, "session", fake_session) workflow_agent = SimpleNamespace(id="inline-agent-1", active_config_snapshot_id="inline-version-1") @@ -611,7 +611,7 @@ def test_composer_create_agents_syncs_active_config_has_model(monkeypatch): assert roster_agent.active_config_has_model is True -def test_composer_version_helpers_and_lookup_errors(monkeypatch): +def test_composer_version_helpers_and_lookup_errors(monkeypatch: pytest.MonkeyPatch): fake_session = FakeSession( scalar=[ 1, @@ -670,7 +670,7 @@ def test_composer_version_helpers_and_lookup_errors(monkeypatch): assert workflow.id == "workflow-1" -def test_composer_current_version_and_error_paths(monkeypatch): +def test_composer_current_version_and_error_paths(monkeypatch: pytest.MonkeyPatch): fake_session = FakeSession(scalar=[2]) monkeypatch.setattr(composer_service.db, "session", fake_session) payload = ComposerSavePayload.model_validate( @@ -717,7 +717,7 @@ def test_composer_current_version_and_error_paths(monkeypatch): ) -def test_roster_list_and_invite_options(monkeypatch): +def test_roster_list_and_invite_options(monkeypatch: pytest.MonkeyPatch): created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) updated_at = datetime(2026, 1, 3, 3, 4, 5, tzinfo=UTC) version_created_at = datetime(2026, 1, 4, 3, 4, 5, tzinfo=UTC) @@ -790,7 +790,7 @@ def test_roster_list_and_invite_options(monkeypatch): assert invited["data"][0]["existing_node_ids"] == ["node-1"] -def test_invite_options_uses_db_filtered_pagination(monkeypatch): +def test_invite_options_uses_db_filtered_pagination(monkeypatch: pytest.MonkeyPatch): configured_agent = Agent( id="agent-2", tenant_id="tenant-1", @@ -915,7 +915,7 @@ def test_published_references_include_app_display_fields_and_sort_by_updated_at( assert references[0]["workflow_version"] == "published-recent" -def test_roster_update_archive_versions_and_detail(monkeypatch): +def test_roster_update_archive_versions_and_detail(monkeypatch: pytest.MonkeyPatch): listed_version = AgentConfigSnapshot(id="version-2", agent_id="agent-1", version=2) listed_version_created_at = datetime(2026, 1, 5, 3, 4, 5, tzinfo=UTC) listed_version.created_at = listed_version_created_at @@ -973,7 +973,7 @@ def test_roster_update_archive_versions_and_detail(monkeypatch): assert detail["revisions"][0]["created_at"] == int(revision_created_at.timestamp()) -def test_roster_create_detail_and_lookup_helpers(monkeypatch): +def test_roster_create_detail_and_lookup_helpers(monkeypatch: pytest.MonkeyPatch): fake_session = FakeSession( scalar=[ SimpleNamespace(id="agent-1"), @@ -1040,9 +1040,7 @@ def test_agent_app_visible_versions_exclude_draft_saves(): def test_app_list_all_excludes_agent_apps_by_default(): filters = AppService._build_app_list_filters( - "account-1", - "tenant-1", - AppListParams(mode="all"), + "account-1", "tenant-1", AppListParams(mode="all"), FakeSession(scalar=None, scalars=None) ) sql = " ".join(str(filter_) for filter_ in filters) @@ -2173,7 +2171,7 @@ class TestWorkflowAgentDraftBindingSync: assert session.deleted == [stale_binding] -def test_dataset_rows_filters_malformed_ids(monkeypatch): +def test_dataset_rows_filters_malformed_ids(monkeypatch: pytest.MonkeyPatch): """Mention ids are user-editable text: a non-UUID id must read as missing (placeholder semantics), never reach the UUID-typed dataset query (E2E 500).""" captured = {} @@ -2197,7 +2195,7 @@ def test_dataset_rows_filters_malformed_ids(monkeypatch): assert captured == {} -def test_workspace_dify_tools_returns_provider_and_tool_granularities(monkeypatch): +def test_workspace_dify_tools_returns_provider_and_tool_granularities(monkeypatch: pytest.MonkeyPatch): """The slash-menu Tools tab needs both selection granularities: a provider hosts many tools (like an MCP server), so candidates return one provider-level entry (id = /*, = all tools) plus one per tool.""" @@ -2268,7 +2266,7 @@ def _patch_drive_keys(monkeypatch, existing_keys): return captured -def test_drive_ref_findings_reports_missing_keys(monkeypatch): +def test_drive_ref_findings_reports_missing_keys(monkeypatch: pytest.MonkeyPatch): _patch_drive_keys(monkeypatch, existing_keys=["tender-analyzer/SKILL.md"]) findings = AgentComposerService._drive_ref_findings( @@ -2279,7 +2277,7 @@ def test_drive_ref_findings_reports_missing_keys(monkeypatch): assert str(findings[0]["message"]).startswith("file_ref_dangling: ") -def test_drive_ref_findings_clean_when_all_keys_exist(monkeypatch): +def test_drive_ref_findings_clean_when_all_keys_exist(monkeypatch: pytest.MonkeyPatch): _patch_drive_keys(monkeypatch, existing_keys=["tender-analyzer/SKILL.md", "files/sample.pdf"]) assert ( @@ -2288,7 +2286,7 @@ def test_drive_ref_findings_clean_when_all_keys_exist(monkeypatch): ) -def test_drive_ref_findings_skips_refs_without_drive_keys(monkeypatch): +def test_drive_ref_findings_skips_refs_without_drive_keys(monkeypatch: pytest.MonkeyPatch): # No drive-backed ref at all -> no DB roundtrip, no findings. soul = _drive_soul( skills_files={"skills": [{"id": "legacy", "name": "Legacy"}], "files": [{"name": "u.pdf", "file_id": "u-1"}]} @@ -2297,7 +2295,7 @@ def test_drive_ref_findings_skips_refs_without_drive_keys(monkeypatch): assert findings == [] -def test_require_drive_refs_resolved_raises_with_stable_code(monkeypatch): +def test_require_drive_refs_resolved_raises_with_stable_code(monkeypatch: pytest.MonkeyPatch): from services.agent.errors import InvalidComposerConfigError _patch_drive_keys(monkeypatch, existing_keys=[]) @@ -2308,7 +2306,7 @@ def test_require_drive_refs_resolved_raises_with_stable_code(monkeypatch): ) -def test_collect_validation_findings_appends_drive_findings_with_agent_context(monkeypatch): +def test_collect_validation_findings_appends_drive_findings_with_agent_context(monkeypatch: pytest.MonkeyPatch): from services.entities.agent_entities import ComposerSavePayload _patch_drive_keys(monkeypatch, existing_keys=[]) @@ -2334,7 +2332,7 @@ def test_collect_validation_findings_appends_drive_findings_with_agent_context(m # ── ENG-625 D5: soul-first ref removal ─────────────────────────────────────── -def _patch_remove_drive_refs_env(monkeypatch, *, soul_dict): +def _patch_remove_drive_refs_env(monkeypatch: pytest.MonkeyPatch, *, soul_dict): """Wire the classmethod's collaborators so soul editing + versioning is observable.""" from types import SimpleNamespace @@ -2359,7 +2357,7 @@ def _patch_remove_drive_refs_env(monkeypatch, *, soul_dict): return agent, captured, committed -def test_remove_drive_refs_drops_skill_by_slug_and_versions(monkeypatch): +def test_remove_drive_refs_drops_skill_by_slug_and_versions(monkeypatch: pytest.MonkeyPatch): soul_dict = { "skills_files": { "skills": [ @@ -2383,7 +2381,7 @@ def test_remove_drive_refs_drops_skill_by_slug_and_versions(monkeypatch): assert committed.get("committed") is True -def test_remove_drive_refs_is_noop_when_ref_absent(monkeypatch): +def test_remove_drive_refs_is_noop_when_ref_absent(monkeypatch: pytest.MonkeyPatch): soul_dict = {"skills_files": {"skills": [], "files": []}} agent, captured, committed = _patch_remove_drive_refs_env(monkeypatch, soul_dict=soul_dict) @@ -2397,7 +2395,7 @@ def test_remove_drive_refs_is_noop_when_ref_absent(monkeypatch): assert committed == {} -def test_remove_drive_refs_drops_file_by_key(monkeypatch): +def test_remove_drive_refs_drops_file_by_key(monkeypatch: pytest.MonkeyPatch): soul_dict = { "skills_files": { "skills": [], @@ -2417,7 +2415,7 @@ def test_remove_drive_refs_drops_file_by_key(monkeypatch): assert [f.drive_key for f in captured["agent_soul"].skills_files.files] == ["files/keep.pdf"] -def test_add_drive_file_ref_adds_or_replaces_file_and_versions(monkeypatch): +def test_add_drive_file_ref_adds_or_replaces_file_and_versions(monkeypatch: pytest.MonkeyPatch): soul_dict = { "skills_files": { "skills": [], @@ -2444,7 +2442,7 @@ def test_add_drive_file_ref_adds_or_replaces_file_and_versions(monkeypatch): assert committed.get("committed") is True -def test_add_drive_file_ref_syncs_workflow_binding_snapshot(monkeypatch): +def test_add_drive_file_ref_syncs_workflow_binding_snapshot(monkeypatch: pytest.MonkeyPatch): binding = SimpleNamespace(agent_id="agent-1", current_snapshot_id="snap-1", updated_by=None) _patch_remove_drive_refs_env(monkeypatch, soul_dict={"skills_files": {"skills": [], "files": []}}) monkeypatch.setattr( @@ -2473,7 +2471,7 @@ def test_remove_drive_refs_requires_exactly_one_scope(): # ── ENG-623/625: resolver helpers + save-path drive guard ──────────────────── -def test_resolve_bound_agent_id_queries_active_roster_agent(monkeypatch): +def test_resolve_bound_agent_id_queries_active_roster_agent(monkeypatch: pytest.MonkeyPatch): from types import SimpleNamespace import services.agent.composer_service as module @@ -2482,7 +2480,7 @@ def test_resolve_bound_agent_id_queries_active_roster_agent(monkeypatch): assert AgentComposerService.resolve_bound_agent_id(tenant_id="t-1", app_id="app-1") == "agent-9" -def test_resolve_workflow_node_agent_id_degrades_without_workflow_or_binding(monkeypatch): +def test_resolve_workflow_node_agent_id_degrades_without_workflow_or_binding(monkeypatch: pytest.MonkeyPatch): from types import SimpleNamespace def boom(cls, **kwargs): @@ -2505,7 +2503,7 @@ def test_resolve_workflow_node_agent_id_degrades_without_workflow_or_binding(mon assert AgentComposerService.resolve_workflow_node_agent_id(tenant_id="t", app_id="a", node_id="n") == "agent-7" -def test_remove_drive_refs_returns_none_without_agent_or_snapshot(monkeypatch): +def test_remove_drive_refs_returns_none_without_agent_or_snapshot(monkeypatch: pytest.MonkeyPatch): from types import SimpleNamespace import services.agent.composer_service as module @@ -2518,7 +2516,7 @@ def test_remove_drive_refs_returns_none_without_agent_or_snapshot(monkeypatch): assert AgentComposerService.remove_drive_refs(tenant_id="t", agent_id="a", account_id="u", skill_slug="s") is None -def test_save_workflow_composer_guards_drive_refs_for_existing_agent_strategies(monkeypatch): +def test_save_workflow_composer_guards_drive_refs_for_existing_agent_strategies(monkeypatch: pytest.MonkeyPatch): from types import SimpleNamespace from services.entities.agent_entities import ComposerSavePayload diff --git a/api/tests/unit_tests/services/test_snippet_service.py b/api/tests/unit_tests/services/test_snippet_service.py index 7cbe773e419..008b6cfa418 100644 --- a/api/tests/unit_tests/services/test_snippet_service.py +++ b/api/tests/unit_tests/services/test_snippet_service.py @@ -94,14 +94,15 @@ def test_validate_snippet_graph_forbidden_nodes_raises_with_node_details() -> No def test_get_snippets_returns_empty_when_tag_filter_has_no_targets(monkeypatch: pytest.MonkeyPatch) -> None: + session = _SessionWithoutNameLookup() get_target_ids = Mock(return_value=[]) monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", get_target_ids) service = SnippetService.__new__(SnippetService) - result = service.get_snippets(tenant_id="tenant-1", tag_ids=["tag-1"]) + result = service.get_snippets(tenant_id="tenant-1", session=session, tag_ids=["tag-1"]) assert result == ([], 0, False) - get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True) + get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], session, match_all=True) def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPatch) -> None: @@ -124,6 +125,7 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa result, total, has_more = service.get_snippets( tenant_id="tenant-1", + session=session, page=2, limit=2, keyword="search", @@ -135,7 +137,7 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa assert result == snippets[:2] assert total == 3 assert has_more is True - get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True) + get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], session, match_all=True) session.scalar.assert_called_once() session.scalars.assert_called_once() diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py index 282b32a7e55..73df7cc2673 100644 --- a/api/tests/unit_tests/services/test_tag_service.py +++ b/api/tests/unit_tests/services/test_tag_service.py @@ -1,6 +1,7 @@ from types import SimpleNamespace import pytest +from pytest_mock import MockerFixture from werkzeug.exceptions import NotFound from models.enums import TagType @@ -8,19 +9,19 @@ from services.tag_service import TagBindingCreatePayload, TagBindingDeletePayloa @pytest.fixture -def current_user(mocker): +def current_user(mocker: MockerFixture): user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1") mocker.patch("services.tag_service.current_user", user) return user @pytest.fixture -def db_session(mocker): - mock_db = mocker.patch("services.tag_service.db") +def db_session(mocker: MockerFixture): + mock_db = mocker.Mock() return mock_db.session -def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker, current_user, db_session): +def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker: MockerFixture, current_user, db_session): mocker.patch("services.tag_service.TagService.check_target_exists") db_session.scalars.return_value.all.return_value = ["tag-1"] db_session.scalar.return_value = None @@ -30,7 +31,8 @@ def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker, c tag_ids=["tag-1", "tag-from-other-tenant"], target_id="snippet-1", type=TagType.SNIPPET, - ) + ), + db_session, ) db_session.add.assert_called_once() @@ -42,7 +44,7 @@ def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker, c db_session.commit.assert_called_once() -def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker, current_user, db_session): +def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker: MockerFixture, current_user, db_session): mocker.patch("services.tag_service.TagService.check_target_exists") db_session.execute.return_value = SimpleNamespace(rowcount=1) @@ -51,14 +53,15 @@ def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker, curren tag_ids=["tag-1", "tag-from-other-tenant"], target_id="snippet-1", type=TagType.SNIPPET, - ) + ), + db_session, ) db_session.execute.assert_called_once() db_session.commit.assert_called_once() -def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current_user, db_session): +def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker: MockerFixture, current_user, db_session): mocker.patch("services.tag_service.TagService.check_target_exists") db_session.execute.return_value = SimpleNamespace(rowcount=0) @@ -67,7 +70,8 @@ def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current tag_ids=["tag-1"], target_id="snippet-1", type=TagType.SNIPPET, - ) + ), + db_session, ) db_session.execute.assert_called_once() @@ -75,7 +79,7 @@ def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current def test_get_target_ids_by_tag_ids_returns_empty_without_query_for_empty_input(db_session): - result = TagService.get_target_ids_by_tag_ids(TagType.SNIPPET, "tenant-1", []) + result = TagService.get_target_ids_by_tag_ids(TagType.SNIPPET, "tenant-1", [], db_session) assert result == [] db_session.scalars.assert_not_called() @@ -84,7 +88,7 @@ def test_get_target_ids_by_tag_ids_returns_empty_without_query_for_empty_input(d def test_check_target_exists_accepts_existing_snippet(current_user, db_session): db_session.scalar.return_value = SimpleNamespace(id="snippet-1") - TagService.check_target_exists("snippet", "snippet-1") + TagService.check_target_exists("snippet", "snippet-1", db_session) db_session.scalar.assert_called_once() @@ -93,11 +97,11 @@ def test_check_target_exists_raises_when_snippet_missing(current_user, db_sessio db_session.scalar.return_value = None with pytest.raises(NotFound, match="Snippet not found"): - TagService.check_target_exists("snippet", "missing-snippet") + TagService.check_target_exists("snippet", "missing-snippet", db_session) def test_check_target_exists_raises_for_invalid_binding_type(current_user, db_session): with pytest.raises(NotFound, match="Invalid binding type"): - TagService.check_target_exists("unknown", "target-1") + TagService.check_target_exists("unknown", "target-1", db_session) db_session.scalar.assert_not_called()