chore: example of make db.session pass from parameter. (#37561)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-06-18 11:16:09 +09:00 committed by GitHub
parent 0fa43973b8
commit 4304044905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 274 additions and 206 deletions

View File

@ -286,7 +286,7 @@ class AgentAppListApi(Resource):
status="normal",
)
app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params)
app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params, db.session)
if app_pagination is None:
empty = AgentAppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json")

View File

@ -594,7 +594,7 @@ class AppListApi(Resource):
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params, db.session)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
@ -661,7 +661,7 @@ class StarredAppListApi(Resource):
is_created_by_me=args.is_created_by_me,
)
app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params)
app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params, db.session)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200

View File

@ -409,6 +409,7 @@ class DatasetListApi(Resource):
datasets, total = DatasetService.get_datasets(
query.page,
query.limit,
db.session,
current_tenant_id,
current_user,
query.keyword,

View File

@ -122,7 +122,7 @@ class TagListApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type), db.session)
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -146,9 +146,9 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden()
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str)
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str, db.session)
binding_count = TagService.get_tag_binding_count(tag_id_str)
binding_count = TagService.get_tag_binding_count(tag_id_str, db.session)
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
@ -164,7 +164,7 @@ class TagUpdateDeleteApi(Resource):
def delete(self, tag_id: UUID):
tag_id_str = str(tag_id)
TagService.delete_tag(tag_id_str)
TagService.delete_tag(tag_id_str, db.session)
return "", 204
@ -189,7 +189,8 @@ def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
),
db.session,
)
return {"result": "success"}, 200
@ -203,7 +204,8 @@ def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
),
db.session,
)
return {"result": "success"}, 200

View File

@ -126,6 +126,7 @@ class CustomizedSnippetsApi(Resource):
snippet_service = _snippet_service()
snippets, total, has_more = snippet_service.get_snippets(
tenant_id=current_tenant_id,
session=db.session,
page=query.page,
limit=query.limit,
keyword=query.keyword,

View File

@ -174,7 +174,7 @@ class AppListApi(Resource):
tag_ids: list[str] | None = None
if query.tag:
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag, db.session)
if not tags:
return empty
tag_ids = [tag.id for tag in tags]
@ -191,7 +191,7 @@ class AppListApi(Resource):
openapi_visible=True,
)
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params, db.session)
if pagination is None:
return empty

View File

@ -770,7 +770,7 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden()
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE), db.session)
response = dump_response(
KnowledgeTagResponse,
@ -808,9 +808,9 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
tag_id = payload.tag_id
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id)
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id, db.session)
binding_count = TagService.get_tag_binding_count(tag_id)
binding_count = TagService.get_tag_binding_count(tag_id, db.session)
response = dump_response(
KnowledgeTagResponse,
@ -840,7 +840,7 @@ class DatasetTagsApi(DatasetApiResource):
def delete(self, _):
"""Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
TagService.delete_tag(payload.tag_id, db.session)
return "", 204
@ -873,7 +873,8 @@ class DatasetTagBindingApi(DatasetApiResource):
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE),
db.session,
)
return "", 204
@ -907,7 +908,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE),
db.session,
)
return "", 204
@ -942,6 +944,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
dataset_id = kwargs.get("dataset_id")
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
tags = TagService.get_tags_by_target_id(
"knowledge", current_user.current_tenant_id, str(dataset_id), db.session
)
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200

View File

@ -75,7 +75,7 @@ class CreateAppParams(BaseModel):
class AppService:
@staticmethod
def _build_app_list_filters(
user_id: str, tenant_id: str, params: AppListBaseParams
user_id: str, tenant_id: str, params: AppListBaseParams, session: scoped_session
) -> list[sa.ColumnElement[bool]]:
filters = [App.tenant_id == tenant_id, App.is_universal == False]
@ -115,7 +115,7 @@ class AppService:
escaped_name = escape_like_pattern(name)
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
if params.tag_ids and len(params.tag_ids) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, match_all=True)
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, session, match_all=True)
if target_ids and len(target_ids) > 0:
filters.append(App.id.in_(target_ids))
else:
@ -197,7 +197,9 @@ class AppService:
).scalars()
)
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
def get_paginate_apps(
self, user_id: str, tenant_id: str, params: AppListParams, session: scoped_session
) -> Pagination | None:
"""
Get app list with pagination, filters, and explicit sort order.
:param user_id: user id
@ -205,7 +207,7 @@ class AppService:
:param params: query parameters
:return:
"""
filters = self._build_app_list_filters(user_id, tenant_id, params)
filters = self._build_app_list_filters(user_id, tenant_id, params, session)
if not filters:
return None
@ -231,12 +233,12 @@ class AppService:
return app_models
def get_paginate_starred_apps(
self, user_id: str, tenant_id: str, params: StarredAppListParams
self, user_id: str, tenant_id: str, params: StarredAppListParams, session: scoped_session
) -> Pagination | None:
"""
Get apps starred by the current account with pagination, filters, and explicit sort order.
"""
filters = self._build_app_list_filters(user_id, tenant_id, params)
filters = self._build_app_list_filters(user_id, tenant_id, params, session)
if not filters:
return None

View File

@ -13,7 +13,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from redis.exceptions import LockNotOwnedError
from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
@ -235,7 +235,9 @@ class _EstimateArgs(BaseModel):
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
def get_datasets(
page, per_page, session: scoped_session, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False
):
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
if user:
@ -295,6 +297,7 @@ class DatasetService:
"knowledge",
tenant_id,
tag_ids,
session,
match_all=True,
)
else:

View File

@ -6,7 +6,7 @@ from datetime import UTC, datetime
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from core.db import session_factory
from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -192,6 +192,7 @@ class SnippetService:
self,
*,
tenant_id: str,
session: scoped_session,
page: int = 1,
limit: int = 20,
keyword: str | None = None,
@ -229,20 +230,19 @@ class SnippetService:
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, match_all=True)
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, session, match_all=True)
if target_ids:
stmt = stmt.where(CustomizedSnippet.id.in_(target_ids))
else:
return [], 0, False
with self._session_scope() as session:
# Get total count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = session.scalar(count_stmt) or 0
# Get total count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = session.scalar(count_stmt) or 0
# Apply pagination
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
snippets = list(session.scalars(stmt).all())
# Apply pagination
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
snippets = list(session.scalars(stmt).all())
has_more = len(snippets) > limit
if has_more:

View File

@ -6,10 +6,9 @@ from flask_login import current_user
from pydantic import BaseModel, Field
from sqlalchemy import delete, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, scoped_session
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import TagType
from models.model import App, Tag, TagBinding
@ -56,7 +55,7 @@ class TagService:
@staticmethod
def get_target_ids_by_tag_ids(
tag_type: str, current_tenant_id: str, tag_ids: list[str], *, match_all: bool = False
tag_type: str, current_tenant_id: str, tag_ids: list[str], session: scoped_session, *, match_all: bool = False
):
"""
Return target IDs bound to tags for the given tenant and resource type.
@ -70,7 +69,7 @@ class TagService:
return []
# Deduplicate repeated query params so match_all counts each requested tag once.
requested_tag_ids = list(dict.fromkeys(tag_ids))
tags = db.session.scalars(
tags = session.scalars(
select(Tag.id).where(
Tag.id.in_(requested_tag_ids),
Tag.tenant_id == current_tenant_id,
@ -86,13 +85,13 @@ class TagService:
if match_all:
if len(tag_ids) != len(requested_tag_ids):
return []
return db.session.scalars(
return session.scalars(
select(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.group_by(TagBinding.target_id)
.having(func.count(sa.distinct(TagBinding.tag_id)) == len(tag_ids))
).all()
tag_bindings = db.session.scalars(
tag_bindings = session.scalars(
select(TagBinding.target_id).where(
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
)
@ -100,11 +99,11 @@ class TagService:
return tag_bindings
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str, session: scoped_session):
if not tag_type or not tag_name:
return []
tags = list(
db.session.scalars(
session.scalars(
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
)
@ -113,8 +112,8 @@ class TagService:
return tags
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
tags = db.session.scalars(
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str, session: scoped_session):
tags = session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
@ -128,8 +127,8 @@ class TagService:
return tags or []
@staticmethod
def save_tags(payload: SaveTagPayload) -> Tag:
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name):
def save_tags(payload: SaveTagPayload, session: scoped_session) -> Tag:
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name, session):
raise ValueError("Tag name already exists")
tag = Tag(
name=payload.name,
@ -138,17 +137,17 @@ class TagService:
tenant_id=current_user.current_tenant_id,
)
tag.id = str(uuid.uuid4())
db.session.add(tag)
db.session.commit()
session.add(tag)
session.commit()
return tag
@staticmethod
def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag:
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
def update_tags(payload: UpdateTagPayload, tag_id: str, session: scoped_session) -> Tag:
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
if not tag:
raise NotFound("Tag not found")
if payload.name != tag.name:
existing = db.session.scalar(
existing = session.scalar(
select(Tag)
.where(
Tag.name == payload.name,
@ -161,31 +160,31 @@ class TagService:
if existing:
raise ValueError("Tag name already exists")
tag.name = payload.name
db.session.commit()
session.commit()
return tag
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
count = db.session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
def get_tag_binding_count(tag_id: str, session: scoped_session) -> int:
count = session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
return count
@staticmethod
def delete_tag(tag_id: str):
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
def delete_tag(tag_id: str, session: scoped_session):
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
session.delete(tag)
# delete tag binding
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
tag_bindings = session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)
db.session.commit()
session.delete(tag_binding)
session.commit()
@staticmethod
def save_tag_binding(payload: TagBindingCreatePayload):
TagService.check_target_exists(payload.type, payload.target_id)
valid_tag_ids = db.session.scalars(
def save_tag_binding(payload: TagBindingCreatePayload, session: scoped_session):
TagService.check_target_exists(payload.type, payload.target_id, session)
valid_tag_ids = session.scalars(
select(Tag.id).where(
Tag.id.in_(payload.tag_ids),
Tag.tenant_id == current_user.current_tenant_id,
@ -193,7 +192,7 @@ class TagService:
)
).all()
for tag_id in valid_tag_ids:
tag_binding = db.session.scalar(
tag_binding = session.scalar(
select(TagBinding)
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id)
.limit(1)
@ -206,15 +205,15 @@ class TagService:
tenant_id=current_user.current_tenant_id,
created_by=current_user.id,
)
db.session.add(new_tag_binding)
db.session.commit()
session.add(new_tag_binding)
session.commit()
@staticmethod
def delete_tag_binding(payload: TagBindingDeletePayload):
TagService.check_target_exists(payload.type, payload.target_id)
def delete_tag_binding(payload: TagBindingDeletePayload, session: scoped_session):
TagService.check_target_exists(payload.type, payload.target_id, session)
result = cast(
CursorResult,
db.session.execute(
session.execute(
delete(TagBinding).where(
TagBinding.target_id == payload.target_id,
TagBinding.tag_id.in_(payload.tag_ids),
@ -230,12 +229,12 @@ class TagService:
)
if result.rowcount:
db.session.commit()
session.commit()
@staticmethod
def check_target_exists(type: str, target_id: str):
def check_target_exists(type: str, target_id: str, session: scoped_session):
if type == "knowledge":
dataset = db.session.scalar(
dataset = session.scalar(
select(Dataset)
.where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
.limit(1)
@ -243,13 +242,13 @@ class TagService:
if not dataset:
raise NotFound("Dataset not found")
elif type == "app":
app = db.session.scalar(
app = session.scalar(
select(App).where(App.tenant_id == current_user.current_tenant_id, App.id == target_id).limit(1)
)
if not app:
raise NotFound("App not found")
elif type == "snippet":
snippet = db.session.scalar(
snippet = session.scalar(
select(CustomizedSnippet)
.where(CustomizedSnippet.tenant_id == current_user.current_tenant_id, CustomizedSnippet.id == target_id)
.limit(1)

View File

@ -16,7 +16,7 @@ since these test controller-level behavior.
import uuid
from contextlib import ExitStack
from datetime import UTC, datetime
from unittest.mock import Mock, PropertyMock, patch
from unittest.mock import ANY, Mock, PropertyMock, patch
import pytest
from flask import Flask
@ -1129,7 +1129,7 @@ class TestDatasetTagsApiPatch:
assert status == 200
assert response == {"id": "tag-1", "name": "Updated Tag", "type": "knowledge", "binding_count": "5"}
mock_tag_svc.update_tags.assert_called_once()
update_payload, tag_id = mock_tag_svc.update_tags.call_args.args
update_payload, tag_id, session = mock_tag_svc.update_tags.call_args.args
assert update_payload.name == "Updated Tag"
assert tag_id == "tag-1"
@ -1184,7 +1184,7 @@ class TestDatasetTagsApiDelete:
result = api.delete(_=None)
assert result == ("", 204)
mock_tag_svc.delete_tag.assert_called_once_with("tag-1")
mock_tag_svc.delete_tag.assert_called_once_with("tag-1", ANY)
@patch("libs.login.current_user")
def test_delete_tag_forbidden(self, mock_current_user, app: Flask):
@ -1233,7 +1233,7 @@ class TestDatasetTagsBindingStatusApi:
assert status_code == 200
assert response["data"] == [{"id": "tag_1", "name": "Test Tag"}]
assert response["total"] == 1
mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123")
mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123", ANY)
class TestDatasetTagBindingApiPost:
@ -1266,7 +1266,8 @@ class TestDatasetTagBindingApiPost:
from services.tag_service import TagBindingCreatePayload
mock_tag_svc.save_tag_binding.assert_called_once_with(
TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
ANY,
)
@patch("controllers.service_api.dataset.dataset.current_user")
@ -1317,7 +1318,8 @@ class TestDatasetTagUnbindingApiPost:
from services.tag_service import TagBindingDeletePayload
mock_tag_svc.delete_tag_binding.assert_called_once_with(
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
ANY,
)
@patch("controllers.service_api.dataset.dataset.TagService")
@ -1347,7 +1349,8 @@ class TestDatasetTagUnbindingApiPost:
from services.tag_service import TagBindingDeletePayload
mock_tag_svc.delete_tag_binding.assert_called_once_with(
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
ANY,
)
@patch("controllers.service_api.dataset.dataset.current_user")

View File

@ -234,7 +234,7 @@ class TestAppService:
# Get paginated apps
params = AppListParams(page=1, limit=10, mode="chat")
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
# Verify pagination results
assert paginated_apps is not None
@ -295,7 +295,7 @@ class TestAppService:
db_session_with_containers.commit()
last_modified_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert last_modified_apps is not None
assert [app.name for app in last_modified_apps.items] == [
@ -305,7 +305,10 @@ class TestAppService:
]
recently_created_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created")
account.id,
tenant.id,
AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"),
db_session_with_containers,
)
assert recently_created_apps is not None
assert [app.name for app in recently_created_apps.items] == [
@ -315,7 +318,10 @@ class TestAppService:
]
earliest_created_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created")
account.id,
tenant.id,
AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"),
db_session_with_containers,
)
assert earliest_created_apps is not None
assert [app.name for app in earliest_created_apps.items] == [
@ -366,7 +372,7 @@ class TestAppService:
assert star_count == 1
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert paginated_apps is not None
starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items}
@ -377,7 +383,7 @@ class TestAppService:
db_session_with_containers.commit()
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert paginated_apps is not None
starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items}
@ -442,7 +448,7 @@ class TestAppService:
db_session_with_containers.commit()
last_modified_apps = app_service.get_paginate_starred_apps(
account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat")
account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert last_modified_apps is not None
assert [app.name for app in last_modified_apps.items] == [
@ -457,6 +463,7 @@ class TestAppService:
account.id,
tenant.id,
StarredAppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"),
db_session_with_containers,
)
assert recently_created_apps is not None
assert [app.name for app in recently_created_apps.items] == [
@ -469,6 +476,7 @@ class TestAppService:
account.id,
tenant.id,
StarredAppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"),
db_session_with_containers,
)
assert earliest_created_apps is not None
assert [app.name for app in earliest_created_apps.items] == [
@ -522,20 +530,25 @@ class TestAppService:
completion_app = app_service.create_app(tenant.id, completion_app_params, account)
# Test filter by mode
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"))
chat_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert len(chat_apps.items) == 1
assert chat_apps.items[0].mode == "chat"
# Test filter by name
filtered_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat")
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat"), db_session_with_containers
)
assert len(filtered_apps.items) == 1
assert "Chat" in filtered_apps.items[0].name
# Test filter by created_by_me
my_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True)
account.id,
tenant.id,
AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True),
db_session_with_containers,
)
assert len(my_apps.items) == 1
@ -588,6 +601,7 @@ class TestAppService:
first_account.id,
tenant.id,
AppListParams(page=1, limit=10, mode="chat", creator_ids=[second_account.id]),
db_session_with_containers,
)
assert filtered_apps is not None
@ -635,10 +649,12 @@ class TestAppService:
# Test with tag filter
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["tag1", "tag2"])
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
# Verify tag service was called
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"], match_all=True)
mock_tag_service.assert_called_once_with(
"app", tenant.id, ["tag1", "tag2"], db_session_with_containers, match_all=True
)
# Verify results
assert paginated_apps is not None
@ -651,7 +667,7 @@ class TestAppService:
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["nonexistent_tag"])
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
# Should return None when no apps match tag filter
assert paginated_apps is None
@ -1467,7 +1483,7 @@ class TestAppService:
# Test 1: Search with % character
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers
)
assert paginated_apps is not None
assert paginated_apps.total == 1
@ -1476,7 +1492,10 @@ class TestAppService:
# Test 2: Search with _ character
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="test_data", mode="chat", page=1, limit=10)
account.id,
tenant.id,
AppListParams(name="test_data", mode="chat", page=1, limit=10),
db_session_with_containers,
)
assert paginated_apps is not None
assert paginated_apps.total == 1
@ -1485,7 +1504,10 @@ class TestAppService:
# Test 3: Search with \ character
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10)
account.id,
tenant.id,
AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10),
db_session_with_containers,
)
assert paginated_apps is not None
assert paginated_apps.total == 1
@ -1494,7 +1516,7 @@ class TestAppService:
# Test 4: Search with % should NOT match 100% (verifies escaping works)
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers
)
assert paginated_apps is not None
assert paginated_apps.total == 1

View File

@ -227,7 +227,7 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id)
datasets, total = DatasetService.get_datasets(page, per_page, db_session_with_containers, tenant_id=tenant.id)
# Assert
assert len(datasets) == 5
@ -257,7 +257,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, search=search)
datasets, total = DatasetService.get_datasets(
page, per_page, db_session_with_containers, tenant_id=tenant.id, search=search
)
# Assert
assert len(datasets) == 1
@ -301,7 +303,9 @@ class TestDatasetServiceGetDatasets:
tag_ids = [tag_1.id, tag_2.id]
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
datasets, total = DatasetService.get_datasets(
page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids
)
# Assert
assert len(datasets) == 1
@ -326,7 +330,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
datasets, total = DatasetService.get_datasets(
page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids
)
# Assert
# When tag_ids is empty, tag filtering is skipped, so normal query results are returned
@ -356,7 +362,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, user=None)
datasets, total = DatasetService.get_datasets(
page, per_page, db_session_with_containers, tenant_id=tenant.id, user=None
)
# Assert
assert len(datasets) == 1
@ -384,6 +392,7 @@ class TestDatasetServiceGetDatasets:
datasets, total = DatasetService.get_datasets(
page=1,
per_page=20,
session=db_session_with_containers,
tenant_id=tenant.id,
user=owner,
include_all=True,
@ -408,7 +417,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
)
# Assert
assert len(datasets) == 1
@ -432,7 +443,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
)
# Assert
assert len(datasets) == 1
@ -459,7 +472,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
)
# Assert
assert len(datasets) == 1
@ -486,7 +501,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator
)
# Assert
assert len(datasets) == 1
@ -509,7 +526,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator
)
# Assert
assert datasets == []

View File

@ -449,7 +449,7 @@ class TestTagService:
# Act: Execute the method under test
tag_ids = [tag.id for tag in tags]
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids)
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -485,7 +485,7 @@ class TestTagService:
)
# Act: Execute the method under test with empty tag IDs
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [])
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [], db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -533,13 +533,19 @@ class TestTagService:
# Act: Execute the method under test
tag_ids = [tag.id for tag in tags]
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, match_all=True)
result = TagService.get_target_ids_by_tag_ids(
"knowledge", tenant.id, tag_ids, db_session_with_containers, match_all=True
)
# Assert: Verify the expected outcomes
assert result == [dataset_with_all_tags.id]
missing_tag_result = TagService.get_target_ids_by_tag_ids(
"knowledge", tenant.id, [tags[0].id, str(uuid.uuid4())], match_all=True
"knowledge",
tenant.id,
[tags[0].id, str(uuid.uuid4())],
db_session_with_containers,
match_all=True,
)
assert missing_tag_result == []
@ -565,7 +571,9 @@ class TestTagService:
non_existent_tag_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# Act: Execute the method under test
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, non_existent_tag_ids)
result = TagService.get_target_ids_by_tag_ids(
"knowledge", tenant.id, non_existent_tag_ids, db_session_with_containers
)
# Assert: Verify the expected outcomes
assert result is not None
@ -599,7 +607,7 @@ class TestTagService:
db_session_with_containers.commit()
# Act: Execute the method under test
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag")
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag", db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -625,7 +633,7 @@ class TestTagService:
)
# Act: Execute the method under test with non-existent tag name
result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag")
result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag", db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -650,8 +658,8 @@ class TestTagService:
)
# Act: Execute the method under test with empty parameters
result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag")
result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "")
result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag", db_session_with_containers)
result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "", db_session_with_containers)
# Assert: Verify the expected outcomes
assert result_empty_type is not None
@ -688,7 +696,7 @@ class TestTagService:
)
# Act: Execute the method under test
result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -720,7 +728,7 @@ class TestTagService:
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Act: Execute the method under test
result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -745,7 +753,7 @@ class TestTagService:
tag_args = SaveTagPayload(name="test_tag_name", type="knowledge")
# Act: Execute the method under test
result = TagService.save_tags(tag_args)
result = TagService.save_tags(tag_args, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -783,11 +791,11 @@ class TestTagService:
# Create first tag
tag_args = SaveTagPayload(name="duplicate_tag", type="app")
TagService.save_tags(tag_args)
TagService.save_tags(tag_args, db_session_with_containers)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
TagService.save_tags(tag_args)
TagService.save_tags(tag_args, db_session_with_containers)
assert "Tag name already exists" in str(exc_info.value)
def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
@ -807,13 +815,13 @@ class TestTagService:
# Create a tag to update
tag_args = SaveTagPayload(name="original_name", type="knowledge")
tag = TagService.save_tags(tag_args)
tag = TagService.save_tags(tag_args, db_session_with_containers)
# Update args
update_args = UpdateTagPayload(name="updated_name")
# Act: Execute the method under test
result = TagService.update_tags(update_args, tag.id)
result = TagService.update_tags(update_args, tag.id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -854,7 +862,7 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.update_tags(update_args, non_existent_tag_id)
TagService.update_tags(update_args, non_existent_tag_id, db_session_with_containers)
assert "Tag not found" in str(exc_info.value)
def test_update_tags_duplicate_name_error(
@ -875,17 +883,17 @@ class TestTagService:
# Create two tags
tag1_args = SaveTagPayload(name="first_tag", type="app")
tag1 = TagService.save_tags(tag1_args)
tag1 = TagService.save_tags(tag1_args, db_session_with_containers)
tag2_args = SaveTagPayload(name="second_tag", type="app")
tag2 = TagService.save_tags(tag2_args)
tag2 = TagService.save_tags(tag2_args, db_session_with_containers)
# Try to update second tag with first tag's name
update_args = UpdateTagPayload(name="first_tag")
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
TagService.update_tags(update_args, tag2.id)
TagService.update_tags(update_args, tag2.id, db_session_with_containers)
assert "Tag name already exists" in str(exc_info.value)
def test_get_tag_binding_count_success(
@ -917,8 +925,8 @@ class TestTagService:
)
# Act: Execute the method under test
result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id)
result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id)
result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id, db_session_with_containers)
result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result_tag_with_bindings == 1
@ -946,7 +954,7 @@ class TestTagService:
non_existent_tag_id = str(uuid.uuid4())
# Act: Execute the method under test
result = TagService.get_tag_binding_count(non_existent_tag_id)
result = TagService.get_tag_binding_count(non_existent_tag_id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result == 0
@ -986,7 +994,7 @@ class TestTagService:
assert binding_before is not None
# Act: Execute the method under test
TagService.delete_tag(tag.id)
TagService.delete_tag(tag.id, db_session_with_containers)
# Assert: Verify the expected outcomes
# Verify tag was deleted
@ -1018,7 +1026,7 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.delete_tag(non_existent_tag_id)
TagService.delete_tag(non_existent_tag_id, db_session_with_containers)
assert "Tag not found" in str(exc_info.value)
def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
@ -1048,7 +1056,7 @@ class TestTagService:
binding_payload = TagBindingCreatePayload(
type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
)
TagService.save_tag_binding(binding_payload)
TagService.save_tag_binding(binding_payload, db_session_with_containers)
# Assert: Verify the expected outcomes
@ -1090,10 +1098,10 @@ class TestTagService:
# Create first binding
binding_payload = TagBindingCreatePayload(type="app", target_id=app.id, tag_ids=[tag.id])
TagService.save_tag_binding(binding_payload)
TagService.save_tag_binding(binding_payload, db_session_with_containers)
# Act: Try to create duplicate binding
TagService.save_tag_binding(binding_payload)
TagService.save_tag_binding(binding_payload, db_session_with_containers)
# Assert: Verify the expected outcomes
@ -1173,7 +1181,7 @@ class TestTagService:
delete_payload = TagBindingDeletePayload(
type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
)
TagService.delete_tag_binding(delete_payload)
TagService.delete_tag_binding(delete_payload, db_session_with_containers)
# Assert: Verify the expected outcomes
# Verify tag bindings were deleted
@ -1209,7 +1217,7 @@ class TestTagService:
# Act: Try to delete non-existent binding
delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_ids=[tag.id])
TagService.delete_tag_binding(delete_payload)
TagService.delete_tag_binding(delete_payload, db_session_with_containers)
# Assert: Verify the expected outcomes
# No error should be raised, and database state should remain unchanged
@ -1240,7 +1248,7 @@ class TestTagService:
dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Act: Execute the method under test
TagService.check_target_exists("knowledge", dataset.id)
TagService.check_target_exists("knowledge", dataset.id, db_session_with_containers)
# Assert: Verify the expected outcomes
# No exception should be raised for existing dataset
@ -1268,7 +1276,7 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.check_target_exists("knowledge", non_existent_dataset_id)
TagService.check_target_exists("knowledge", non_existent_dataset_id, db_session_with_containers)
assert "Dataset not found" in str(exc_info.value)
def test_check_target_exists_app_success(
@ -1292,7 +1300,7 @@ class TestTagService:
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Act: Execute the method under test
TagService.check_target_exists("app", app.id)
TagService.check_target_exists("app", app.id, db_session_with_containers)
# Assert: Verify the expected outcomes
# No exception should be raised for existing app
@ -1320,7 +1328,7 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.check_target_exists("app", non_existent_app_id)
TagService.check_target_exists("app", non_existent_app_id, db_session_with_containers)
assert "App not found" in str(exc_info.value)
def test_check_target_exists_invalid_type(
@ -1346,5 +1354,5 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.check_target_exists("invalid_type", non_existent_target_id)
TagService.check_target_exists("invalid_type", non_existent_target_id, db_session_with_containers)
assert "Invalid binding type" in str(exc_info.value)

View File

@ -191,7 +191,7 @@ def test_agent_app_list_and_create_use_agent_route(
def get_app(self, app_obj: object) -> object:
return app_obj
def get_paginate_apps(self, user_id: str, tenant_id: str, params) -> object:
def get_paginate_apps(self, user_id: str, tenant_id: str, params, session) -> object:
captured["list"] = {"user_id": user_id, "tenant_id": tenant_id, "params": params}
return SimpleNamespace(
page=1,

View File

@ -2,15 +2,8 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from sqlalchemy.orm import Session
class SessionMatcher:
def __eq__(self, other):
return isinstance(other, Session)
from flask import Flask
from sqlalchemy.orm import Session, scoped_session
from werkzeug.exceptions import Forbidden
import controllers.console.tag.tags as module
@ -27,6 +20,11 @@ from models.enums import TagType
from services.tag_service import UpdateTagPayload
class SessionMatcher:
def __eq__(self, other):
return isinstance(other, Session | scoped_session)
def unwrap(func):
"""
Recursively unwrap decorated functions.
@ -193,9 +191,10 @@ class TestTagUpdateDeleteApi:
result, status = method(api, admin_user, "tag-1")
assert status == 200
update_payload, tag_id = update_tags_mock.call_args.args
update_payload, tag_id, session = update_tags_mock.call_args.args
assert update_payload == UpdateTagPayload(name="updated")
assert tag_id == "tag-1"
assert session == module.db.session
assert result["binding_count"] == "3"
def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch):
@ -221,7 +220,7 @@ class TestTagUpdateDeleteApi:
):
result, status = method(api, "tag-1")
delete_mock.assert_called_once_with("tag-1")
delete_mock.assert_called_once_with("tag-1", module.db.session)
assert status == 204

View File

@ -1,6 +1,6 @@
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock
from unittest.mock import ANY, Mock
import pytest
from werkzeug.exceptions import NotFound
@ -94,6 +94,7 @@ def test_list_snippets_returns_pagination(app, monkeypatch):
}
get_snippets.assert_called_once_with(
tenant_id="tenant-1",
session=ANY,
page=2,
limit=10,
keyword=None,

View File

@ -100,7 +100,7 @@ def test_agent_soul_has_model():
assert agent_soul_has_model(AgentSoulConfig()) is False
def test_load_workflow_composer_returns_empty_state(monkeypatch):
def test_load_workflow_composer_returns_empty_state(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1"))
monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: None)
@ -118,7 +118,7 @@ def test_load_workflow_composer_returns_empty_state(monkeypatch):
assert files_output["array_item"] == {"type": "file", "description": None, "children": []}
def test_load_workflow_composer_serializes_existing_binding(monkeypatch):
def test_load_workflow_composer_serializes_existing_binding(monkeypatch: pytest.MonkeyPatch):
binding = SimpleNamespace(
agent_id="agent-1",
binding_type=WorkflowAgentBindingType.ROSTER_AGENT,
@ -220,7 +220,7 @@ def test_save_workflow_composer_rejects_agent_app_variant():
)
def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch):
def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch: pytest.MonkeyPatch):
fake_session = FakeSession(scalar=[None])
created_version = SimpleNamespace(id="version-1")
@ -249,7 +249,7 @@ def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch):
assert fake_session.commits == 1
def test_save_agent_app_composer_updates_current_version(monkeypatch):
def test_save_agent_app_composer_updates_current_version(monkeypatch: pytest.MonkeyPatch):
agent = SimpleNamespace(id="agent-1", active_config_snapshot_id="version-1", updated_by=None)
fake_session = FakeSession(scalar=[agent])
updated = {}
@ -283,7 +283,7 @@ def test_save_agent_app_composer_updates_current_version(monkeypatch):
assert fake_session.commits == 1
def test_agent_app_composer_candidates_and_impact(monkeypatch):
def test_agent_app_composer_candidates_and_impact(monkeypatch: pytest.MonkeyPatch):
bindings = [
SimpleNamespace(app_id="app-1", workflow_id="workflow-1", node_id="node-1"),
SimpleNamespace(app_id="app-1", workflow_id="workflow-1", node_id="node-2"),
@ -316,7 +316,7 @@ def test_agent_app_composer_candidates_and_impact(monkeypatch):
assert impact["bindings"][1]["node_id"] == "node-2"
def test_serialize_workflow_state_changes_lock_and_save_options(monkeypatch):
def test_serialize_workflow_state_changes_lock_and_save_options(monkeypatch: pytest.MonkeyPatch):
binding = WorkflowAgentNodeBinding(
id="binding-1",
tenant_id="tenant-1",
@ -342,7 +342,7 @@ def test_serialize_workflow_state_changes_lock_and_save_options(monkeypatch):
assert effective_names == ["text", "files", "json"]
def test_serialize_workflow_state_passes_user_declared_outputs_through_effective(monkeypatch):
def test_serialize_workflow_state_passes_user_declared_outputs_through_effective(monkeypatch: pytest.MonkeyPatch):
binding = WorkflowAgentNodeBinding(
id="binding-1",
tenant_id="tenant-1",
@ -369,7 +369,7 @@ def test_serialize_workflow_state_passes_user_declared_outputs_through_effective
assert effective[0]["required"] is True
def test_composer_save_helpers_create_and_rebind_agents(monkeypatch):
def test_composer_save_helpers_create_and_rebind_agents(monkeypatch: pytest.MonkeyPatch):
fake_session = FakeSession()
monkeypatch.setattr(composer_service.db, "session", fake_session)
workflow_agent = SimpleNamespace(id="inline-agent-1", active_config_snapshot_id="inline-version-1")
@ -611,7 +611,7 @@ def test_composer_create_agents_syncs_active_config_has_model(monkeypatch):
assert roster_agent.active_config_has_model is True
def test_composer_version_helpers_and_lookup_errors(monkeypatch):
def test_composer_version_helpers_and_lookup_errors(monkeypatch: pytest.MonkeyPatch):
fake_session = FakeSession(
scalar=[
1,
@ -670,7 +670,7 @@ def test_composer_version_helpers_and_lookup_errors(monkeypatch):
assert workflow.id == "workflow-1"
def test_composer_current_version_and_error_paths(monkeypatch):
def test_composer_current_version_and_error_paths(monkeypatch: pytest.MonkeyPatch):
fake_session = FakeSession(scalar=[2])
monkeypatch.setattr(composer_service.db, "session", fake_session)
payload = ComposerSavePayload.model_validate(
@ -717,7 +717,7 @@ def test_composer_current_version_and_error_paths(monkeypatch):
)
def test_roster_list_and_invite_options(monkeypatch):
def test_roster_list_and_invite_options(monkeypatch: pytest.MonkeyPatch):
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
updated_at = datetime(2026, 1, 3, 3, 4, 5, tzinfo=UTC)
version_created_at = datetime(2026, 1, 4, 3, 4, 5, tzinfo=UTC)
@ -790,7 +790,7 @@ def test_roster_list_and_invite_options(monkeypatch):
assert invited["data"][0]["existing_node_ids"] == ["node-1"]
def test_invite_options_uses_db_filtered_pagination(monkeypatch):
def test_invite_options_uses_db_filtered_pagination(monkeypatch: pytest.MonkeyPatch):
configured_agent = Agent(
id="agent-2",
tenant_id="tenant-1",
@ -915,7 +915,7 @@ def test_published_references_include_app_display_fields_and_sort_by_updated_at(
assert references[0]["workflow_version"] == "published-recent"
def test_roster_update_archive_versions_and_detail(monkeypatch):
def test_roster_update_archive_versions_and_detail(monkeypatch: pytest.MonkeyPatch):
listed_version = AgentConfigSnapshot(id="version-2", agent_id="agent-1", version=2)
listed_version_created_at = datetime(2026, 1, 5, 3, 4, 5, tzinfo=UTC)
listed_version.created_at = listed_version_created_at
@ -973,7 +973,7 @@ def test_roster_update_archive_versions_and_detail(monkeypatch):
assert detail["revisions"][0]["created_at"] == int(revision_created_at.timestamp())
def test_roster_create_detail_and_lookup_helpers(monkeypatch):
def test_roster_create_detail_and_lookup_helpers(monkeypatch: pytest.MonkeyPatch):
fake_session = FakeSession(
scalar=[
SimpleNamespace(id="agent-1"),
@ -1040,9 +1040,7 @@ def test_agent_app_visible_versions_exclude_draft_saves():
def test_app_list_all_excludes_agent_apps_by_default():
filters = AppService._build_app_list_filters(
"account-1",
"tenant-1",
AppListParams(mode="all"),
"account-1", "tenant-1", AppListParams(mode="all"), FakeSession(scalar=None, scalars=None)
)
sql = " ".join(str(filter_) for filter_ in filters)
@ -2173,7 +2171,7 @@ class TestWorkflowAgentDraftBindingSync:
assert session.deleted == [stale_binding]
def test_dataset_rows_filters_malformed_ids(monkeypatch):
def test_dataset_rows_filters_malformed_ids(monkeypatch: pytest.MonkeyPatch):
"""Mention ids are user-editable text: a non-UUID id must read as missing
(placeholder semantics), never reach the UUID-typed dataset query (E2E 500)."""
captured = {}
@ -2197,7 +2195,7 @@ def test_dataset_rows_filters_malformed_ids(monkeypatch):
assert captured == {}
def test_workspace_dify_tools_returns_provider_and_tool_granularities(monkeypatch):
def test_workspace_dify_tools_returns_provider_and_tool_granularities(monkeypatch: pytest.MonkeyPatch):
"""The slash-menu Tools tab needs both selection granularities: a provider
hosts many tools (like an MCP server), so candidates return one
provider-level entry (id = <provider>/*, = all tools) plus one per tool."""
@ -2268,7 +2266,7 @@ def _patch_drive_keys(monkeypatch, existing_keys):
return captured
def test_drive_ref_findings_reports_missing_keys(monkeypatch):
def test_drive_ref_findings_reports_missing_keys(monkeypatch: pytest.MonkeyPatch):
_patch_drive_keys(monkeypatch, existing_keys=["tender-analyzer/SKILL.md"])
findings = AgentComposerService._drive_ref_findings(
@ -2279,7 +2277,7 @@ def test_drive_ref_findings_reports_missing_keys(monkeypatch):
assert str(findings[0]["message"]).startswith("file_ref_dangling: ")
def test_drive_ref_findings_clean_when_all_keys_exist(monkeypatch):
def test_drive_ref_findings_clean_when_all_keys_exist(monkeypatch: pytest.MonkeyPatch):
_patch_drive_keys(monkeypatch, existing_keys=["tender-analyzer/SKILL.md", "files/sample.pdf"])
assert (
@ -2288,7 +2286,7 @@ def test_drive_ref_findings_clean_when_all_keys_exist(monkeypatch):
)
def test_drive_ref_findings_skips_refs_without_drive_keys(monkeypatch):
def test_drive_ref_findings_skips_refs_without_drive_keys(monkeypatch: pytest.MonkeyPatch):
# No drive-backed ref at all -> no DB roundtrip, no findings.
soul = _drive_soul(
skills_files={"skills": [{"id": "legacy", "name": "Legacy"}], "files": [{"name": "u.pdf", "file_id": "u-1"}]}
@ -2297,7 +2295,7 @@ def test_drive_ref_findings_skips_refs_without_drive_keys(monkeypatch):
assert findings == []
def test_require_drive_refs_resolved_raises_with_stable_code(monkeypatch):
def test_require_drive_refs_resolved_raises_with_stable_code(monkeypatch: pytest.MonkeyPatch):
from services.agent.errors import InvalidComposerConfigError
_patch_drive_keys(monkeypatch, existing_keys=[])
@ -2308,7 +2306,7 @@ def test_require_drive_refs_resolved_raises_with_stable_code(monkeypatch):
)
def test_collect_validation_findings_appends_drive_findings_with_agent_context(monkeypatch):
def test_collect_validation_findings_appends_drive_findings_with_agent_context(monkeypatch: pytest.MonkeyPatch):
from services.entities.agent_entities import ComposerSavePayload
_patch_drive_keys(monkeypatch, existing_keys=[])
@ -2334,7 +2332,7 @@ def test_collect_validation_findings_appends_drive_findings_with_agent_context(m
# ── ENG-625 D5: soul-first ref removal ───────────────────────────────────────
def _patch_remove_drive_refs_env(monkeypatch, *, soul_dict):
def _patch_remove_drive_refs_env(monkeypatch: pytest.MonkeyPatch, *, soul_dict):
"""Wire the classmethod's collaborators so soul editing + versioning is observable."""
from types import SimpleNamespace
@ -2359,7 +2357,7 @@ def _patch_remove_drive_refs_env(monkeypatch, *, soul_dict):
return agent, captured, committed
def test_remove_drive_refs_drops_skill_by_slug_and_versions(monkeypatch):
def test_remove_drive_refs_drops_skill_by_slug_and_versions(monkeypatch: pytest.MonkeyPatch):
soul_dict = {
"skills_files": {
"skills": [
@ -2383,7 +2381,7 @@ def test_remove_drive_refs_drops_skill_by_slug_and_versions(monkeypatch):
assert committed.get("committed") is True
def test_remove_drive_refs_is_noop_when_ref_absent(monkeypatch):
def test_remove_drive_refs_is_noop_when_ref_absent(monkeypatch: pytest.MonkeyPatch):
soul_dict = {"skills_files": {"skills": [], "files": []}}
agent, captured, committed = _patch_remove_drive_refs_env(monkeypatch, soul_dict=soul_dict)
@ -2397,7 +2395,7 @@ def test_remove_drive_refs_is_noop_when_ref_absent(monkeypatch):
assert committed == {}
def test_remove_drive_refs_drops_file_by_key(monkeypatch):
def test_remove_drive_refs_drops_file_by_key(monkeypatch: pytest.MonkeyPatch):
soul_dict = {
"skills_files": {
"skills": [],
@ -2417,7 +2415,7 @@ def test_remove_drive_refs_drops_file_by_key(monkeypatch):
assert [f.drive_key for f in captured["agent_soul"].skills_files.files] == ["files/keep.pdf"]
def test_add_drive_file_ref_adds_or_replaces_file_and_versions(monkeypatch):
def test_add_drive_file_ref_adds_or_replaces_file_and_versions(monkeypatch: pytest.MonkeyPatch):
soul_dict = {
"skills_files": {
"skills": [],
@ -2444,7 +2442,7 @@ def test_add_drive_file_ref_adds_or_replaces_file_and_versions(monkeypatch):
assert committed.get("committed") is True
def test_add_drive_file_ref_syncs_workflow_binding_snapshot(monkeypatch):
def test_add_drive_file_ref_syncs_workflow_binding_snapshot(monkeypatch: pytest.MonkeyPatch):
binding = SimpleNamespace(agent_id="agent-1", current_snapshot_id="snap-1", updated_by=None)
_patch_remove_drive_refs_env(monkeypatch, soul_dict={"skills_files": {"skills": [], "files": []}})
monkeypatch.setattr(
@ -2473,7 +2471,7 @@ def test_remove_drive_refs_requires_exactly_one_scope():
# ── ENG-623/625: resolver helpers + save-path drive guard ────────────────────
def test_resolve_bound_agent_id_queries_active_roster_agent(monkeypatch):
def test_resolve_bound_agent_id_queries_active_roster_agent(monkeypatch: pytest.MonkeyPatch):
from types import SimpleNamespace
import services.agent.composer_service as module
@ -2482,7 +2480,7 @@ def test_resolve_bound_agent_id_queries_active_roster_agent(monkeypatch):
assert AgentComposerService.resolve_bound_agent_id(tenant_id="t-1", app_id="app-1") == "agent-9"
def test_resolve_workflow_node_agent_id_degrades_without_workflow_or_binding(monkeypatch):
def test_resolve_workflow_node_agent_id_degrades_without_workflow_or_binding(monkeypatch: pytest.MonkeyPatch):
from types import SimpleNamespace
def boom(cls, **kwargs):
@ -2505,7 +2503,7 @@ def test_resolve_workflow_node_agent_id_degrades_without_workflow_or_binding(mon
assert AgentComposerService.resolve_workflow_node_agent_id(tenant_id="t", app_id="a", node_id="n") == "agent-7"
def test_remove_drive_refs_returns_none_without_agent_or_snapshot(monkeypatch):
def test_remove_drive_refs_returns_none_without_agent_or_snapshot(monkeypatch: pytest.MonkeyPatch):
from types import SimpleNamespace
import services.agent.composer_service as module
@ -2518,7 +2516,7 @@ def test_remove_drive_refs_returns_none_without_agent_or_snapshot(monkeypatch):
assert AgentComposerService.remove_drive_refs(tenant_id="t", agent_id="a", account_id="u", skill_slug="s") is None
def test_save_workflow_composer_guards_drive_refs_for_existing_agent_strategies(monkeypatch):
def test_save_workflow_composer_guards_drive_refs_for_existing_agent_strategies(monkeypatch: pytest.MonkeyPatch):
from types import SimpleNamespace
from services.entities.agent_entities import ComposerSavePayload

View File

@ -94,14 +94,15 @@ def test_validate_snippet_graph_forbidden_nodes_raises_with_node_details() -> No
def test_get_snippets_returns_empty_when_tag_filter_has_no_targets(monkeypatch: pytest.MonkeyPatch) -> None:
session = _SessionWithoutNameLookup()
get_target_ids = Mock(return_value=[])
monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", get_target_ids)
service = SnippetService.__new__(SnippetService)
result = service.get_snippets(tenant_id="tenant-1", tag_ids=["tag-1"])
result = service.get_snippets(tenant_id="tenant-1", session=session, tag_ids=["tag-1"])
assert result == ([], 0, False)
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True)
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], session, match_all=True)
def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPatch) -> None:
@ -124,6 +125,7 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa
result, total, has_more = service.get_snippets(
tenant_id="tenant-1",
session=session,
page=2,
limit=2,
keyword="search",
@ -135,7 +137,7 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa
assert result == snippets[:2]
assert total == 3
assert has_more is True
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True)
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], session, match_all=True)
session.scalar.assert_called_once()
session.scalars.assert_called_once()

View File

@ -1,6 +1,7 @@
from types import SimpleNamespace
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import NotFound
from models.enums import TagType
@ -8,19 +9,19 @@ from services.tag_service import TagBindingCreatePayload, TagBindingDeletePayloa
@pytest.fixture
def current_user(mocker):
def current_user(mocker: MockerFixture):
user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1")
mocker.patch("services.tag_service.current_user", user)
return user
@pytest.fixture
def db_session(mocker):
mock_db = mocker.patch("services.tag_service.db")
def db_session(mocker: MockerFixture):
mock_db = mocker.Mock()
return mock_db.session
def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker, current_user, db_session):
def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker: MockerFixture, current_user, db_session):
mocker.patch("services.tag_service.TagService.check_target_exists")
db_session.scalars.return_value.all.return_value = ["tag-1"]
db_session.scalar.return_value = None
@ -30,7 +31,8 @@ def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker, c
tag_ids=["tag-1", "tag-from-other-tenant"],
target_id="snippet-1",
type=TagType.SNIPPET,
)
),
db_session,
)
db_session.add.assert_called_once()
@ -42,7 +44,7 @@ def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker, c
db_session.commit.assert_called_once()
def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker, current_user, db_session):
def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker: MockerFixture, current_user, db_session):
mocker.patch("services.tag_service.TagService.check_target_exists")
db_session.execute.return_value = SimpleNamespace(rowcount=1)
@ -51,14 +53,15 @@ def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker, curren
tag_ids=["tag-1", "tag-from-other-tenant"],
target_id="snippet-1",
type=TagType.SNIPPET,
)
),
db_session,
)
db_session.execute.assert_called_once()
db_session.commit.assert_called_once()
def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current_user, db_session):
def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker: MockerFixture, current_user, db_session):
mocker.patch("services.tag_service.TagService.check_target_exists")
db_session.execute.return_value = SimpleNamespace(rowcount=0)
@ -67,7 +70,8 @@ def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current
tag_ids=["tag-1"],
target_id="snippet-1",
type=TagType.SNIPPET,
)
),
db_session,
)
db_session.execute.assert_called_once()
@ -75,7 +79,7 @@ def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current
def test_get_target_ids_by_tag_ids_returns_empty_without_query_for_empty_input(db_session):
result = TagService.get_target_ids_by_tag_ids(TagType.SNIPPET, "tenant-1", [])
result = TagService.get_target_ids_by_tag_ids(TagType.SNIPPET, "tenant-1", [], db_session)
assert result == []
db_session.scalars.assert_not_called()
@ -84,7 +88,7 @@ def test_get_target_ids_by_tag_ids_returns_empty_without_query_for_empty_input(d
def test_check_target_exists_accepts_existing_snippet(current_user, db_session):
db_session.scalar.return_value = SimpleNamespace(id="snippet-1")
TagService.check_target_exists("snippet", "snippet-1")
TagService.check_target_exists("snippet", "snippet-1", db_session)
db_session.scalar.assert_called_once()
@ -93,11 +97,11 @@ def test_check_target_exists_raises_when_snippet_missing(current_user, db_sessio
db_session.scalar.return_value = None
with pytest.raises(NotFound, match="Snippet not found"):
TagService.check_target_exists("snippet", "missing-snippet")
TagService.check_target_exists("snippet", "missing-snippet", db_session)
def test_check_target_exists_raises_for_invalid_binding_type(current_user, db_session):
with pytest.raises(NotFound, match="Invalid binding type"):
TagService.check_target_exists("unknown", "target-1")
TagService.check_target_exists("unknown", "target-1", db_session)
db_session.scalar.assert_not_called()