mirror of
https://github.com/langgenius/dify.git
synced 2026-06-23 04:11:09 +08:00
chore: example of make db.session pass from parameter. (#37561)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
0fa43973b8
commit
4304044905
@ -286,7 +286,7 @@ class AgentAppListApi(Resource):
|
||||
status="normal",
|
||||
)
|
||||
|
||||
app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params)
|
||||
app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params, db.session)
|
||||
if app_pagination is None:
|
||||
empty = AgentAppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json")
|
||||
|
||||
@ -594,7 +594,7 @@ class AppListApi(Resource):
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
|
||||
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params, db.session)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
@ -661,7 +661,7 @@ class StarredAppListApi(Resource):
|
||||
is_created_by_me=args.is_created_by_me,
|
||||
)
|
||||
|
||||
app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params)
|
||||
app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params, db.session)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
|
||||
@ -409,6 +409,7 @@ class DatasetListApi(Resource):
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
query.page,
|
||||
query.limit,
|
||||
db.session,
|
||||
current_tenant_id,
|
||||
current_user,
|
||||
query.keyword,
|
||||
|
||||
@ -122,7 +122,7 @@ class TagListApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type), db.session)
|
||||
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
@ -146,9 +146,9 @@ class TagUpdateDeleteApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str)
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str, db.session)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str, db.session)
|
||||
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
@ -164,7 +164,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
def delete(self, tag_id: UUID):
|
||||
tag_id_str = str(tag_id)
|
||||
|
||||
TagService.delete_tag(tag_id_str)
|
||||
TagService.delete_tag(tag_id_str, db.session)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -189,7 +189,8 @@ def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
),
|
||||
db.session,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -203,7 +204,8 @@ def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
),
|
||||
db.session,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@ -126,6 +126,7 @@ class CustomizedSnippetsApi(Resource):
|
||||
snippet_service = _snippet_service()
|
||||
snippets, total, has_more = snippet_service.get_snippets(
|
||||
tenant_id=current_tenant_id,
|
||||
session=db.session,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
|
||||
@ -174,7 +174,7 @@ class AppListApi(Resource):
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag, db.session)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
@ -191,7 +191,7 @@ class AppListApi(Resource):
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params, db.session)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
|
||||
@ -770,7 +770,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE), db.session)
|
||||
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
@ -808,9 +808,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag_id = payload.tag_id
|
||||
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id)
|
||||
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id, db.session)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id, db.session)
|
||||
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
@ -840,7 +840,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
def delete(self, _):
|
||||
"""Delete a knowledge type tag."""
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
TagService.delete_tag(payload.tag_id, db.session)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -873,7 +873,8 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE),
|
||||
db.session,
|
||||
)
|
||||
|
||||
return "", 204
|
||||
@ -907,7 +908,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE),
|
||||
db.session,
|
||||
)
|
||||
|
||||
return "", 204
|
||||
@ -942,6 +944,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags = TagService.get_tags_by_target_id(
|
||||
"knowledge", current_user.current_tenant_id, str(dataset_id), db.session
|
||||
)
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200
|
||||
|
||||
@ -75,7 +75,7 @@ class CreateAppParams(BaseModel):
|
||||
class AppService:
|
||||
@staticmethod
|
||||
def _build_app_list_filters(
|
||||
user_id: str, tenant_id: str, params: AppListBaseParams
|
||||
user_id: str, tenant_id: str, params: AppListBaseParams, session: scoped_session
|
||||
) -> list[sa.ColumnElement[bool]]:
|
||||
filters = [App.tenant_id == tenant_id, App.is_universal == False]
|
||||
|
||||
@ -115,7 +115,7 @@ class AppService:
|
||||
escaped_name = escape_like_pattern(name)
|
||||
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
|
||||
if params.tag_ids and len(params.tag_ids) > 0:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, match_all=True)
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, session, match_all=True)
|
||||
if target_ids and len(target_ids) > 0:
|
||||
filters.append(App.id.in_(target_ids))
|
||||
else:
|
||||
@ -197,7 +197,9 @@ class AppService:
|
||||
).scalars()
|
||||
)
|
||||
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
|
||||
def get_paginate_apps(
|
||||
self, user_id: str, tenant_id: str, params: AppListParams, session: scoped_session
|
||||
) -> Pagination | None:
|
||||
"""
|
||||
Get app list with pagination, filters, and explicit sort order.
|
||||
:param user_id: user id
|
||||
@ -205,7 +207,7 @@ class AppService:
|
||||
:param params: query parameters
|
||||
:return:
|
||||
"""
|
||||
filters = self._build_app_list_filters(user_id, tenant_id, params)
|
||||
filters = self._build_app_list_filters(user_id, tenant_id, params, session)
|
||||
if not filters:
|
||||
return None
|
||||
|
||||
@ -231,12 +233,12 @@ class AppService:
|
||||
return app_models
|
||||
|
||||
def get_paginate_starred_apps(
|
||||
self, user_id: str, tenant_id: str, params: StarredAppListParams
|
||||
self, user_id: str, tenant_id: str, params: StarredAppListParams, session: scoped_session
|
||||
) -> Pagination | None:
|
||||
"""
|
||||
Get apps starred by the current account with pagination, filters, and explicit sort order.
|
||||
"""
|
||||
filters = self._build_app_list_filters(user_id, tenant_id, params)
|
||||
filters = self._build_app_list_filters(user_id, tenant_id, params, session)
|
||||
if not filters:
|
||||
return None
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@ -235,7 +235,9 @@ class _EstimateArgs(BaseModel):
|
||||
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||
def get_datasets(
|
||||
page, per_page, session: scoped_session, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False
|
||||
):
|
||||
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
|
||||
|
||||
if user:
|
||||
@ -295,6 +297,7 @@ class DatasetService:
|
||||
"knowledge",
|
||||
tenant_id,
|
||||
tag_ids,
|
||||
session,
|
||||
match_all=True,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -6,7 +6,7 @@ from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
|
||||
from core.db import session_factory
|
||||
from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
@ -192,6 +192,7 @@ class SnippetService:
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
session: scoped_session,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
keyword: str | None = None,
|
||||
@ -229,20 +230,19 @@ class SnippetService:
|
||||
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
|
||||
|
||||
if tag_ids:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, match_all=True)
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, session, match_all=True)
|
||||
if target_ids:
|
||||
stmt = stmt.where(CustomizedSnippet.id.in_(target_ids))
|
||||
else:
|
||||
return [], 0, False
|
||||
|
||||
with self._session_scope() as session:
|
||||
# Get total count
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = session.scalar(count_stmt) or 0
|
||||
# Get total count
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = session.scalar(count_stmt) or 0
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
|
||||
snippets = list(session.scalars(stmt).all())
|
||||
# Apply pagination
|
||||
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
|
||||
snippets = list(session.scalars(stmt).all())
|
||||
|
||||
has_more = len(snippets) > limit
|
||||
if has_more:
|
||||
|
||||
@ -6,10 +6,9 @@ from flask_login import current_user
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
@ -56,7 +55,7 @@ class TagService:
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(
|
||||
tag_type: str, current_tenant_id: str, tag_ids: list[str], *, match_all: bool = False
|
||||
tag_type: str, current_tenant_id: str, tag_ids: list[str], session: scoped_session, *, match_all: bool = False
|
||||
):
|
||||
"""
|
||||
Return target IDs bound to tags for the given tenant and resource type.
|
||||
@ -70,7 +69,7 @@ class TagService:
|
||||
return []
|
||||
# Deduplicate repeated query params so match_all counts each requested tag once.
|
||||
requested_tag_ids = list(dict.fromkeys(tag_ids))
|
||||
tags = db.session.scalars(
|
||||
tags = session.scalars(
|
||||
select(Tag.id).where(
|
||||
Tag.id.in_(requested_tag_ids),
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
@ -86,13 +85,13 @@ class TagService:
|
||||
if match_all:
|
||||
if len(tag_ids) != len(requested_tag_ids):
|
||||
return []
|
||||
return db.session.scalars(
|
||||
return session.scalars(
|
||||
select(TagBinding.target_id)
|
||||
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
|
||||
.group_by(TagBinding.target_id)
|
||||
.having(func.count(sa.distinct(TagBinding.tag_id)) == len(tag_ids))
|
||||
).all()
|
||||
tag_bindings = db.session.scalars(
|
||||
tag_bindings = session.scalars(
|
||||
select(TagBinding.target_id).where(
|
||||
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
|
||||
)
|
||||
@ -100,11 +99,11 @@ class TagService:
|
||||
return tag_bindings
|
||||
|
||||
@staticmethod
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str, session: scoped_session):
|
||||
if not tag_type or not tag_name:
|
||||
return []
|
||||
tags = list(
|
||||
db.session.scalars(
|
||||
session.scalars(
|
||||
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
).all()
|
||||
)
|
||||
@ -113,8 +112,8 @@ class TagService:
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
|
||||
tags = db.session.scalars(
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str, session: scoped_session):
|
||||
tags = session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
@ -128,8 +127,8 @@ class TagService:
|
||||
return tags or []
|
||||
|
||||
@staticmethod
|
||||
def save_tags(payload: SaveTagPayload) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name):
|
||||
def save_tags(payload: SaveTagPayload, session: scoped_session) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name, session):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
name=payload.name,
|
||||
@ -138,17 +137,17 @@ class TagService:
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
tag.id = str(uuid.uuid4())
|
||||
db.session.add(tag)
|
||||
db.session.commit()
|
||||
session.add(tag)
|
||||
session.commit()
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag:
|
||||
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
def update_tags(payload: UpdateTagPayload, tag_id: str, session: scoped_session) -> Tag:
|
||||
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
if payload.name != tag.name:
|
||||
existing = db.session.scalar(
|
||||
existing = session.scalar(
|
||||
select(Tag)
|
||||
.where(
|
||||
Tag.name == payload.name,
|
||||
@ -161,31 +160,31 @@ class TagService:
|
||||
if existing:
|
||||
raise ValueError("Tag name already exists")
|
||||
tag.name = payload.name
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def get_tag_binding_count(tag_id: str) -> int:
|
||||
count = db.session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
|
||||
def get_tag_binding_count(tag_id: str, session: scoped_session) -> int:
|
||||
count = session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def delete_tag(tag_id: str):
|
||||
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
def delete_tag(tag_id: str, session: scoped_session):
|
||||
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
db.session.delete(tag)
|
||||
session.delete(tag)
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
|
||||
tag_bindings = session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
|
||||
if tag_bindings:
|
||||
for tag_binding in tag_bindings:
|
||||
db.session.delete(tag_binding)
|
||||
db.session.commit()
|
||||
session.delete(tag_binding)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def save_tag_binding(payload: TagBindingCreatePayload):
|
||||
TagService.check_target_exists(payload.type, payload.target_id)
|
||||
valid_tag_ids = db.session.scalars(
|
||||
def save_tag_binding(payload: TagBindingCreatePayload, session: scoped_session):
|
||||
TagService.check_target_exists(payload.type, payload.target_id, session)
|
||||
valid_tag_ids = session.scalars(
|
||||
select(Tag.id).where(
|
||||
Tag.id.in_(payload.tag_ids),
|
||||
Tag.tenant_id == current_user.current_tenant_id,
|
||||
@ -193,7 +192,7 @@ class TagService:
|
||||
)
|
||||
).all()
|
||||
for tag_id in valid_tag_ids:
|
||||
tag_binding = db.session.scalar(
|
||||
tag_binding = session.scalar(
|
||||
select(TagBinding)
|
||||
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id)
|
||||
.limit(1)
|
||||
@ -206,15 +205,15 @@ class TagService:
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(new_tag_binding)
|
||||
db.session.commit()
|
||||
session.add(new_tag_binding)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_tag_binding(payload: TagBindingDeletePayload):
|
||||
TagService.check_target_exists(payload.type, payload.target_id)
|
||||
def delete_tag_binding(payload: TagBindingDeletePayload, session: scoped_session):
|
||||
TagService.check_target_exists(payload.type, payload.target_id, session)
|
||||
result = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
delete(TagBinding).where(
|
||||
TagBinding.target_id == payload.target_id,
|
||||
TagBinding.tag_id.in_(payload.tag_ids),
|
||||
@ -230,12 +229,12 @@ class TagService:
|
||||
)
|
||||
|
||||
if result.rowcount:
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def check_target_exists(type: str, target_id: str):
|
||||
def check_target_exists(type: str, target_id: str, session: scoped_session):
|
||||
if type == "knowledge":
|
||||
dataset = db.session.scalar(
|
||||
dataset = session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
|
||||
.limit(1)
|
||||
@ -243,13 +242,13 @@ class TagService:
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found")
|
||||
elif type == "app":
|
||||
app = db.session.scalar(
|
||||
app = session.scalar(
|
||||
select(App).where(App.tenant_id == current_user.current_tenant_id, App.id == target_id).limit(1)
|
||||
)
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
elif type == "snippet":
|
||||
snippet = db.session.scalar(
|
||||
snippet = session.scalar(
|
||||
select(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.tenant_id == current_user.current_tenant_id, CustomizedSnippet.id == target_id)
|
||||
.limit(1)
|
||||
|
||||
@ -16,7 +16,7 @@ since these test controller-level behavior.
|
||||
import uuid
|
||||
from contextlib import ExitStack
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
from unittest.mock import ANY, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -1129,7 +1129,7 @@ class TestDatasetTagsApiPatch:
|
||||
assert status == 200
|
||||
assert response == {"id": "tag-1", "name": "Updated Tag", "type": "knowledge", "binding_count": "5"}
|
||||
mock_tag_svc.update_tags.assert_called_once()
|
||||
update_payload, tag_id = mock_tag_svc.update_tags.call_args.args
|
||||
update_payload, tag_id, session = mock_tag_svc.update_tags.call_args.args
|
||||
assert update_payload.name == "Updated Tag"
|
||||
assert tag_id == "tag-1"
|
||||
|
||||
@ -1184,7 +1184,7 @@ class TestDatasetTagsApiDelete:
|
||||
result = api.delete(_=None)
|
||||
|
||||
assert result == ("", 204)
|
||||
mock_tag_svc.delete_tag.assert_called_once_with("tag-1")
|
||||
mock_tag_svc.delete_tag.assert_called_once_with("tag-1", ANY)
|
||||
|
||||
@patch("libs.login.current_user")
|
||||
def test_delete_tag_forbidden(self, mock_current_user, app: Flask):
|
||||
@ -1233,7 +1233,7 @@ class TestDatasetTagsBindingStatusApi:
|
||||
assert status_code == 200
|
||||
assert response["data"] == [{"id": "tag_1", "name": "Test Tag"}]
|
||||
assert response["total"] == 1
|
||||
mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123")
|
||||
mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123", ANY)
|
||||
|
||||
|
||||
class TestDatasetTagBindingApiPost:
|
||||
@ -1266,7 +1266,8 @@ class TestDatasetTagBindingApiPost:
|
||||
from services.tag_service import TagBindingCreatePayload
|
||||
|
||||
mock_tag_svc.save_tag_binding.assert_called_once_with(
|
||||
TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
|
||||
TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
|
||||
ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
@ -1317,7 +1318,8 @@ class TestDatasetTagUnbindingApiPost:
|
||||
from services.tag_service import TagBindingDeletePayload
|
||||
|
||||
mock_tag_svc.delete_tag_binding.assert_called_once_with(
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
|
||||
ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.TagService")
|
||||
@ -1347,7 +1349,8 @@ class TestDatasetTagUnbindingApiPost:
|
||||
from services.tag_service import TagBindingDeletePayload
|
||||
|
||||
mock_tag_svc.delete_tag_binding.assert_called_once_with(
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
|
||||
ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
|
||||
@ -234,7 +234,7 @@ class TestAppService:
|
||||
# Get paginated apps
|
||||
params = AppListParams(page=1, limit=10, mode="chat")
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
|
||||
|
||||
# Verify pagination results
|
||||
assert paginated_apps is not None
|
||||
@ -295,7 +295,7 @@ class TestAppService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
last_modified_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert last_modified_apps is not None
|
||||
assert [app.name for app in last_modified_apps.items] == [
|
||||
@ -305,7 +305,10 @@ class TestAppService:
|
||||
]
|
||||
|
||||
recently_created_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created")
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert recently_created_apps is not None
|
||||
assert [app.name for app in recently_created_apps.items] == [
|
||||
@ -315,7 +318,10 @@ class TestAppService:
|
||||
]
|
||||
|
||||
earliest_created_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created")
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert earliest_created_apps is not None
|
||||
assert [app.name for app in earliest_created_apps.items] == [
|
||||
@ -366,7 +372,7 @@ class TestAppService:
|
||||
assert star_count == 1
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items}
|
||||
@ -377,7 +383,7 @@ class TestAppService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items}
|
||||
@ -442,7 +448,7 @@ class TestAppService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
last_modified_apps = app_service.get_paginate_starred_apps(
|
||||
account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat")
|
||||
account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert last_modified_apps is not None
|
||||
assert [app.name for app in last_modified_apps.items] == [
|
||||
@ -457,6 +463,7 @@ class TestAppService:
|
||||
account.id,
|
||||
tenant.id,
|
||||
StarredAppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert recently_created_apps is not None
|
||||
assert [app.name for app in recently_created_apps.items] == [
|
||||
@ -469,6 +476,7 @@ class TestAppService:
|
||||
account.id,
|
||||
tenant.id,
|
||||
StarredAppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert earliest_created_apps is not None
|
||||
assert [app.name for app in earliest_created_apps.items] == [
|
||||
@ -522,20 +530,25 @@ class TestAppService:
|
||||
completion_app = app_service.create_app(tenant.id, completion_app_params, account)
|
||||
|
||||
# Test filter by mode
|
||||
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"))
|
||||
chat_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert len(chat_apps.items) == 1
|
||||
assert chat_apps.items[0].mode == "chat"
|
||||
|
||||
# Test filter by name
|
||||
filtered_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat")
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat"), db_session_with_containers
|
||||
)
|
||||
assert len(filtered_apps.items) == 1
|
||||
assert "Chat" in filtered_apps.items[0].name
|
||||
|
||||
# Test filter by created_by_me
|
||||
my_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True)
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert len(my_apps.items) == 1
|
||||
|
||||
@ -588,6 +601,7 @@ class TestAppService:
|
||||
first_account.id,
|
||||
tenant.id,
|
||||
AppListParams(page=1, limit=10, mode="chat", creator_ids=[second_account.id]),
|
||||
db_session_with_containers,
|
||||
)
|
||||
|
||||
assert filtered_apps is not None
|
||||
@ -635,10 +649,12 @@ class TestAppService:
|
||||
# Test with tag filter
|
||||
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["tag1", "tag2"])
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
|
||||
|
||||
# Verify tag service was called
|
||||
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"], match_all=True)
|
||||
mock_tag_service.assert_called_once_with(
|
||||
"app", tenant.id, ["tag1", "tag2"], db_session_with_containers, match_all=True
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert paginated_apps is not None
|
||||
@ -651,7 +667,7 @@ class TestAppService:
|
||||
|
||||
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["nonexistent_tag"])
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
|
||||
|
||||
# Should return None when no apps match tag filter
|
||||
assert paginated_apps is None
|
||||
@ -1467,7 +1483,7 @@ class TestAppService:
|
||||
|
||||
# Test 1: Search with % character
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
@ -1476,7 +1492,10 @@ class TestAppService:
|
||||
|
||||
# Test 2: Search with _ character
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="test_data", mode="chat", page=1, limit=10)
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(name="test_data", mode="chat", page=1, limit=10),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
@ -1485,7 +1504,10 @@ class TestAppService:
|
||||
|
||||
# Test 3: Search with \ character
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10)
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
@ -1494,7 +1516,7 @@ class TestAppService:
|
||||
|
||||
# Test 4: Search with % should NOT match 100% (verifies escaping works)
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
|
||||
@ -227,7 +227,7 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id)
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, db_session_with_containers, tenant_id=tenant.id)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 5
|
||||
@ -257,7 +257,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, search=search)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, per_page, db_session_with_containers, tenant_id=tenant.id, search=search
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -301,7 +303,9 @@ class TestDatasetServiceGetDatasets:
|
||||
tag_ids = [tag_1.id, tag_2.id]
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -326,7 +330,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids
|
||||
)
|
||||
|
||||
# Assert
|
||||
# When tag_ids is empty, tag filtering is skipped, so normal query results are returned
|
||||
@ -356,7 +362,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, user=None)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, per_page, db_session_with_containers, tenant_id=tenant.id, user=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -384,6 +392,7 @@ class TestDatasetServiceGetDatasets:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1,
|
||||
per_page=20,
|
||||
session=db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
user=owner,
|
||||
include_all=True,
|
||||
@ -408,7 +417,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -432,7 +443,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -459,7 +472,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -486,7 +501,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -509,7 +526,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
|
||||
@ -449,7 +449,7 @@ class TestTagService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids)
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -485,7 +485,7 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test with empty tag IDs
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [])
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [], db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -533,13 +533,19 @@ class TestTagService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, match_all=True)
|
||||
result = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge", tenant.id, tag_ids, db_session_with_containers, match_all=True
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result == [dataset_with_all_tags.id]
|
||||
|
||||
missing_tag_result = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge", tenant.id, [tags[0].id, str(uuid.uuid4())], match_all=True
|
||||
"knowledge",
|
||||
tenant.id,
|
||||
[tags[0].id, str(uuid.uuid4())],
|
||||
db_session_with_containers,
|
||||
match_all=True,
|
||||
)
|
||||
assert missing_tag_result == []
|
||||
|
||||
@ -565,7 +571,9 @@ class TestTagService:
|
||||
non_existent_tag_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, non_existent_tag_ids)
|
||||
result = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge", tenant.id, non_existent_tag_ids, db_session_with_containers
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -599,7 +607,7 @@ class TestTagService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag")
|
||||
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag", db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -625,7 +633,7 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test with non-existent tag name
|
||||
result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag")
|
||||
result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag", db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -650,8 +658,8 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test with empty parameters
|
||||
result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag")
|
||||
result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "")
|
||||
result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag", db_session_with_containers)
|
||||
result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "", db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result_empty_type is not None
|
||||
@ -688,7 +696,7 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
|
||||
result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -720,7 +728,7 @@ class TestTagService:
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
|
||||
result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -745,7 +753,7 @@ class TestTagService:
|
||||
tag_args = SaveTagPayload(name="test_tag_name", type="knowledge")
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.save_tags(tag_args)
|
||||
result = TagService.save_tags(tag_args, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -783,11 +791,11 @@ class TestTagService:
|
||||
|
||||
# Create first tag
|
||||
tag_args = SaveTagPayload(name="duplicate_tag", type="app")
|
||||
TagService.save_tags(tag_args)
|
||||
TagService.save_tags(tag_args, db_session_with_containers)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
TagService.save_tags(tag_args)
|
||||
TagService.save_tags(tag_args, db_session_with_containers)
|
||||
assert "Tag name already exists" in str(exc_info.value)
|
||||
|
||||
def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
@ -807,13 +815,13 @@ class TestTagService:
|
||||
|
||||
# Create a tag to update
|
||||
tag_args = SaveTagPayload(name="original_name", type="knowledge")
|
||||
tag = TagService.save_tags(tag_args)
|
||||
tag = TagService.save_tags(tag_args, db_session_with_containers)
|
||||
|
||||
# Update args
|
||||
update_args = UpdateTagPayload(name="updated_name")
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.update_tags(update_args, tag.id)
|
||||
result = TagService.update_tags(update_args, tag.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -854,7 +862,7 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.update_tags(update_args, non_existent_tag_id)
|
||||
TagService.update_tags(update_args, non_existent_tag_id, db_session_with_containers)
|
||||
assert "Tag not found" in str(exc_info.value)
|
||||
|
||||
def test_update_tags_duplicate_name_error(
|
||||
@ -875,17 +883,17 @@ class TestTagService:
|
||||
|
||||
# Create two tags
|
||||
tag1_args = SaveTagPayload(name="first_tag", type="app")
|
||||
tag1 = TagService.save_tags(tag1_args)
|
||||
tag1 = TagService.save_tags(tag1_args, db_session_with_containers)
|
||||
|
||||
tag2_args = SaveTagPayload(name="second_tag", type="app")
|
||||
tag2 = TagService.save_tags(tag2_args)
|
||||
tag2 = TagService.save_tags(tag2_args, db_session_with_containers)
|
||||
|
||||
# Try to update second tag with first tag's name
|
||||
update_args = UpdateTagPayload(name="first_tag")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
TagService.update_tags(update_args, tag2.id)
|
||||
TagService.update_tags(update_args, tag2.id, db_session_with_containers)
|
||||
assert "Tag name already exists" in str(exc_info.value)
|
||||
|
||||
def test_get_tag_binding_count_success(
|
||||
@ -917,8 +925,8 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id)
|
||||
result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id)
|
||||
result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id, db_session_with_containers)
|
||||
result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result_tag_with_bindings == 1
|
||||
@ -946,7 +954,7 @@ class TestTagService:
|
||||
non_existent_tag_id = str(uuid.uuid4())
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tag_binding_count(non_existent_tag_id)
|
||||
result = TagService.get_tag_binding_count(non_existent_tag_id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result == 0
|
||||
@ -986,7 +994,7 @@ class TestTagService:
|
||||
assert binding_before is not None
|
||||
|
||||
# Act: Execute the method under test
|
||||
TagService.delete_tag(tag.id)
|
||||
TagService.delete_tag(tag.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify tag was deleted
|
||||
@ -1018,7 +1026,7 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.delete_tag(non_existent_tag_id)
|
||||
TagService.delete_tag(non_existent_tag_id, db_session_with_containers)
|
||||
assert "Tag not found" in str(exc_info.value)
|
||||
|
||||
def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
@ -1048,7 +1056,7 @@ class TestTagService:
|
||||
binding_payload = TagBindingCreatePayload(
|
||||
type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
|
||||
)
|
||||
TagService.save_tag_binding(binding_payload)
|
||||
TagService.save_tag_binding(binding_payload, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
||||
@ -1090,10 +1098,10 @@ class TestTagService:
|
||||
|
||||
# Create first binding
|
||||
binding_payload = TagBindingCreatePayload(type="app", target_id=app.id, tag_ids=[tag.id])
|
||||
TagService.save_tag_binding(binding_payload)
|
||||
TagService.save_tag_binding(binding_payload, db_session_with_containers)
|
||||
|
||||
# Act: Try to create duplicate binding
|
||||
TagService.save_tag_binding(binding_payload)
|
||||
TagService.save_tag_binding(binding_payload, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
||||
@ -1173,7 +1181,7 @@ class TestTagService:
|
||||
delete_payload = TagBindingDeletePayload(
|
||||
type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
|
||||
)
|
||||
TagService.delete_tag_binding(delete_payload)
|
||||
TagService.delete_tag_binding(delete_payload, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify tag bindings were deleted
|
||||
@ -1209,7 +1217,7 @@ class TestTagService:
|
||||
|
||||
# Act: Try to delete non-existent binding
|
||||
delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_ids=[tag.id])
|
||||
TagService.delete_tag_binding(delete_payload)
|
||||
TagService.delete_tag_binding(delete_payload, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# No error should be raised, and database state should remain unchanged
|
||||
@ -1240,7 +1248,7 @@ class TestTagService:
|
||||
dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
|
||||
|
||||
# Act: Execute the method under test
|
||||
TagService.check_target_exists("knowledge", dataset.id)
|
||||
TagService.check_target_exists("knowledge", dataset.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# No exception should be raised for existing dataset
|
||||
@ -1268,7 +1276,7 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.check_target_exists("knowledge", non_existent_dataset_id)
|
||||
TagService.check_target_exists("knowledge", non_existent_dataset_id, db_session_with_containers)
|
||||
assert "Dataset not found" in str(exc_info.value)
|
||||
|
||||
def test_check_target_exists_app_success(
|
||||
@ -1292,7 +1300,7 @@ class TestTagService:
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
|
||||
|
||||
# Act: Execute the method under test
|
||||
TagService.check_target_exists("app", app.id)
|
||||
TagService.check_target_exists("app", app.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# No exception should be raised for existing app
|
||||
@ -1320,7 +1328,7 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.check_target_exists("app", non_existent_app_id)
|
||||
TagService.check_target_exists("app", non_existent_app_id, db_session_with_containers)
|
||||
assert "App not found" in str(exc_info.value)
|
||||
|
||||
def test_check_target_exists_invalid_type(
|
||||
@ -1346,5 +1354,5 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.check_target_exists("invalid_type", non_existent_target_id)
|
||||
TagService.check_target_exists("invalid_type", non_existent_target_id, db_session_with_containers)
|
||||
assert "Invalid binding type" in str(exc_info.value)
|
||||
|
||||
@ -191,7 +191,7 @@ def test_agent_app_list_and_create_use_agent_route(
|
||||
def get_app(self, app_obj: object) -> object:
|
||||
return app_obj
|
||||
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, params) -> object:
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, params, session) -> object:
|
||||
captured["list"] = {"user_id": user_id, "tenant_id": tenant_id, "params": params}
|
||||
return SimpleNamespace(
|
||||
page=1,
|
||||
|
||||
@ -2,15 +2,8 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class SessionMatcher:
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Session)
|
||||
|
||||
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import controllers.console.tag.tags as module
|
||||
@ -27,6 +20,11 @@ from models.enums import TagType
|
||||
from services.tag_service import UpdateTagPayload
|
||||
|
||||
|
||||
class SessionMatcher:
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Session | scoped_session)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
@ -193,9 +191,10 @@ class TestTagUpdateDeleteApi:
|
||||
result, status = method(api, admin_user, "tag-1")
|
||||
|
||||
assert status == 200
|
||||
update_payload, tag_id = update_tags_mock.call_args.args
|
||||
update_payload, tag_id, session = update_tags_mock.call_args.args
|
||||
assert update_payload == UpdateTagPayload(name="updated")
|
||||
assert tag_id == "tag-1"
|
||||
assert session == module.db.session
|
||||
assert result["binding_count"] == "3"
|
||||
|
||||
def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch):
|
||||
@ -221,7 +220,7 @@ class TestTagUpdateDeleteApi:
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
|
||||
delete_mock.assert_called_once_with("tag-1")
|
||||
delete_mock.assert_called_once_with("tag-1", module.db.session)
|
||||
assert status == 204
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -94,6 +94,7 @@ def test_list_snippets_returns_pagination(app, monkeypatch):
|
||||
}
|
||||
get_snippets.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
session=ANY,
|
||||
page=2,
|
||||
limit=10,
|
||||
keyword=None,
|
||||
|
||||
@ -100,7 +100,7 @@ def test_agent_soul_has_model():
|
||||
assert agent_soul_has_model(AgentSoulConfig()) is False
|
||||
|
||||
|
||||
def test_load_workflow_composer_returns_empty_state(monkeypatch):
|
||||
def test_load_workflow_composer_returns_empty_state(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1"))
|
||||
monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: None)
|
||||
|
||||
@ -118,7 +118,7 @@ def test_load_workflow_composer_returns_empty_state(monkeypatch):
|
||||
assert files_output["array_item"] == {"type": "file", "description": None, "children": []}
|
||||
|
||||
|
||||
def test_load_workflow_composer_serializes_existing_binding(monkeypatch):
|
||||
def test_load_workflow_composer_serializes_existing_binding(monkeypatch: pytest.MonkeyPatch):
|
||||
binding = SimpleNamespace(
|
||||
agent_id="agent-1",
|
||||
binding_type=WorkflowAgentBindingType.ROSTER_AGENT,
|
||||
@ -220,7 +220,7 @@ def test_save_workflow_composer_rejects_agent_app_variant():
|
||||
)
|
||||
|
||||
|
||||
def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch):
|
||||
def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_session = FakeSession(scalar=[None])
|
||||
created_version = SimpleNamespace(id="version-1")
|
||||
|
||||
@ -249,7 +249,7 @@ def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch):
|
||||
assert fake_session.commits == 1
|
||||
|
||||
|
||||
def test_save_agent_app_composer_updates_current_version(monkeypatch):
|
||||
def test_save_agent_app_composer_updates_current_version(monkeypatch: pytest.MonkeyPatch):
|
||||
agent = SimpleNamespace(id="agent-1", active_config_snapshot_id="version-1", updated_by=None)
|
||||
fake_session = FakeSession(scalar=[agent])
|
||||
updated = {}
|
||||
@ -283,7 +283,7 @@ def test_save_agent_app_composer_updates_current_version(monkeypatch):
|
||||
assert fake_session.commits == 1
|
||||
|
||||
|
||||
def test_agent_app_composer_candidates_and_impact(monkeypatch):
|
||||
def test_agent_app_composer_candidates_and_impact(monkeypatch: pytest.MonkeyPatch):
|
||||
bindings = [
|
||||
SimpleNamespace(app_id="app-1", workflow_id="workflow-1", node_id="node-1"),
|
||||
SimpleNamespace(app_id="app-1", workflow_id="workflow-1", node_id="node-2"),
|
||||
@ -316,7 +316,7 @@ def test_agent_app_composer_candidates_and_impact(monkeypatch):
|
||||
assert impact["bindings"][1]["node_id"] == "node-2"
|
||||
|
||||
|
||||
def test_serialize_workflow_state_changes_lock_and_save_options(monkeypatch):
|
||||
def test_serialize_workflow_state_changes_lock_and_save_options(monkeypatch: pytest.MonkeyPatch):
|
||||
binding = WorkflowAgentNodeBinding(
|
||||
id="binding-1",
|
||||
tenant_id="tenant-1",
|
||||
@ -342,7 +342,7 @@ def test_serialize_workflow_state_changes_lock_and_save_options(monkeypatch):
|
||||
assert effective_names == ["text", "files", "json"]
|
||||
|
||||
|
||||
def test_serialize_workflow_state_passes_user_declared_outputs_through_effective(monkeypatch):
|
||||
def test_serialize_workflow_state_passes_user_declared_outputs_through_effective(monkeypatch: pytest.MonkeyPatch):
|
||||
binding = WorkflowAgentNodeBinding(
|
||||
id="binding-1",
|
||||
tenant_id="tenant-1",
|
||||
@ -369,7 +369,7 @@ def test_serialize_workflow_state_passes_user_declared_outputs_through_effective
|
||||
assert effective[0]["required"] is True
|
||||
|
||||
|
||||
def test_composer_save_helpers_create_and_rebind_agents(monkeypatch):
|
||||
def test_composer_save_helpers_create_and_rebind_agents(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_session = FakeSession()
|
||||
monkeypatch.setattr(composer_service.db, "session", fake_session)
|
||||
workflow_agent = SimpleNamespace(id="inline-agent-1", active_config_snapshot_id="inline-version-1")
|
||||
@ -611,7 +611,7 @@ def test_composer_create_agents_syncs_active_config_has_model(monkeypatch):
|
||||
assert roster_agent.active_config_has_model is True
|
||||
|
||||
|
||||
def test_composer_version_helpers_and_lookup_errors(monkeypatch):
|
||||
def test_composer_version_helpers_and_lookup_errors(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_session = FakeSession(
|
||||
scalar=[
|
||||
1,
|
||||
@ -670,7 +670,7 @@ def test_composer_version_helpers_and_lookup_errors(monkeypatch):
|
||||
assert workflow.id == "workflow-1"
|
||||
|
||||
|
||||
def test_composer_current_version_and_error_paths(monkeypatch):
|
||||
def test_composer_current_version_and_error_paths(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_session = FakeSession(scalar=[2])
|
||||
monkeypatch.setattr(composer_service.db, "session", fake_session)
|
||||
payload = ComposerSavePayload.model_validate(
|
||||
@ -717,7 +717,7 @@ def test_composer_current_version_and_error_paths(monkeypatch):
|
||||
)
|
||||
|
||||
|
||||
def test_roster_list_and_invite_options(monkeypatch):
|
||||
def test_roster_list_and_invite_options(monkeypatch: pytest.MonkeyPatch):
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
updated_at = datetime(2026, 1, 3, 3, 4, 5, tzinfo=UTC)
|
||||
version_created_at = datetime(2026, 1, 4, 3, 4, 5, tzinfo=UTC)
|
||||
@ -790,7 +790,7 @@ def test_roster_list_and_invite_options(monkeypatch):
|
||||
assert invited["data"][0]["existing_node_ids"] == ["node-1"]
|
||||
|
||||
|
||||
def test_invite_options_uses_db_filtered_pagination(monkeypatch):
|
||||
def test_invite_options_uses_db_filtered_pagination(monkeypatch: pytest.MonkeyPatch):
|
||||
configured_agent = Agent(
|
||||
id="agent-2",
|
||||
tenant_id="tenant-1",
|
||||
@ -915,7 +915,7 @@ def test_published_references_include_app_display_fields_and_sort_by_updated_at(
|
||||
assert references[0]["workflow_version"] == "published-recent"
|
||||
|
||||
|
||||
def test_roster_update_archive_versions_and_detail(monkeypatch):
|
||||
def test_roster_update_archive_versions_and_detail(monkeypatch: pytest.MonkeyPatch):
|
||||
listed_version = AgentConfigSnapshot(id="version-2", agent_id="agent-1", version=2)
|
||||
listed_version_created_at = datetime(2026, 1, 5, 3, 4, 5, tzinfo=UTC)
|
||||
listed_version.created_at = listed_version_created_at
|
||||
@ -973,7 +973,7 @@ def test_roster_update_archive_versions_and_detail(monkeypatch):
|
||||
assert detail["revisions"][0]["created_at"] == int(revision_created_at.timestamp())
|
||||
|
||||
|
||||
def test_roster_create_detail_and_lookup_helpers(monkeypatch):
|
||||
def test_roster_create_detail_and_lookup_helpers(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_session = FakeSession(
|
||||
scalar=[
|
||||
SimpleNamespace(id="agent-1"),
|
||||
@ -1040,9 +1040,7 @@ def test_agent_app_visible_versions_exclude_draft_saves():
|
||||
|
||||
def test_app_list_all_excludes_agent_apps_by_default():
|
||||
filters = AppService._build_app_list_filters(
|
||||
"account-1",
|
||||
"tenant-1",
|
||||
AppListParams(mode="all"),
|
||||
"account-1", "tenant-1", AppListParams(mode="all"), FakeSession(scalar=None, scalars=None)
|
||||
)
|
||||
sql = " ".join(str(filter_) for filter_ in filters)
|
||||
|
||||
@ -2173,7 +2171,7 @@ class TestWorkflowAgentDraftBindingSync:
|
||||
assert session.deleted == [stale_binding]
|
||||
|
||||
|
||||
def test_dataset_rows_filters_malformed_ids(monkeypatch):
|
||||
def test_dataset_rows_filters_malformed_ids(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Mention ids are user-editable text: a non-UUID id must read as missing
|
||||
(placeholder semantics), never reach the UUID-typed dataset query (E2E 500)."""
|
||||
captured = {}
|
||||
@ -2197,7 +2195,7 @@ def test_dataset_rows_filters_malformed_ids(monkeypatch):
|
||||
assert captured == {}
|
||||
|
||||
|
||||
def test_workspace_dify_tools_returns_provider_and_tool_granularities(monkeypatch):
|
||||
def test_workspace_dify_tools_returns_provider_and_tool_granularities(monkeypatch: pytest.MonkeyPatch):
|
||||
"""The slash-menu Tools tab needs both selection granularities: a provider
|
||||
hosts many tools (like an MCP server), so candidates return one
|
||||
provider-level entry (id = <provider>/*, = all tools) plus one per tool."""
|
||||
@ -2268,7 +2266,7 @@ def _patch_drive_keys(monkeypatch, existing_keys):
|
||||
return captured
|
||||
|
||||
|
||||
def test_drive_ref_findings_reports_missing_keys(monkeypatch):
|
||||
def test_drive_ref_findings_reports_missing_keys(monkeypatch: pytest.MonkeyPatch):
|
||||
_patch_drive_keys(monkeypatch, existing_keys=["tender-analyzer/SKILL.md"])
|
||||
|
||||
findings = AgentComposerService._drive_ref_findings(
|
||||
@ -2279,7 +2277,7 @@ def test_drive_ref_findings_reports_missing_keys(monkeypatch):
|
||||
assert str(findings[0]["message"]).startswith("file_ref_dangling: ")
|
||||
|
||||
|
||||
def test_drive_ref_findings_clean_when_all_keys_exist(monkeypatch):
|
||||
def test_drive_ref_findings_clean_when_all_keys_exist(monkeypatch: pytest.MonkeyPatch):
|
||||
_patch_drive_keys(monkeypatch, existing_keys=["tender-analyzer/SKILL.md", "files/sample.pdf"])
|
||||
|
||||
assert (
|
||||
@ -2288,7 +2286,7 @@ def test_drive_ref_findings_clean_when_all_keys_exist(monkeypatch):
|
||||
)
|
||||
|
||||
|
||||
def test_drive_ref_findings_skips_refs_without_drive_keys(monkeypatch):
|
||||
def test_drive_ref_findings_skips_refs_without_drive_keys(monkeypatch: pytest.MonkeyPatch):
|
||||
# No drive-backed ref at all -> no DB roundtrip, no findings.
|
||||
soul = _drive_soul(
|
||||
skills_files={"skills": [{"id": "legacy", "name": "Legacy"}], "files": [{"name": "u.pdf", "file_id": "u-1"}]}
|
||||
@ -2297,7 +2295,7 @@ def test_drive_ref_findings_skips_refs_without_drive_keys(monkeypatch):
|
||||
assert findings == []
|
||||
|
||||
|
||||
def test_require_drive_refs_resolved_raises_with_stable_code(monkeypatch):
|
||||
def test_require_drive_refs_resolved_raises_with_stable_code(monkeypatch: pytest.MonkeyPatch):
|
||||
from services.agent.errors import InvalidComposerConfigError
|
||||
|
||||
_patch_drive_keys(monkeypatch, existing_keys=[])
|
||||
@ -2308,7 +2306,7 @@ def test_require_drive_refs_resolved_raises_with_stable_code(monkeypatch):
|
||||
)
|
||||
|
||||
|
||||
def test_collect_validation_findings_appends_drive_findings_with_agent_context(monkeypatch):
|
||||
def test_collect_validation_findings_appends_drive_findings_with_agent_context(monkeypatch: pytest.MonkeyPatch):
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
|
||||
_patch_drive_keys(monkeypatch, existing_keys=[])
|
||||
@ -2334,7 +2332,7 @@ def test_collect_validation_findings_appends_drive_findings_with_agent_context(m
|
||||
# ── ENG-625 D5: soul-first ref removal ───────────────────────────────────────
|
||||
|
||||
|
||||
def _patch_remove_drive_refs_env(monkeypatch, *, soul_dict):
|
||||
def _patch_remove_drive_refs_env(monkeypatch: pytest.MonkeyPatch, *, soul_dict):
|
||||
"""Wire the classmethod's collaborators so soul editing + versioning is observable."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
@ -2359,7 +2357,7 @@ def _patch_remove_drive_refs_env(monkeypatch, *, soul_dict):
|
||||
return agent, captured, committed
|
||||
|
||||
|
||||
def test_remove_drive_refs_drops_skill_by_slug_and_versions(monkeypatch):
|
||||
def test_remove_drive_refs_drops_skill_by_slug_and_versions(monkeypatch: pytest.MonkeyPatch):
|
||||
soul_dict = {
|
||||
"skills_files": {
|
||||
"skills": [
|
||||
@ -2383,7 +2381,7 @@ def test_remove_drive_refs_drops_skill_by_slug_and_versions(monkeypatch):
|
||||
assert committed.get("committed") is True
|
||||
|
||||
|
||||
def test_remove_drive_refs_is_noop_when_ref_absent(monkeypatch):
|
||||
def test_remove_drive_refs_is_noop_when_ref_absent(monkeypatch: pytest.MonkeyPatch):
|
||||
soul_dict = {"skills_files": {"skills": [], "files": []}}
|
||||
agent, captured, committed = _patch_remove_drive_refs_env(monkeypatch, soul_dict=soul_dict)
|
||||
|
||||
@ -2397,7 +2395,7 @@ def test_remove_drive_refs_is_noop_when_ref_absent(monkeypatch):
|
||||
assert committed == {}
|
||||
|
||||
|
||||
def test_remove_drive_refs_drops_file_by_key(monkeypatch):
|
||||
def test_remove_drive_refs_drops_file_by_key(monkeypatch: pytest.MonkeyPatch):
|
||||
soul_dict = {
|
||||
"skills_files": {
|
||||
"skills": [],
|
||||
@ -2417,7 +2415,7 @@ def test_remove_drive_refs_drops_file_by_key(monkeypatch):
|
||||
assert [f.drive_key for f in captured["agent_soul"].skills_files.files] == ["files/keep.pdf"]
|
||||
|
||||
|
||||
def test_add_drive_file_ref_adds_or_replaces_file_and_versions(monkeypatch):
|
||||
def test_add_drive_file_ref_adds_or_replaces_file_and_versions(monkeypatch: pytest.MonkeyPatch):
|
||||
soul_dict = {
|
||||
"skills_files": {
|
||||
"skills": [],
|
||||
@ -2444,7 +2442,7 @@ def test_add_drive_file_ref_adds_or_replaces_file_and_versions(monkeypatch):
|
||||
assert committed.get("committed") is True
|
||||
|
||||
|
||||
def test_add_drive_file_ref_syncs_workflow_binding_snapshot(monkeypatch):
|
||||
def test_add_drive_file_ref_syncs_workflow_binding_snapshot(monkeypatch: pytest.MonkeyPatch):
|
||||
binding = SimpleNamespace(agent_id="agent-1", current_snapshot_id="snap-1", updated_by=None)
|
||||
_patch_remove_drive_refs_env(monkeypatch, soul_dict={"skills_files": {"skills": [], "files": []}})
|
||||
monkeypatch.setattr(
|
||||
@ -2473,7 +2471,7 @@ def test_remove_drive_refs_requires_exactly_one_scope():
|
||||
# ── ENG-623/625: resolver helpers + save-path drive guard ────────────────────
|
||||
|
||||
|
||||
def test_resolve_bound_agent_id_queries_active_roster_agent(monkeypatch):
|
||||
def test_resolve_bound_agent_id_queries_active_roster_agent(monkeypatch: pytest.MonkeyPatch):
|
||||
from types import SimpleNamespace
|
||||
|
||||
import services.agent.composer_service as module
|
||||
@ -2482,7 +2480,7 @@ def test_resolve_bound_agent_id_queries_active_roster_agent(monkeypatch):
|
||||
assert AgentComposerService.resolve_bound_agent_id(tenant_id="t-1", app_id="app-1") == "agent-9"
|
||||
|
||||
|
||||
def test_resolve_workflow_node_agent_id_degrades_without_workflow_or_binding(monkeypatch):
|
||||
def test_resolve_workflow_node_agent_id_degrades_without_workflow_or_binding(monkeypatch: pytest.MonkeyPatch):
|
||||
from types import SimpleNamespace
|
||||
|
||||
def boom(cls, **kwargs):
|
||||
@ -2505,7 +2503,7 @@ def test_resolve_workflow_node_agent_id_degrades_without_workflow_or_binding(mon
|
||||
assert AgentComposerService.resolve_workflow_node_agent_id(tenant_id="t", app_id="a", node_id="n") == "agent-7"
|
||||
|
||||
|
||||
def test_remove_drive_refs_returns_none_without_agent_or_snapshot(monkeypatch):
|
||||
def test_remove_drive_refs_returns_none_without_agent_or_snapshot(monkeypatch: pytest.MonkeyPatch):
|
||||
from types import SimpleNamespace
|
||||
|
||||
import services.agent.composer_service as module
|
||||
@ -2518,7 +2516,7 @@ def test_remove_drive_refs_returns_none_without_agent_or_snapshot(monkeypatch):
|
||||
assert AgentComposerService.remove_drive_refs(tenant_id="t", agent_id="a", account_id="u", skill_slug="s") is None
|
||||
|
||||
|
||||
def test_save_workflow_composer_guards_drive_refs_for_existing_agent_strategies(monkeypatch):
|
||||
def test_save_workflow_composer_guards_drive_refs_for_existing_agent_strategies(monkeypatch: pytest.MonkeyPatch):
|
||||
from types import SimpleNamespace
|
||||
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
|
||||
@ -94,14 +94,15 @@ def test_validate_snippet_graph_forbidden_nodes_raises_with_node_details() -> No
|
||||
|
||||
|
||||
def test_get_snippets_returns_empty_when_tag_filter_has_no_targets(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = _SessionWithoutNameLookup()
|
||||
get_target_ids = Mock(return_value=[])
|
||||
monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", get_target_ids)
|
||||
service = SnippetService.__new__(SnippetService)
|
||||
|
||||
result = service.get_snippets(tenant_id="tenant-1", tag_ids=["tag-1"])
|
||||
result = service.get_snippets(tenant_id="tenant-1", session=session, tag_ids=["tag-1"])
|
||||
|
||||
assert result == ([], 0, False)
|
||||
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True)
|
||||
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], session, match_all=True)
|
||||
|
||||
|
||||
def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -124,6 +125,7 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa
|
||||
|
||||
result, total, has_more = service.get_snippets(
|
||||
tenant_id="tenant-1",
|
||||
session=session,
|
||||
page=2,
|
||||
limit=2,
|
||||
keyword="search",
|
||||
@ -135,7 +137,7 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa
|
||||
assert result == snippets[:2]
|
||||
assert total == 3
|
||||
assert has_more is True
|
||||
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True)
|
||||
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], session, match_all=True)
|
||||
session.scalar.assert_called_once()
|
||||
session.scalars.assert_called_once()
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models.enums import TagType
|
||||
@ -8,19 +9,19 @@ from services.tag_service import TagBindingCreatePayload, TagBindingDeletePayloa
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user(mocker):
|
||||
def current_user(mocker: MockerFixture):
|
||||
user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1")
|
||||
mocker.patch("services.tag_service.current_user", user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(mocker):
|
||||
mock_db = mocker.patch("services.tag_service.db")
|
||||
def db_session(mocker: MockerFixture):
|
||||
mock_db = mocker.Mock()
|
||||
return mock_db.session
|
||||
|
||||
|
||||
def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker, current_user, db_session):
|
||||
def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker: MockerFixture, current_user, db_session):
|
||||
mocker.patch("services.tag_service.TagService.check_target_exists")
|
||||
db_session.scalars.return_value.all.return_value = ["tag-1"]
|
||||
db_session.scalar.return_value = None
|
||||
@ -30,7 +31,8 @@ def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker, c
|
||||
tag_ids=["tag-1", "tag-from-other-tenant"],
|
||||
target_id="snippet-1",
|
||||
type=TagType.SNIPPET,
|
||||
)
|
||||
),
|
||||
db_session,
|
||||
)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
@ -42,7 +44,7 @@ def test_save_tag_binding_only_creates_bindings_for_valid_snippet_tags(mocker, c
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker, current_user, db_session):
|
||||
def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker: MockerFixture, current_user, db_session):
|
||||
mocker.patch("services.tag_service.TagService.check_target_exists")
|
||||
db_session.execute.return_value = SimpleNamespace(rowcount=1)
|
||||
|
||||
@ -51,14 +53,15 @@ def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker, curren
|
||||
tag_ids=["tag-1", "tag-from-other-tenant"],
|
||||
target_id="snippet-1",
|
||||
type=TagType.SNIPPET,
|
||||
)
|
||||
),
|
||||
db_session,
|
||||
)
|
||||
|
||||
db_session.execute.assert_called_once()
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current_user, db_session):
|
||||
def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker: MockerFixture, current_user, db_session):
|
||||
mocker.patch("services.tag_service.TagService.check_target_exists")
|
||||
db_session.execute.return_value = SimpleNamespace(rowcount=0)
|
||||
|
||||
@ -67,7 +70,8 @@ def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current
|
||||
tag_ids=["tag-1"],
|
||||
target_id="snippet-1",
|
||||
type=TagType.SNIPPET,
|
||||
)
|
||||
),
|
||||
db_session,
|
||||
)
|
||||
|
||||
db_session.execute.assert_called_once()
|
||||
@ -75,7 +79,7 @@ def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current
|
||||
|
||||
|
||||
def test_get_target_ids_by_tag_ids_returns_empty_without_query_for_empty_input(db_session):
|
||||
result = TagService.get_target_ids_by_tag_ids(TagType.SNIPPET, "tenant-1", [])
|
||||
result = TagService.get_target_ids_by_tag_ids(TagType.SNIPPET, "tenant-1", [], db_session)
|
||||
|
||||
assert result == []
|
||||
db_session.scalars.assert_not_called()
|
||||
@ -84,7 +88,7 @@ def test_get_target_ids_by_tag_ids_returns_empty_without_query_for_empty_input(d
|
||||
def test_check_target_exists_accepts_existing_snippet(current_user, db_session):
|
||||
db_session.scalar.return_value = SimpleNamespace(id="snippet-1")
|
||||
|
||||
TagService.check_target_exists("snippet", "snippet-1")
|
||||
TagService.check_target_exists("snippet", "snippet-1", db_session)
|
||||
|
||||
db_session.scalar.assert_called_once()
|
||||
|
||||
@ -93,11 +97,11 @@ def test_check_target_exists_raises_when_snippet_missing(current_user, db_sessio
|
||||
db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(NotFound, match="Snippet not found"):
|
||||
TagService.check_target_exists("snippet", "missing-snippet")
|
||||
TagService.check_target_exists("snippet", "missing-snippet", db_session)
|
||||
|
||||
|
||||
def test_check_target_exists_raises_for_invalid_binding_type(current_user, db_session):
|
||||
with pytest.raises(NotFound, match="Invalid binding type"):
|
||||
TagService.check_target_exists("unknown", "target-1")
|
||||
TagService.check_target_exists("unknown", "target-1", db_session)
|
||||
|
||||
db_session.scalar.assert_not_called()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user