Merge branch 'main' into feat/workflow-run-history-infinite-scroll

This commit is contained in:
Benjamin 2026-04-03 17:51:35 +08:00 committed by GitHub
commit 5896aa6327
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1144 changed files with 26582 additions and 12310 deletions

2
.gitignore vendored
View File

@ -212,7 +212,7 @@ api/.vscode
# pnpm
/.pnpm-store
/node_modules
node_modules
.vite-hooks/_
# plugin migrate

View File

@ -89,6 +89,12 @@ if $web_modified; then
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi
echo "Running knip"
if ! pnpm run knip; then
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
exit 1
fi
echo "Running unit tests check"
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@ -71,7 +71,7 @@ class AppImportApi(Resource):
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = current_user

View File

@ -193,7 +193,7 @@ workflow_draft_variable_list_model = console_ns.model(
)
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -210,7 +210,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: Any, **kwargs: Any):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(*args, **kwargs)
return wrapper

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import overload
from sqlalchemy import select
@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
def get_app_model(
view: Callable[..., Any] | None = None,
@overload
def get_app_model[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
) -> Callable[P, R]: ...
@overload
def get_app_model[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
@ -68,14 +84,30 @@ def get_app_model(
return decorator(view)
def get_app_model_with_trial(
view: Callable[..., Any] | None = None,
@overload
def get_app_model_with_trial[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
) -> Callable[P, R]: ...
@overload
def get_app_model_with_trial[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model_with_trial[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@ -158,10 +158,11 @@ class DataSourceApi(Resource):
@login_required
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
_, current_tenant_id = current_account_with_tenant()
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")

View File

@ -1,4 +1,5 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from flask import Response, request
@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite(f):
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -70,7 +71,7 @@ def _api_prerequisite(f):
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)

View File

@ -1,9 +1,10 @@
import inspect
import logging
import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import Any, cast, overload
from typing import cast, overload
from flask import current_app, request
from flask_login import user_logged_in
@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R](
return interceptor
def validate_dataset_token(
view: Callable[..., Any] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated(*args: Any, **kwargs: Any) -> Any:
api_token = validate_and_get_api_token("dataset")
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
positional_parameters = [
parameter
for parameter in inspect.signature(view).parameters.values()
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
@wraps(view)
def decorated(*args: object, **kwargs: object) -> R:
api_token = validate_and_get_api_token("dataset")
# First try to get from kwargs (explicit parameter)
dataset_id = kwargs.get("dataset_id")
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
dataset_id = kwargs.get("dataset_id")
# If not in kwargs, try to extract from positional args
if not dataset_id and args:
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
# Check if first arg is likely a class instance (has __dict__ or __class__)
if len(args) > 1 and hasattr(args[0], "__dict__"):
# This is a class method, dataset_id should be in args[1]
potential_id = args[1]
# Validate it's a string-like UUID, not another object
try:
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from class method args")
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
if not dataset_id and args:
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.limit(1)
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
.limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant does not exist.")
if args and isinstance(args[0], Resource):
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs)
if expects_bound_instance:
if not args:
raise TypeError("validate_dataset_token expected a bound resource instance.")
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
return decorated
return view(api_token.tenant_id, *args, **kwargs)
if view:
return decorator(view)
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
return decorated
def validate_and_get_api_token(scope: str | None = None):

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
from services.trigger.webhook_service import WebhookService
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
logger = logging.getLogger(__name__)
@ -23,6 +23,7 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
webhook_id, is_debug=is_debug
)
webhook_data: RawWebhookDataDict
try:
# Use new unified extraction and validation
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)

View File

@ -3,13 +3,19 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any
from typing import Any, TypedDict
import orjson
from configs import dify_config
class IdentityDict(TypedDict, total=False):
tenant_id: str
user_id: str
user_type: str
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
@ -84,7 +90,7 @@ class StructuredJSONFormatter(logging.Formatter):
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
@ -92,7 +98,7 @@ class StructuredJSONFormatter(logging.Formatter):
if not any([tenant_id, user_id, user_type]):
return None
identity: dict[str, str] = {}
identity: IdentityDict = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:

View File

@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta
from types import TracebackType
from typing import Any, Self, cast
from typing import Any, Self
from httpx import HTTPStatusError
from pydantic import BaseModel
@ -338,12 +338,11 @@ class BaseSession[
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
validated_request = cast(ReceiveRequestT, validated_request)
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
@ -359,15 +358,14 @@ class BaseSession[
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
notification = cast(ReceiveNotificationT, notification)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification)
self._handle_incoming(notification)
self._received_notification(notification) # type: ignore[arg-type]
self._handle_incoming(notification) # type: ignore[arg-type]
except Exception as e:
# For other validation errors, log and continue
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)

View File

@ -1,5 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from flask import Flask
if TYPE_CHECKING:
from extensions.ext_login import DifyLoginManager
class DifyApp(Flask):
pass
"""Flask application type with Dify-specific extension attributes."""
login_manager: DifyLoginManager

View File

@ -1,7 +1,8 @@
import json
from typing import cast
import flask_login
from flask import Response, request
from flask import Request, Response, request
from flask_login import user_loaded_from_request, user_logged_in
from sqlalchemy import select
from werkzeug.exceptions import NotFound, Unauthorized
@ -16,13 +17,35 @@ from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
login_manager = flask_login.LoginManager()
type LoginUser = Account | EndUser
class DifyLoginManager(flask_login.LoginManager):
"""Project-specific Flask-Login manager with a stable unauthorized contract.
Dify registers `unauthorized_handler` below to always return a JSON `Response`.
Overriding this method lets callers rely on that narrower return type instead of
Flask-Login's broader callback contract.
"""
def unauthorized(self) -> Response:
"""Return the registered unauthorized handler result as a Flask `Response`."""
return cast(Response, super().unauthorized())
def load_user_from_request_context(self) -> None:
"""Populate Flask-Login's request-local user cache for the current request."""
self._load_user()
login_manager = DifyLoginManager()
# Flask-Login configuration
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
def load_user_from_request(request_from_flask_login: Request) -> LoginUser | None:
"""Load user based on the request."""
del request_from_flask_login
# Skip authentication for documentation endpoints
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
return None
@ -100,10 +123,12 @@ def load_user_from_request(request_from_flask_login):
raise NotFound("End user not found.")
return end_user
return None
@user_logged_in.connect
@user_loaded_from_request.connect
def on_user_logged_in(_sender, user):
def on_user_logged_in(_sender: object, user: LoginUser) -> None:
"""Called when a user logged in.
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
@ -114,8 +139,10 @@ def on_user_logged_in(_sender, user):
@login_manager.unauthorized_handler
def unauthorized_handler():
def unauthorized_handler() -> Response:
"""Handle unauthorized requests."""
# Keep this as a concrete `Response`; `DifyLoginManager.unauthorized()` narrows
# Flask-Login's callback contract based on this override.
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
@ -123,5 +150,5 @@ def unauthorized_handler():
)
def init_app(app: DifyApp):
def init_app(app: DifyApp) -> None:
login_manager.init_app(app)

View File

@ -2,19 +2,19 @@ from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from flask import current_app, g, has_request_context, request
from flask import Response, current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_login import DifyLoginManager
from libs.token import check_csrf_token
from models import Account
if TYPE_CHECKING:
from flask.typing import ResponseReturnValue
from models.model import EndUser
@ -29,7 +29,13 @@ def _resolve_current_user() -> EndUser | Account | None:
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
def current_account_with_tenant():
def _get_login_manager() -> DifyLoginManager:
"""Return the project login manager with Dify's narrowed unauthorized contract."""
app = cast(DifyApp, current_app)
return app.login_manager
def current_account_with_tenant() -> tuple[Account, str]:
"""
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
@ -42,7 +48,7 @@ def current_account_with_tenant():
return user, user.current_tenant_id
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
@ -77,13 +83,16 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
"""
@wraps(func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _resolve_current_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# `DifyLoginManager` guarantees that the registered unauthorized handler
# is surfaced here as a concrete Flask `Response`.
unauthorized_response: Response = _get_login_manager().unauthorized()
return unauthorized_response
g._login_user = user
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
@ -96,7 +105,7 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
def _get_user() -> EndUser | Account | None:
if has_request_context():
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
_get_login_manager().load_user_from_request_context()
return g._login_user

View File

@ -171,7 +171,7 @@ dev = [
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.57.1",
"pyrefly>=0.59.1",
]
############################################################

View File

@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import func, select
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
@ -144,22 +144,26 @@ class AccountService:
@staticmethod
def load_user(user_id: str) -> None | Account:
account = db.session.query(Account).filter_by(id=user_id).first()
account = db.session.get(Account, user_id)
if not account:
return None
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
current_tenant = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
.limit(1)
)
if current_tenant:
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = (
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
available_ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
.limit(1)
)
if not available_ta:
return None
@ -195,7 +199,7 @@ class AccountService:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
raise AccountPasswordError("Invalid email or password.")
@ -371,8 +375,10 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: AccountIntegrate | None = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
account_integrate: AccountIntegrate | None = db.session.scalar(
select(AccountIntegrate)
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
.limit(1)
)
if account_integrate:
@ -416,7 +422,9 @@ class AccountService:
def update_account_email(account: Account, email: str) -> Account:
"""Update account email"""
account.email = email
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first()
account_integrate = db.session.scalar(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
)
if account_integrate:
db.session.delete(account_integrate)
db.session.add(account)
@ -818,7 +826,7 @@ class AccountService:
)
)
account = db.session.query(Account).where(Account.email == email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
return None
@ -1018,7 +1026,7 @@ class AccountService:
@staticmethod
def check_email_unique(email: str) -> bool:
return db.session.query(Account).filter_by(email=email).first() is None
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
class TenantService:
@ -1384,10 +1392,10 @@ class RegisterService:
db.session.add(dify_setup)
db.session.commit()
except Exception as e:
db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Account).delete()
db.session.query(Tenant).delete()
db.session.execute(delete(DifySetup))
db.session.execute(delete(TenantAccountJoin))
db.session.execute(delete(Account))
db.session.execute(delete(Tenant))
db.session.commit()
logger.exception("Setup account failed, email: %s, name: %s", email, name)
@ -1488,7 +1496,11 @@ class RegisterService:
TenantService.switch_tenant(account, tenant.id)
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not ta:
TenantService.create_tenant_member(tenant, account, role)
@ -1545,21 +1557,18 @@ class RegisterService:
if not invitation_data:
return None
tenant = (
db.session.query(Tenant)
.where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.first()
tenant = db.session.scalar(
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
)
if not tenant:
return None
tenant_account = (
db.session.query(Account, TenantAccountJoin.role)
tenant_account = db.session.execute(
select(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.first()
)
).first()
if not tenant_account:
return None

View File

@ -4,6 +4,8 @@ import uuid
import pandas as pd
logger = logging.getLogger(__name__)
from typing import TypedDict
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -23,6 +25,27 @@ from tasks.annotation.enable_annotation_reply_task import enable_annotation_repl
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
class AnnotationJobStatusDict(TypedDict):
job_id: str
job_status: str
class EmbeddingModelDict(TypedDict):
embedding_provider_name: str
embedding_model_name: str
class AnnotationSettingDict(TypedDict):
id: str
enabled: bool
score_threshold: float
embedding_model: EmbeddingModelDict | dict
class AnnotationSettingDisabledDict(TypedDict):
enabled: bool
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
@ -85,7 +108,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str):
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@ -109,7 +132,7 @@ class AppAnnotationService:
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str):
def disable_app_annotation(cls, app_id: str) -> AnnotationJobStatusDict:
_, current_tenant_id = current_account_with_tenant()
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
@ -567,7 +590,7 @@ class AppAnnotationService:
db.session.commit()
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict:
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
@ -602,7 +625,9 @@ class AppAnnotationService:
return {"enabled": False}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
def update_app_annotation_setting(
cls, app_id: str, annotation_setting_id: str, args: dict
) -> AnnotationSettingDict:
current_user, current_tenant_id = current_account_with_tenant()
# get app info
app = (

View File

@ -32,6 +32,11 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class KnowledgeRateLimitDict(TypedDict):
limit: int
subscription_plan: str
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@ -58,7 +63,7 @@ class BillingService:
return usage_info
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)

View File

@ -5,7 +5,7 @@ from urllib.parse import urlparse
import httpx
from graphon.nodes.http_request.exc import InvalidHttpMethodError
from sqlalchemy import select
from sqlalchemy import func, select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@ -103,8 +103,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -112,8 +114,10 @@ class ExternalDatasetService:
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -132,8 +136,10 @@ class ExternalDatasetService:
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -144,9 +150,12 @@ class ExternalDatasetService:
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
db.session.scalar(
select(func.count(ExternalKnowledgeBindings.id)).where(
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
)
)
or 0
)
if count > 0:
return True, count
@ -154,8 +163,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding: ExternalKnowledgeBindings | None = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
@ -163,8 +174,10 @@ class ExternalDatasetService:
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("api template not found")
@ -238,12 +251,17 @@ class ExternalDatasetService:
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
if db.session.scalar(
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
):
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(
ExternalKnowledgeApis.id == args.get("external_knowledge_api_id"),
ExternalKnowledgeApis.tenant_id == tenant_id,
)
.limit(1)
)
if external_knowledge_api is None:
@ -286,16 +304,18 @@ class ExternalDatasetService:
external_retrieval_parameters: dict,
metadata_condition: MetadataCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("external api template not found")

View File

@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any
from typing import Any, TypedDict
from graphon.model_runtime.entities import LLMMode
@ -18,6 +18,16 @@ from models.enums import CreatorUserRole, DatasetQuerySource
logger = logging.getLogger(__name__)
class QueryDict(TypedDict):
content: str
class RetrieveResponseDict(TypedDict):
query: QueryDict
records: list[dict[str, Any]]
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
@ -150,7 +160,7 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict:
records = RetrievalService.format_retrieval_documents(documents)
return {
@ -161,7 +171,7 @@ class HitTestingService:
}
@classmethod
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> RetrieveResponseDict:
records = []
if dataset.provider == "external":
for document in documents:

View File

@ -1,6 +1,8 @@
import copy
import logging
from sqlalchemy import delete, func, select
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -25,10 +27,14 @@ class MetadataService:
raise ValueError("Metadata name cannot exceed 255 characters.")
current_user, current_tenant_id = current_account_with_tenant()
# check if metadata name already exists
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
if db.session.scalar(
select(DatasetMetadata)
.where(
DatasetMetadata.tenant_id == current_tenant_id,
DatasetMetadata.dataset_id == dataset_id,
DatasetMetadata.name == metadata_args.name,
)
.limit(1)
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
@ -54,10 +60,14 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
current_user, current_tenant_id = current_account_with_tenant()
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name)
.first()
if db.session.scalar(
select(DatasetMetadata)
.where(
DatasetMetadata.tenant_id == current_tenant_id,
DatasetMetadata.dataset_id == dataset_id,
DatasetMetadata.name == name,
)
.limit(1)
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
@ -65,7 +75,11 @@ class MetadataService:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
.limit(1)
)
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
@ -74,9 +88,9 @@ class MetadataService:
metadata.updated_at = naive_utc_now()
# update related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
dataset_metadata_bindings = db.session.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -101,15 +115,19 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
.limit(1)
)
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)
# deal related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
dataset_metadata_bindings = db.session.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -224,16 +242,23 @@ class MetadataService:
# deal metadata binding (in the same transaction as the doc_metadata update)
if not operation.partial_update:
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
db.session.execute(
delete(DatasetMetadataBinding).where(
DatasetMetadataBinding.document_id == operation.document_id
)
)
current_user, current_tenant_id = current_account_with_tenant()
for metadata_value in operation.metadata_list:
# check if binding already exists
if operation.partial_update:
existing_binding = (
db.session.query(DatasetMetadataBinding)
.filter_by(document_id=operation.document_id, metadata_id=metadata_value.id)
.first()
existing_binding = db.session.scalar(
select(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.document_id == operation.document_id,
DatasetMetadataBinding.metadata_id == metadata_value.id,
)
.limit(1)
)
if existing_binding:
continue
@ -275,9 +300,13 @@ class MetadataService:
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"count": db.session.query(DatasetMetadataBinding)
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
.count(),
"count": db.session.scalar(
select(func.count(DatasetMetadataBinding.id)).where(
DatasetMetadataBinding.metadata_id == item.get("id"),
DatasetMetadataBinding.dataset_id == dataset.id,
)
)
or 0,
}
for item in dataset.doc_metadata or []
if item.get("id") != "built-in"

View File

@ -156,27 +156,27 @@ class RagPipelineService:
:param template_id: template id
:param template_info: template info
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
# check template name is exist
template_name = template_info.name
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
@ -192,13 +192,13 @@ class RagPipelineService:
"""
Delete customized pipeline template.
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
@ -210,14 +210,14 @@ class RagPipelineService:
Get draft workflow
"""
# fetch draft workflow by rag pipeline
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
.limit(1)
)
# return draft workflow
@ -232,28 +232,28 @@ class RagPipelineService:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id,
)
.first()
.limit(1)
)
return workflow
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""Fetch a published workflow snapshot by ID for restore operations."""
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == workflow_id,
)
.first()
.limit(1)
)
if workflow and workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("source workflow must be published")
@ -974,7 +974,7 @@ class RagPipelineService:
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
document = db.session.get(Document, document_id.value)
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = error
@ -1178,12 +1178,12 @@ class RagPipelineService:
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
pipeline = db.session.get(Pipeline, pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
workflow = db.session.get(Workflow, pipeline.workflow_id)
if not workflow:
raise ValueError("Workflow not found")
with Session(db.engine) as session:
@ -1194,21 +1194,21 @@ class RagPipelineService:
# check template name is exist
template_name = args.get("name")
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
max_position = db.session.scalar(
select(func.max(PipelineCustomizedTemplate.position)).where(
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
)
)
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -1239,13 +1239,14 @@ class RagPipelineService:
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
return (
db.session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
db.session.scalar(
select(func.count(Workflow.id)).where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
)
.count()
or 0
) > 0
def get_node_last_run(
@ -1353,11 +1354,11 @@ class RagPipelineService:
def get_recommended_plugins(self, type: str) -> dict:
# Query active recommended plugins
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":
query = query.where(PipelineRecommendedPlugin.type == type)
stmt = stmt.where(PipelineRecommendedPlugin.type == type)
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
if not pipeline_recommended_plugins:
return {
@ -1396,14 +1397,12 @@ class RagPipelineService:
"""
Retry error document
"""
document_pipeline_execution_log = (
db.session.query(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document.id)
.first()
document_pipeline_execution_log = db.session.scalar(
select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
)
if not document_pipeline_execution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
# convert to app config
@ -1432,23 +1431,23 @@ class RagPipelineService:
"""
Get datasource plugins
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")
@ -1530,23 +1529,23 @@ class RagPipelineService:
"""
Get pipeline
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")

View File

@ -3,7 +3,7 @@ import logging
from datetime import datetime
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import or_, select
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from core.tools.__base.tool_provider import ToolProviderController
@ -42,20 +42,22 @@ class WorkflowToolManageService:
labels: list[str] | None = None,
):
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.first()
.limit(1)
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)
)
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
@ -122,30 +124,30 @@ class WorkflowToolManageService:
:return: the updated tool
"""
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.first()
.limit(1)
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = (
db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
)
if app is None:
@ -234,9 +236,11 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
"""
db.session.query(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
db.session.execute(
delete(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
)
)
db.session.commit()
@ -251,10 +255,10 @@ class WorkflowToolManageService:
:param workflow_tool_id: the workflow tool id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
@ -267,10 +271,10 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
@ -284,8 +288,8 @@ class WorkflowToolManageService:
if db_tool is None:
raise ValueError("Tool not found")
workflow_app: App | None = (
db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
workflow_app: App | None = db.session.scalar(
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
)
if workflow_app is None:
@ -331,10 +335,10 @@ class WorkflowToolManageService:
:param workflow_tool_id: the workflow tool id
:return: the list of tools
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
if db_tool is None:

View File

@ -3,7 +3,7 @@ import logging
import mimetypes
import secrets
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from typing import Any, TypedDict
import orjson
from flask import request
@ -50,6 +50,14 @@ logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class RawWebhookDataDict(TypedDict):
method: str
headers: dict[str, str]
query_params: dict[str, str]
body: dict[str, Any]
files: dict[str, Any]
class WebhookService:
"""Service for handling webhook operations."""
@ -145,7 +153,7 @@ class WebhookService:
@classmethod
def extract_and_validate_webhook_data(
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
) -> dict[str, Any]:
) -> RawWebhookDataDict:
"""Extract and validate webhook data in a single unified process.
Args:
@ -173,7 +181,7 @@ class WebhookService:
return processed_data
@classmethod
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]:
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> RawWebhookDataDict:
"""Extract raw data from incoming webhook request without type conversion.
Args:
@ -189,7 +197,7 @@ class WebhookService:
"""
cls._validate_content_length()
data = {
data: RawWebhookDataDict = {
"method": request.method,
"headers": dict(request.headers),
"query_params": dict(request.args),
@ -223,7 +231,7 @@ class WebhookService:
return data
@classmethod
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
def _process_and_validate_data(cls, raw_data: RawWebhookDataDict, node_data: WebhookData) -> RawWebhookDataDict:
"""Process and validate webhook data according to node configuration.
Args:
@ -664,7 +672,7 @@ class WebhookService:
raise ValueError(f"Required header missing: {header_name}")
@classmethod
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
def _validate_http_metadata(cls, webhook_data: RawWebhookDataDict, node_data: WebhookData) -> dict[str, Any]:
"""Validate HTTP method and content-type.
Args:
@ -729,7 +737,7 @@ class WebhookService:
return False
@classmethod
def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]:
def build_workflow_inputs(cls, webhook_data: RawWebhookDataDict) -> dict[str, Any]:
"""Construct workflow inputs payload from webhook data.
Args:
@ -747,7 +755,7 @@ class WebhookService:
@classmethod
def trigger_workflow_execution(
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: RawWebhookDataDict, workflow: Workflow
) -> None:
"""Trigger workflow execution via AsyncWorkflowService.

View File

@ -129,6 +129,7 @@ class VariableTruncator(BaseTruncator):
used_size += self.calculate_json_size(key)
if used_size > budget:
truncated_mapping[key] = "..."
is_truncated = True
continue
value_budget = (budget - used_size) // (length - len(truncated_mapping))
if isinstance(value, Segment):
@ -164,9 +165,9 @@ class VariableTruncator(BaseTruncator):
result = self._truncate_segment(segment, self._max_size_bytes)
if result.value_size > self._max_size_bytes:
if isinstance(result.value, str):
result = self._truncate_string(result.value, self._max_size_bytes)
return TruncationResult(StringSegment(value=result.value), True)
if isinstance(result.value, StringSegment):
fallback_result = self._truncate_string(result.value.value, self._max_size_bytes)
return TruncationResult(StringSegment(value=fallback_result.value), True)
# Apply final fallback - convert to JSON string and truncate
json_str = dumps_with_segments(result.value, ensure_ascii=False)

View File

@ -20,7 +20,7 @@ def app():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["RESTX_MASK_HEADER"] = "X-Fields"
app.login_manager = SimpleNamespace(_load_user=lambda: None)
app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
return app

View File

@ -12,7 +12,7 @@ from models.account import Account, TenantAccountRole
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
flask_app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
return flask_app

View File

@ -0,0 +1,17 @@
import json
from flask import Response
from extensions.ext_login import unauthorized_handler
def test_unauthorized_handler_returns_json_response() -> None:
response = unauthorized_handler()
assert isinstance(response, Response)
assert response.status_code == 401
assert response.content_type == "application/json"
assert json.loads(response.get_data(as_text=True)) == {
"code": "unauthorized",
"message": "Unauthorized.",
}

View File

@ -2,11 +2,12 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask, g
from flask_login import LoginManager, UserMixin
from flask import Flask, Response, g
from flask_login import UserMixin
from pytest_mock import MockerFixture
import libs.login as login_module
from extensions.ext_login import DifyLoginManager
from libs.login import current_user
from models.account import Account
@ -39,9 +40,12 @@ def login_app(mocker: MockerFixture) -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
login_manager = LoginManager()
login_manager = DifyLoginManager()
login_manager.init_app(app)
login_manager.unauthorized = mocker.Mock(name="unauthorized", return_value="Unauthorized")
login_manager.unauthorized = mocker.Mock(
name="unauthorized",
return_value=Response("Unauthorized", status=401, content_type="application/json"),
)
@login_manager.user_loader
def load_user(_user_id: str):
@ -109,18 +113,43 @@ class TestLoginRequired:
resolved_user: MockUser | None,
description: str,
):
"""Test that missing or unauthenticated users are redirected."""
"""Test that missing or unauthenticated users return the manager response."""
resolve_user = resolve_current_user(resolved_user)
with login_app.test_request_context():
result = protected_view()
assert result == "Unauthorized", description
assert result is login_app.login_manager.unauthorized.return_value, description
assert isinstance(result, Response)
assert result.status_code == 401
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_called_once_with()
csrf_check.assert_not_called()
def test_unauthorized_access_propagates_response_object(
self,
login_app: Flask,
protected_view,
csrf_check: MagicMock,
resolve_current_user,
mocker: MockerFixture,
) -> None:
"""Test that unauthorized responses are propagated as Flask Response objects."""
resolve_user = resolve_current_user(None)
response = Response("Unauthorized", status=401, content_type="application/json")
mocker.patch.object(
login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response)
)
with login_app.test_request_context():
result = protected_view()
assert result is response
assert isinstance(result, Response)
resolve_user.assert_called_once_with()
csrf_check.assert_not_called()
@pytest.mark.parametrize(
("method", "login_disabled"),
[
@ -168,10 +197,14 @@ class TestGetUser:
"""Test that _get_user loads user if not already in g."""
mock_user = MockUser("test_user")
def _load_user() -> None:
def load_user_from_request_context() -> None:
g._login_user = mock_user
load_user = mocker.patch.object(login_app.login_manager, "_load_user", side_effect=_load_user)
load_user = mocker.patch.object(
login_app.login_manager,
"load_user_from_request_context",
side_effect=load_user_from_request_context,
)
with login_app.test_request_context():
user = login_module._get_user()

View File

@ -401,10 +401,7 @@ class TestMetadataServiceCreateMetadata:
metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string")
# Mock query to return None (no existing metadata with same name)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Mock BuiltInField enum iteration
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@ -417,10 +414,6 @@ class TestMetadataServiceCreateMetadata:
assert result is not None
assert isinstance(result, DatasetMetadata)
# Verify query was made to check for duplicates
mock_db_session.query.assert_called()
mock_query.filter_by.assert_called()
# Verify metadata was added and committed
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
@ -468,10 +461,7 @@ class TestMetadataServiceCreateMetadata:
# Mock existing metadata with same name
existing_metadata = MetadataTestDataFactory.create_metadata_mock(name="category")
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_metadata
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = existing_metadata
# Act & Assert
with pytest.raises(ValueError, match="Metadata name already exists"):
@ -500,10 +490,7 @@ class TestMetadataServiceCreateMetadata:
)
# Mock query to return None (no duplicate in database)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Mock BuiltInField to include the conflicting name
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@ -597,27 +584,11 @@ class TestMetadataServiceUpdateMetadataName:
existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
# Mock query for duplicate check (no duplicate)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
# Mock metadata retrieval
def query_side_effect(model):
if model == DatasetMetadata:
mock_meta_query = Mock()
mock_meta_query.filter_by.return_value = mock_meta_query
mock_meta_query.first.return_value = existing_metadata
return mock_meta_query
return mock_query
mock_db_session.query.side_effect = query_side_effect
# Mock scalar calls: first for duplicate check (None), second for metadata retrieval
mock_db_session.scalar.side_effect = [None, existing_metadata]
# Mock no metadata bindings (no documents to update)
mock_binding_query = Mock()
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
# Mock BuiltInField enum
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@ -655,22 +626,8 @@ class TestMetadataServiceUpdateMetadataName:
metadata_id = "non-existent-metadata"
new_name = "updated_category"
# Mock query for duplicate check (no duplicate)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
# Mock metadata retrieval to return None
def query_side_effect(model):
if model == DatasetMetadata:
mock_meta_query = Mock()
mock_meta_query.filter_by.return_value = mock_meta_query
mock_meta_query.first.return_value = None # Not found
return mock_meta_query
return mock_query
mock_db_session.query.side_effect = query_side_effect
# Mock scalar calls: first for duplicate check (None), second for metadata retrieval (None = not found)
mock_db_session.scalar.side_effect = [None, None]
# Mock BuiltInField enum
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@ -746,15 +703,10 @@ class TestMetadataServiceDeleteMetadata:
existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
# Mock metadata retrieval
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_metadata
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = existing_metadata
# Mock no metadata bindings (no documents to update)
mock_binding_query = Mock()
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
# Act
result = MetadataService.delete_metadata(dataset_id, metadata_id)
@ -788,10 +740,7 @@ class TestMetadataServiceDeleteMetadata:
metadata_id = "non-existent-metadata"
# Mock metadata retrieval to return None
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="Metadata not found"):
@ -1013,10 +962,7 @@ class TestMetadataServiceGetDatasetMetadatas:
)
# Mock usage count queries
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 5 # 5 documents use this metadata
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = 5 # 5 documents use this metadata
# Act
result = MetadataService.get_dataset_metadatas(dataset)

View File

@ -292,7 +292,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
mock_db_session.scalar.return_value = api
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
assert result is api
@ -302,7 +302,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
When the record is absent, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
@ -320,7 +320,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
existing_api = Mock(spec=ExternalKnowledgeApis)
existing_api.settings_dict = {"api_key": "stored-key"}
existing_api.settings = '{"api_key":"stored-key"}'
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
mock_db_session.scalar.return_value = existing_api
args = {
"name": "New Name",
@ -340,7 +340,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Updating a nonexistent API template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.update_external_knowledge_api(
@ -356,7 +356,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
mock_db_session.scalar.return_value = api
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
@ -368,7 +368,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Deletion of a missing template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
@ -394,7 +394,7 @@ class TestExternalDatasetServiceUsageAndBindings:
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
mock_db_session.scalar.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
@ -406,7 +406,7 @@ class TestExternalDatasetServiceUsageAndBindings:
Zero bindings should return ``(False, 0)``.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
mock_db_session.scalar.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
@ -419,7 +419,7 @@ class TestExternalDatasetServiceUsageAndBindings:
"""
binding = Mock(spec=ExternalKnowledgeBindings)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
mock_db_session.scalar.return_value = binding
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
assert result is binding
@ -429,7 +429,7 @@ class TestExternalDatasetServiceUsageAndBindings:
Missing binding should result in a ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
@ -460,7 +460,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
# Raw string; the service itself calls json.loads on it
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
mock_db_session.scalar.return_value = external_api
process_parameter = {"foo": "value", "bar": "optional"}
@ -474,7 +474,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
When the referenced API template is missing, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
@ -488,7 +488,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
external_api.settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
mock_db_session.scalar.return_value = external_api
process_parameter = {"bar": "present"} # missing "foo"
@ -702,7 +702,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
}
# No existing dataset with same name.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
None, # duplicatename check
Mock(spec=ExternalKnowledgeApis), # external knowledge api
]
@ -724,7 +724,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
"""
existing_dataset = Mock(spec=Dataset)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
mock_db_session.scalar.return_value = existing_dataset
args = {
"name": "Existing",
@ -744,7 +744,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
"""
# First call: duplicate name check not found.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
None,
None, # external knowledge api lookup
]
@ -763,8 +763,10 @@ class TestExternalDatasetServiceCreateExternalDataset:
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
"""
# duplicate name check
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
mock_db_session.scalar.side_effect = [
None,
Mock(spec=ExternalKnowledgeApis),
None,
Mock(spec=ExternalKnowledgeApis),
]
@ -826,7 +828,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
# First query: binding; second query: api.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
api,
]
@ -861,7 +863,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
Missing binding should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
@ -878,7 +880,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
None,
]
@ -901,7 +903,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
api,
]

View File

@ -117,9 +117,7 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv
def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None:
first_query = mocker.Mock()
first_query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=first_query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("tenant-1", "dataset-1")
@ -131,12 +129,8 @@ def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service
def test_update_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
# First query finds the template, second query (duplicate check) returns None
query_mock_1 = mocker.Mock()
query_mock_1.where.return_value.first.return_value = template
query_mock_2 = mocker.Mock()
query_mock_2.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", side_effect=[query_mock_1, query_mock_2])
# First scalar finds the template, second scalar (duplicate check) returns None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@ -152,9 +146,7 @@ def test_update_customized_pipeline_template_success(mocker) -> None:
def test_update_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i"))
@ -166,9 +158,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
duplicate = SimpleNamespace(name="dup")
query_mock = mocker.Mock()
query_mock.where.return_value.first.side_effect = [template, duplicate]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate])
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i"))
@ -181,9 +171,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
def test_delete_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(id="tpl-1")
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = template
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@ -196,9 +184,7 @@ def test_delete_customized_pipeline_template_success(mocker) -> None:
def test_delete_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
with pytest.raises(ValueError, match="Customized pipeline template not found"):
@ -397,18 +383,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.count.return_value = 1
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is True
def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.count.return_value = 0
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is False
@ -738,8 +720,7 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non
def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None:
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
from models.workflow import Workflow
from models.dataset import Pipeline
# 1. Setup mocks
pipeline = mocker.Mock(spec=Pipeline)
@ -754,36 +735,15 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
# Mock db itself to avoid app context errors
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
# Improved mocking for session.query
def mock_query_side_effect(model):
m = mocker.Mock()
if model == Pipeline:
m.where.return_value.first.return_value = pipeline
elif model == Workflow:
m.where.return_value.first.return_value = workflow
elif model == PipelineCustomizedTemplate:
m.where.return_value.first.return_value = None
elif model == Dataset:
m.where.return_value.first.return_value = mocker.Mock()
else:
# For func.max cases
m.where.return_value.scalar.return_value = 5
m.where.return_value.first.return_value = mocker.Mock()
return m
mock_db.session.query.side_effect = mock_query_side_effect
# Mock get() for Pipeline and Workflow PK lookups
mock_db.session.get.side_effect = [pipeline, workflow]
# Mock scalar() for template name check (None) and max position (5)
mock_db.session.scalar.side_effect = [None, 5]
# Mock retrieve_dataset
dataset = mocker.Mock()
pipeline.retrieve_dataset.return_value = dataset
# Mock max position
mocker.patch("services.rag_pipeline.rag_pipeline.func.max", return_value=1)
mocker.patch(
"services.rag_pipeline.rag_pipeline.db.session.query.return_value.where.return_value.scalar",
return_value=5,
)
# Mock RagPipelineDslService
mock_dsl_service = mocker.Mock()
mock_dsl_service.export_rag_pipeline_dsl.return_value = {"dsl": "content"}
@ -839,9 +799,7 @@ def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None:
workflow.rag_pipeline_variables = []
# Mock queries
mock_query = mocker.Mock()
mock_query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@ -881,11 +839,9 @@ def test_retry_error_document_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock()
# Mock queries
mock_query = mocker.Mock()
# Log lookup, then Pipeline lookup
mock_query.where.return_value.first.side_effect = [log, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
# Mock queries: Log lookup via scalar, Pipeline lookup via get
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@ -913,7 +869,7 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None:
# Mock db aggressively
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock()
mock_db.session.query.return_value.where.return_value.first.return_value = mocker.Mock()
mock_db.session.scalar.return_value = mocker.Mock()
pipeline = mocker.Mock(spec=Pipeline)
pipeline.id = "p-1"
@ -976,7 +932,7 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
mock_db.session.scalar.return_value = workflow
# 2. Run test
result = rag_pipeline_service.get_draft_workflow(pipeline)
@ -998,7 +954,7 @@ def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
mock_db.session.scalar.return_value = workflow
# 2. Run test
result = rag_pipeline_service.get_published_workflow(pipeline)
@ -1319,11 +1275,8 @@ def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions
def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = []
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = []
result = rag_pipeline_service.get_recommended_plugins("all")
@ -1336,11 +1289,8 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None:
plugin_a = SimpleNamespace(plugin_id="plugin-a")
plugin_b = SimpleNamespace(plugin_id="plugin-b")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin_a, plugin_b]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch(
"services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools",
@ -1568,9 +1518,7 @@ def test_get_second_step_parameters_filters_first_step_variables(mocker, rag_pip
def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Document pipeline execution log not found"):
rag_pipeline_service.retry_error_document(
@ -1581,9 +1529,7 @@ def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pi
def test_get_datasource_plugins_raises_when_workflow_not_found(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Pipeline or workflow not found"):
@ -1656,8 +1602,7 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
document = SimpleNamespace(indexing_status="waiting", error=None)
query = mocker.Mock()
query.where.return_value.first.return_value = document
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@ -1712,9 +1657,7 @@ def test_run_datasource_node_preview_raises_for_unsupported_provider(mocker, rag
def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@ -1722,9 +1665,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker
def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None)
query = mocker.Mock()
query.where.return_value.first.return_value = pipeline
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
with pytest.raises(ValueError, match="Pipeline workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"})
@ -1732,8 +1673,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("t1", "d1")
@ -1742,8 +1682,7 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.get_pipeline("t1", "d1")
@ -1783,8 +1722,7 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
query = mocker.Mock()
query.where.return_value.first.return_value = template
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@ -2011,8 +1949,7 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
with pytest.raises(ValueError, match="Workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@ -2021,11 +1958,9 @@ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocke
def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
workflow = SimpleNamespace(id="wf-1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, workflow]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock()
mock_db.session.query.return_value = query
mock_db.session.get.side_effect = [pipeline, workflow]
session_ctx = mocker.MagicMock()
session_ctx.__enter__.return_value = SimpleNamespace()
session_ctx.__exit__.return_value = False
@ -2038,11 +1973,8 @@ def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker
def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None:
plugin = SimpleNamespace(plugin_id="plugin-a")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = [plugin]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[])
mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[])
@ -2056,8 +1988,8 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
exec_log = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.retry_error_document(
@ -2069,8 +2001,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_
exec_log = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Workflow not found"):
@ -2086,8 +2018,7 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
assert rag_pipeline_service.get_datasource_plugins("t1", "d1", True) == []
@ -2250,8 +2181,7 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", return_value=[]
@ -2291,8 +2221,7 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
],
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials",
@ -2310,8 +2239,7 @@ def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service)
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
result = rag_pipeline_service.get_pipeline("t1", "d1")

View File

@ -173,9 +173,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = True
@ -188,9 +186,7 @@ class TestAccountService:
def test_authenticate_account_not_found(self, mock_db_dependencies):
"""Test authentication when account does not exist."""
# Setup smart database query mock - no matching results
query_results = {("Account", "email", "notfound@example.com"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = None
# Execute test and verify exception
self._assert_exception_raised(
@ -202,9 +198,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock
query_results = {("Account", "email", "banned@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
# Execute test and verify exception
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
@ -214,9 +208,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = False
@ -230,9 +222,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
# Setup smart database query mock
query_results = {("Account", "email", "pending@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = True
@ -422,12 +412,8 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
# Setup smart database query mock
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant_join
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@ -444,9 +430,7 @@ class TestAccountService:
def test_load_user_not_found(self, mock_db_dependencies):
"""Test user loading when user does not exist."""
# Setup smart database query mock - no matching results
query_results = {("Account", "id", "non-existent-user"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = None
# Execute test
result = AccountService.load_user("non-existent-user")
@ -459,9 +443,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock
query_results = {("Account", "id", "user-123"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# Execute test and verify exception
self._assert_exception_raised(
@ -476,13 +458,9 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
# Setup smart database query mock for complex scenario
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# First scalar: current tenant (None), second scalar: available tenant
mock_db_dependencies["db"].session.scalar.side_effect = [None, mock_available_tenant]
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@ -503,13 +481,9 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock for no tenants scenario
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# First scalar: current tenant (None), second scalar: available tenant (None)
mock_db_dependencies["db"].session.scalar.side_effect = [None, None]
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@ -1060,7 +1034,7 @@ class TestRegisterService:
)
# Verify rollback operations were called
mock_db_dependencies["db"].session.query.assert_called()
mock_db_dependencies["db"].session.execute.assert_called()
# ==================== Registration Tests ====================
@ -1625,10 +1599,8 @@ class TestRegisterService:
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = mock_existing_account
# Mock the db.session.query for TenantAccountJoin
mock_db_query = MagicMock()
mock_db_query.filter_by.return_value.first.return_value = None # No existing member
mock_db_dependencies["db"].session.query.return_value = mock_db_query
# Mock scalar for TenantAccountJoin lookup - no existing member
mock_db_dependencies["db"].session.scalar.return_value = None
# Mock TenantService methods
with (
@ -1803,14 +1775,9 @@ class TestRegisterService:
}
mock_get_invitation_by_token.return_value = invitation_data
# Mock database queries - complex query mocking
mock_query1 = MagicMock()
mock_query1.where.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
# Mock scalar for tenant lookup, execute for account+role lookup
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
@ -1842,10 +1809,8 @@ class TestRegisterService:
}
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Mock database queries - no tenant found
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = None
mock_db_dependencies["db"].session.query.return_value = mock_query
# Mock scalar for tenant lookup - not found
mock_db_dependencies["db"].session.scalar.return_value = None
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
@ -1868,14 +1833,9 @@ class TestRegisterService:
}
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Mock database queries
mock_query1 = MagicMock()
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = None # No account found
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
# Mock scalar for tenant, execute for account+role
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
mock_db_dependencies["db"].session.execute.return_value.first.return_value = None # No account found
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
@ -1901,14 +1861,9 @@ class TestRegisterService:
}
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Mock database queries
mock_query1 = MagicMock()
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
# Mock scalar for tenant, execute for account+role
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")

View File

@ -799,10 +799,7 @@ class TestExternalDatasetServiceGetAPI:
api_id = "api-123"
expected_api = factory.create_external_knowledge_api_mock(api_id=api_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = expected_api
mock_db.session.scalar.return_value = expected_api
# Act
tenant_id = "tenant-123"
@ -810,16 +807,12 @@ class TestExternalDatasetServiceGetAPI:
# Assert
assert result.id == api_id
mock_query.filter_by.assert_called_once_with(id=api_id, tenant_id=tenant_id)
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@ -848,10 +841,7 @@ class TestExternalDatasetServiceUpdateAPI:
"settings": {"endpoint": "https://new.example.com", "api_key": "new-key"},
}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
@ -881,10 +871,7 @@ class TestExternalDatasetServiceUpdateAPI:
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, "user-123", api_id, args)
@ -897,10 +884,7 @@ class TestExternalDatasetServiceUpdateAPI:
def test_update_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
args = {"name": "Updated API"}
@ -912,10 +896,7 @@ class TestExternalDatasetServiceUpdateAPI:
def test_update_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
args = {"name": "Updated API"}
@ -934,10 +915,7 @@ class TestExternalDatasetServiceUpdateAPI:
args = {"name": "New Name Only"}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api("tenant-123", "user-123", "api-123", args)
@ -958,10 +936,7 @@ class TestExternalDatasetServiceDeleteAPI:
existing_api = factory.create_external_knowledge_api_mock(api_id=api_id, tenant_id=tenant_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
ExternalDatasetService.delete_external_knowledge_api(tenant_id, api_id)
@ -974,10 +949,7 @@ class TestExternalDatasetServiceDeleteAPI:
def test_delete_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@ -987,10 +959,7 @@ class TestExternalDatasetServiceDeleteAPI:
def test_delete_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@ -1006,10 +975,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 1
mock_db.session.scalar.return_value = 1
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@ -1024,10 +990,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 10
mock_db.session.scalar.return_value = 10
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@ -1042,10 +1005,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 0
mock_db.session.scalar.return_value = 0
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@ -1067,10 +1027,7 @@ class TestExternalDatasetServiceGetBinding:
expected_binding = factory.create_external_knowledge_binding_mock(tenant_id=tenant_id, dataset_id=dataset_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = expected_binding
mock_db.session.scalar.return_value = expected_binding
# Act
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id(tenant_id, dataset_id)
@ -1083,10 +1040,7 @@ class TestExternalDatasetServiceGetBinding:
def test_get_external_knowledge_binding_not_found(self, mock_db, factory):
"""Test error when binding is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="external knowledge binding not found"):
@ -1113,10 +1067,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {"param1": "value1", "param2": "value2"}
@ -1134,10 +1085,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {}
@ -1149,10 +1097,7 @@ class TestExternalDatasetServiceDocumentValidate:
def test_document_create_args_validate_api_not_found(self, mock_db, factory):
"""Test validation fails when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@ -1165,10 +1110,7 @@ class TestExternalDatasetServiceDocumentValidate:
settings = {}
api = factory.create_external_knowledge_api_mock(settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
# Act & Assert - should not raise
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
@ -1186,10 +1128,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {"required_param": "value"}
@ -1498,24 +1437,7 @@ class TestExternalDatasetServiceCreateDataset:
api = factory.create_external_knowledge_api_mock(api_id="api-123")
# Mock database queries
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
# Act
result = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
@ -1534,10 +1456,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
existing_dataset = factory.create_dataset_mock(name="Duplicate Dataset")
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_dataset
mock_db.session.scalar.return_value = existing_dataset
args = {"name": "Duplicate Dataset"}
@ -1549,23 +1468,7 @@ class TestExternalDatasetServiceCreateDataset:
def test_create_external_dataset_api_not_found_error(self, mock_db, factory):
"""Test error when external knowledge API is not found."""
# Arrange
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = None
mock_db.session.scalar.side_effect = [None, None]
args = {"name": "Test Dataset", "external_knowledge_api_id": "nonexistent-api"}
@ -1579,23 +1482,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
api = factory.create_external_knowledge_api_mock()
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
args = {"name": "Test Dataset", "external_knowledge_api_id": "api-123"}
@ -1609,23 +1496,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
api = factory.create_external_knowledge_api_mock()
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
args = {"name": "Test Dataset", "external_knowledge_id": "knowledge-123"}
@ -1651,23 +1522,7 @@ class TestExternalDatasetServiceFetchRetrieval:
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@ -1695,10 +1550,7 @@ class TestExternalDatasetServiceFetchRetrieval:
def test_fetch_external_knowledge_retrieval_binding_not_found_error(self, mock_db, factory):
"""Test error when external knowledge binding is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="external knowledge binding not found"):
@ -1712,23 +1564,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@ -1751,23 +1587,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@ -1799,23 +1619,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 500
@ -1856,23 +1660,7 @@ class TestExternalDatasetServiceFetchRetrieval:
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = status_code
@ -1891,23 +1679,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 503

View File

@ -85,3 +85,644 @@ def test_get_provider_list_strips_credentials(service_with_fake_configurations:
assert len(custom_models) == 1
# The sanitizer should drop credentials in list response
assert custom_models[0].credentials is None
# === Merged from test_model_provider_service.py ===
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
from core.entities.model_entities import ModelStatus
from models.provider import ProviderType
from services import model_provider_service as service_module
from services.errors.app_model_config import ProviderNotFoundError
from services.model_provider_service import ModelProviderService
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
manager = MagicMock()
service = ModelProviderService()
service._get_provider_manager = MagicMock(return_value=manager)
return service, manager
def _build_provider_configuration(
*,
provider_name: str = "openai",
supported_model_types: list[ModelType] | None = None,
custom_models: list[Any] | None = None,
custom_config_available: bool = True,
) -> SimpleNamespace:
if supported_model_types is None:
supported_model_types = [ModelType.LLM]
return SimpleNamespace(
provider=SimpleNamespace(
provider=provider_name,
label=I18nObject(en_US=provider_name),
description=None,
icon_small=None,
icon_small_dark=None,
background=None,
help=None,
supported_model_types=supported_model_types,
configurate_methods=[],
provider_credential_schema=None,
model_credential_schema=None,
),
preferred_provider_type=ProviderType.CUSTOM,
custom_configuration=SimpleNamespace(
provider=SimpleNamespace(
current_credential_id="cred-1",
current_credential_name="Credential 1",
available_credentials=[],
),
models=custom_models,
can_added_models=[],
),
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
is_custom_configuration_available=lambda: custom_config_available,
)
def test__get_provider_configuration_should_return_configuration_when_provider_exists() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
provider_configuration = SimpleNamespace(name="provider-config")
manager.get_configurations.return_value = {"openai": provider_configuration}
# Act
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
# Assert
assert result is provider_configuration
def test__get_provider_configuration_should_raise_error_when_provider_is_missing() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_configurations.return_value = {}
# Act / Assert
with pytest.raises(ProviderNotFoundError, match="does not exist"):
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
allowed = _build_provider_configuration(
provider_name="openai",
supported_model_types=[ModelType.LLM],
custom_config_available=False,
)
filtered = _build_provider_configuration(
provider_name="embedding",
supported_model_types=[ModelType.TEXT_EMBEDDING],
custom_config_available=True,
)
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
# Act
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert len(result) == 1
assert result[0].provider == "openai"
assert result[0].custom_configuration.status.value == "no-configure"
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
class _Model:
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def model_dump(self) -> dict[str, Any]:
return {
"model": self.model_name,
"label": {"en_US": self.model_name},
"model_type": ModelType.LLM,
"features": [],
"fetch_from": FetchFrom.PREDEFINED_MODEL,
"model_properties": {},
"deprecated": False,
"status": ModelStatus.ACTIVE,
"load_balancing_enabled": False,
"has_invalid_load_balancing_configs": False,
"provider": {
"provider": "openai",
"label": {"en_US": "OpenAI"},
"icon_small": None,
"icon_small_dark": None,
"supported_model_types": [ModelType.LLM],
},
}
provider_configurations = SimpleNamespace(
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
)
manager.get_configurations.return_value = provider_configurations
# Act
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
# Assert
assert len(result) == 2
assert result[0].model == "gpt-4o"
assert result[1].provider.provider == "openai"
provider_configurations.get_models.assert_called_once_with(provider="openai")
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
[
(
"get_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"get_provider_credential",
{"credential_id": "cred-1"},
{"token": "abc"},
),
(
"validate_provider_credentials",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
"validate_provider_credentials",
({"token": "abc"},),
None,
),
(
"create_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}, "credential_name": "A"},
"create_provider_credential",
({"token": "abc"}, "A"),
None,
),
(
"update_provider_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"credentials": {"token": "abc"},
"credential_id": "cred-1",
"credential_name": "B",
},
"update_provider_credential",
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
None,
),
(
"remove_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"delete_provider_credential",
{"credential_id": "cred-1"},
None,
),
(
"switch_active_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"switch_active_provider_credential",
{"credential_id": "cred-1"},
None,
),
],
)
def test_provider_credential_methods_should_delegate_to_provider_configuration(
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
provider_call_kwargs: Any,
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
# Act
result = getattr(service, method_name)(**method_kwargs)
# Assert
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
provider_method = getattr(provider_configuration, provider_method_name)
if isinstance(provider_call_kwargs, tuple):
provider_method.assert_called_once_with(*provider_call_kwargs)
elif isinstance(provider_call_kwargs, dict):
provider_method.assert_called_once_with(**provider_call_kwargs)
else:
provider_method.assert_called_once_with(provider_call_kwargs)
if method_name == "get_provider_credential":
assert result == {"token": "abc"}
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
[
(
"get_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"get_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
{"api_key": "x"},
),
(
"validate_model_credentials",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
},
"validate_custom_model_credentials",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
None,
),
(
"create_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
"create_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
None,
),
(
"update_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
"update_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
None,
),
(
"remove_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"delete_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"switch_active_custom_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"switch_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"add_model_credential_to_model_list",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"add_model_credential_to_model",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"remove_model",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
},
"delete_custom_model",
{"model_type": ModelType.LLM, "model": "gpt-4o"},
None,
),
],
)
def test_custom_model_methods_should_convert_model_type_and_delegate(
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
expected_kwargs: dict[str, Any],
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
# Act
result = getattr(service, method_name)(**method_kwargs)
# Assert
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
if method_name == "get_model_credential":
assert result == {"api_key": "x"}
def test_get_models_by_model_type_should_group_active_non_deprecated_models() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
openai_provider = SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
icon_small_dark=None,
)
anthropic_provider = SimpleNamespace(
provider="anthropic",
label=I18nObject(en_US="Anthropic"),
icon_small=None,
icon_small_dark=None,
)
models = [
SimpleNamespace(
provider=openai_provider,
model="gpt-4o",
label=I18nObject(en_US="GPT-4o"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=False,
),
SimpleNamespace(
provider=openai_provider,
model="old-openai",
label=I18nObject(en_US="Old OpenAI"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
SimpleNamespace(
provider=anthropic_provider,
model="old-anthropic",
label=I18nObject(en_US="Old Anthropic"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
]
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
manager.get_configurations.return_value = provider_configurations
# Act
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
assert len(result) == 1
assert result[0].provider == "openai"
assert len(result[0].models) == 1
assert result[0].models[0].model == "gpt-4o"
@pytest.mark.parametrize(
("credentials", "schema", "expected_count"),
[
(None, None, 0),
({"api_key": "x"}, None, 0),
(
{"api_key": "x"},
SimpleNamespace(
parameter_rules=[
ParameterRule(
name="temperature",
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
)
]
),
1,
),
],
)
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
credentials: dict[str, Any] | None,
schema: Any,
expected_count: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
provider_configuration.get_current_credentials.return_value = credentials
provider_configuration.get_model_schema.return_value = schema
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
# Assert
assert len(result) == expected_count
provider_configuration.get_current_credentials.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4o")
if credentials:
provider_configuration.get_model_schema.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
credentials=credentials,
)
else:
provider_configuration.get_model_schema.assert_not_called()
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = SimpleNamespace(
model="gpt-4o",
model_type=ModelType.LLM,
provider=SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
supported_model_types=[ModelType.LLM],
),
)
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is not None
assert result.model == "gpt-4o"
assert result.provider.provider == "openai"
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = None
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is None
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.side_effect = RuntimeError("boom")
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is None
def test_update_default_model_of_model_type_should_delegate_to_provider_manager() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
# Act
service.update_default_model_of_model_type(
tenant_id="tenant-1",
model_type=ModelType.LLM.value,
provider="openai",
model="gpt-4o",
)
# Assert
manager.update_default_model_record.assert_called_once_with(
tenant_id="tenant-1",
model_type=ModelType.LLM,
provider="openai",
model="gpt-4o",
)
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
service = ModelProviderService()
factory_instance = MagicMock()
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
factory_constructor = MagicMock(return_value=factory_instance)
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
# Act
result = service.get_model_provider_icon(
tenant_id="tenant-1",
provider="openai",
icon_type="icon_small",
lang="en_US",
)
# Assert
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
assert result == (b"icon-bytes", "image/png")
def test_switch_preferred_provider_should_convert_enum_and_delegate(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
service.switch_preferred_provider(
tenant_id="tenant-1",
provider="openai",
preferred_provider_type=ProviderType.SYSTEM.value,
)
# Assert
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
@pytest.mark.parametrize(
("method_name", "provider_method_name"),
[
("enable_model", "enable_model"),
("disable_model", "disable_model"),
],
)
def test_model_enablement_methods_should_convert_model_type_and_delegate(
method_name: str,
provider_method_name: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
getattr(service, method_name)(
tenant_id="tenant-1",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM.value,
)
# Assert
getattr(provider_configuration, provider_method_name).assert_called_once_with(
model="gpt-4o",
model_type=ModelType.LLM,
)

View File

@ -316,7 +316,7 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result == expected_detail
@ -346,7 +346,7 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result["name"] == f"App from {mode}"
@ -369,7 +369,7 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result is None
@ -392,7 +392,7 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result == {}
@ -432,9 +432,197 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result["model_config"] == complex_model_config
assert len(result["workflows"]) == 2
assert len(result["tools"]) == 3
# === Merged from test_recommended_app_service_additional.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from services import recommended_app_service as service_module
from services.recommended_app_service import RecommendedAppService
def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any]:
return cast(dict[str, Any], result)
@pytest.fixture
def mocked_db_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
# Arrange
session = MagicMock()
monkeypatch.setattr(service_module, "db", SimpleNamespace(session=session))
# Assert
return session
def _mock_factory_for_apps(
monkeypatch: pytest.MonkeyPatch,
*,
mode: str,
result: dict[str, Any],
fallback_result: dict[str, Any] | None = None,
) -> tuple[MagicMock, MagicMock]:
retrieval_instance = MagicMock()
retrieval_instance.get_recommended_apps_and_categories.return_value = result
retrieval_factory = MagicMock(return_value=retrieval_instance)
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False)
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_recommend_app_factory",
MagicMock(return_value=retrieval_factory),
)
builtin_instance = MagicMock()
if fallback_result is not None:
builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_buildin_recommend_app_retrieval",
MagicMock(return_value=builtin_instance),
)
return retrieval_instance, builtin_instance
def test_get_recommended_apps_and_categories_should_not_query_trial_table_when_trial_feature_disabled(
monkeypatch: pytest.MonkeyPatch,
mocked_db_session: MagicMock,
) -> None:
# Arrange
expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]}
retrieval_instance, builtin_instance = _mock_factory_for_apps(
monkeypatch,
mode="remote",
result=expected,
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
)
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
assert result == expected
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called()
mocked_db_session.scalar.assert_not_called()
def test_get_recommended_apps_and_categories_should_fallback_and_enrich_can_trial_when_trial_feature_enabled(
monkeypatch: pytest.MonkeyPatch,
mocked_db_session: MagicMock,
) -> None:
# Arrange
remote_result = {"recommended_apps": [], "categories": []}
fallback_result = {"recommended_apps": [{"app_id": "app-1"}, {"app_id": "app-2"}], "categories": ["all"]}
_, builtin_instance = _mock_factory_for_apps(
monkeypatch,
mode="remote",
result=remote_result,
fallback_result=fallback_result,
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
mocked_db_session.scalar.side_effect = [SimpleNamespace(id="trial-app"), None]
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
# Assert
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
assert result["recommended_apps"][0]["can_trial"] is True
assert result["recommended_apps"][1]["can_trial"] is False
assert mocked_db_session.scalar.call_count == 2
@pytest.mark.parametrize(
("trial_query_result", "expected_can_trial"),
[
(SimpleNamespace(id="trial"), True),
(None, False),
],
)
def test_get_recommend_app_detail_should_set_can_trial_when_trial_feature_enabled(
monkeypatch: pytest.MonkeyPatch,
mocked_db_session: MagicMock,
trial_query_result: Any,
expected_can_trial: bool,
) -> None:
# Arrange
detail = {"id": "app-1", "name": "Test App"}
retrieval_instance = MagicMock()
retrieval_instance.get_recommend_app_detail.return_value = detail
retrieval_factory = MagicMock(return_value=retrieval_instance)
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False)
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_recommend_app_factory",
MagicMock(return_value=retrieval_factory),
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
mocked_db_session.scalar.return_value = trial_query_result
# Act
result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail("app-1"))
# Assert
assert result["id"] == "app-1"
assert result["can_trial"] is expected_can_trial
mocked_db_session.scalar.assert_called_once()
def test_add_trial_app_record_should_increment_count_when_existing_record_found(
mocked_db_session: MagicMock,
) -> None:
# Arrange
existing_record = SimpleNamespace(count=3)
mocked_db_session.scalar.return_value = existing_record
# Act
RecommendedAppService.add_trial_app_record("app-1", "account-1")
# Assert
assert existing_record.count == 4
mocked_db_session.scalar.assert_called_once()
mocked_db_session.commit.assert_called_once()
mocked_db_session.add.assert_not_called()
def test_add_trial_app_record_should_create_new_record_when_no_existing_record(
mocked_db_session: MagicMock,
) -> None:
# Arrange
mocked_db_session.scalar.return_value = None
# Act
RecommendedAppService.add_trial_app_record("app-2", "account-2")
# Assert
mocked_db_session.scalar.assert_called_once()
mocked_db_session.add.assert_called_once()
added = mocked_db_session.add.call_args.args[0]
assert added.app_id == "app-2"
assert added.account_id == "account-2"
assert added.count == 1
mocked_db_session.commit.assert_called_once()

View File

@ -1,12 +1,15 @@
import unittest
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
from events.event_handlers.sync_workflow_schedule_when_app_published import (
sync_schedule_from_workflow,
)
@ -14,6 +17,8 @@ from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
from models.account import Account, TenantAccountJoin
from models.trigger import WorkflowSchedulePlan
from models.workflow import Workflow
from services.errors.account import AccountNotFoundError
from services.trigger import schedule_service as service_module
from services.trigger.schedule_service import ScheduleService
@ -775,5 +780,158 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
mock_session.commit.assert_called_once()
@pytest.fixture
def session_mock() -> MagicMock:
return MagicMock(spec=Session)
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def test_update_schedule_should_update_only_node_id_without_recomputing_time(
session_mock: MagicMock,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
schedule = MagicMock(spec=WorkflowSchedulePlan)
schedule.cron_expression = "0 10 * * *"
schedule.timezone = "UTC"
session_mock.get.return_value = schedule
next_run_mock = MagicMock(return_value=datetime(2026, 1, 1, 10, 0, tzinfo=UTC))
monkeypatch.setattr(service_module, "calculate_next_run_at", next_run_mock)
# Act
result = ScheduleService.update_schedule(
session=session_mock,
schedule_id="schedule-1",
updates=SchedulePlanUpdate(node_id="node-new"),
)
# Assert
assert result is schedule
assert schedule.node_id == "node-new"
next_run_mock.assert_not_called()
session_mock.flush.assert_called_once()
def test_get_tenant_owner_should_raise_when_account_record_missing(session_mock: MagicMock) -> None:
# Arrange
join = SimpleNamespace(account_id="account-404")
session_mock.execute.return_value.scalar_one_or_none.return_value = join
session_mock.get.return_value = None
# Act / Assert
with pytest.raises(AccountNotFoundError, match="Account not found: account-404"):
ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1")
def test_get_tenant_owner_should_raise_when_no_owner_or_admin_found(session_mock: MagicMock) -> None:
# Arrange
session_mock.execute.return_value.scalar_one_or_none.side_effect = [None, None]
# Act / Assert
with pytest.raises(AccountNotFoundError, match="Account not found for tenant: tenant-1"):
ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1")
def test_update_next_run_at_should_raise_when_schedule_not_found(session_mock: MagicMock) -> None:
# Arrange
session_mock.get.return_value = None
# Act / Assert
with pytest.raises(ScheduleNotFoundError, match="Schedule not found: schedule-1"):
ScheduleService.update_next_run_at(session=session_mock, schedule_id="schedule-1")
def test_to_schedule_config_should_build_from_cron_mode() -> None:
# Arrange
node_config: dict[str, Any] = {
"id": "node-1",
"data": {
"mode": "cron",
"cron_expression": "0 12 * * *",
"timezone": "Asia/Kolkata",
},
}
# Act
result = ScheduleService.to_schedule_config(node_config=node_config)
# Assert
assert result.node_id == "node-1"
assert result.cron_expression == "0 12 * * *"
assert result.timezone == "Asia/Kolkata"
def test_to_schedule_config_should_raise_for_cron_mode_without_expression() -> None:
# Arrange
node_config = {"id": "node-1", "data": {"mode": "cron", "cron_expression": ""}}
# Act / Assert
with pytest.raises(ScheduleConfigError, match="Cron expression is required for cron mode"):
ScheduleService.to_schedule_config(node_config=node_config)
def test_to_schedule_config_should_build_from_visual_mode(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
node_config = {
"id": "node-1",
"data": {
"mode": "visual",
"frequency": "daily",
"visual_config": {"time": "9:30 AM"},
"timezone": "UTC",
},
}
monkeypatch.setattr(ScheduleService, "visual_to_cron", MagicMock(return_value="30 9 * * *"))
# Act
result = ScheduleService.to_schedule_config(node_config=node_config)
# Assert
assert result.cron_expression == "30 9 * * *"
def test_to_schedule_config_should_raise_for_invalid_mode() -> None:
# Arrange
node_config = {"id": "node-1", "data": {"mode": "manual"}}
# Act / Assert
with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: manual"):
ScheduleService.to_schedule_config(node_config=node_config)
def test_extract_schedule_config_should_raise_when_graph_is_empty() -> None:
# Arrange
workflow = _workflow(graph_dict={})
# Act / Assert
with pytest.raises(ScheduleConfigError, match="Workflow graph is empty"):
ScheduleService.extract_schedule_config(workflow=workflow)
def test_extract_schedule_config_should_raise_when_mode_invalid() -> None:
# Arrange
workflow = _workflow(
graph_dict={
"nodes": [
{
"id": "schedule-1",
"data": {
"type": TRIGGER_SCHEDULE_NODE_TYPE,
"mode": "invalid",
},
}
]
}
)
# Act / Assert
with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: invalid"):
ScheduleService.extract_schedule_config(workflow=workflow)
if __name__ == "__main__":
unittest.main()

View File

@ -12,6 +12,7 @@ This test suite covers all functionality of the current VariableTruncator includ
import functools
import json
import uuid
from collections.abc import Mapping
from typing import Any
from uuid import uuid4
@ -199,14 +200,14 @@ class TestArrayTruncation:
def test_small_array_no_truncation(self, small_truncator: VariableTruncator):
"""Test that small arrays are not truncated."""
small_array = [1, 2]
small_array: list[object] = [1, 2]
result = small_truncator._truncate_array(small_array, 1000)
assert result.value == small_array
assert result.truncated is False
def test_array_element_limit_truncation(self, small_truncator: VariableTruncator):
"""Test that arrays over element limit are truncated."""
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
large_array: list[object] = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
result = small_truncator._truncate_array(large_array, 1000)
assert result.truncated is True
@ -215,7 +216,7 @@ class TestArrayTruncation:
def test_array_size_budget_truncation(self, small_truncator: VariableTruncator):
"""Test array truncation due to size budget constraints."""
# Create array with strings that will exceed size budget
large_strings = ["very long string " * 5, "another long string " * 5]
large_strings: list[object] = ["very long string " * 5, "another long string " * 5]
result = small_truncator._truncate_array(large_strings, 50)
assert result.truncated is True
@ -276,10 +277,10 @@ class TestObjectTruncation:
# Values should be truncated if they exist
for key, value in result.value.items():
if isinstance(value, str):
original_value = obj_with_long_values[key]
# Value should be same or smaller
assert len(value) <= len(original_value)
assert isinstance(value, str)
original_value = obj_with_long_values[key]
# Value should be same or smaller
assert len(value) <= len(original_value)
def test_object_key_dropping(self, small_truncator):
"""Test object truncation where keys are dropped due to size constraints."""
@ -506,10 +507,9 @@ class TestEdgeCases:
truncator = VariableTruncator(string_length_limit=10)
# Unicode characters
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
unicode_text = "你好世界你好世界你好世界" # Multi-byte UTF-8 characters
result = truncator.truncate(StringSegment(value=unicode_text))
if len(unicode_text) > 10:
assert result.truncated is True
assert result.truncated is True
# Special JSON characters
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
@ -631,13 +631,12 @@ class TestIntegrationScenarios:
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
# Should handle all data types appropriately
if result.truncated:
# Verify the result is smaller or equal than original
original_size = truncator.calculate_json_size(mixed_data)
if isinstance(result.result, ObjectSegment):
result_size = truncator.calculate_json_size(result.result.value)
assert result_size <= original_size
assert result.truncated is True
assert isinstance(result.result, ObjectSegment)
# Verify the result is smaller or equal than original
original_size = truncator.calculate_json_size(mixed_data)
result_size = truncator.calculate_json_size(result.result.value)
assert result_size <= original_size
def test_file_and_array_file_variable_mapping(self, file):
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
@ -675,3 +674,229 @@ def test_dummy_variable_truncator_methods():
assert isinstance(result, TruncationResult)
assert result.result == segment
assert result.truncated is False
# === Merged from test_variable_truncator_additional.py ===
from typing import Any
import pytest
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
from graphon.variables.types import SegmentType
from services import variable_truncator as truncator_module
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
class _AbstractPassthrough(BaseTruncator):
def truncate(self, segment: Any) -> TruncationResult:
# Arrange / Act
return super().truncate(segment) # type: ignore[misc]
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
# Arrange / Act
return super().truncate_variable_mapping(v) # type: ignore[misc]
def test_base_truncator_methods_should_execute_abstract_placeholders() -> None:
# Arrange
passthrough = _AbstractPassthrough()
# Act
truncate_result = passthrough.truncate(StringSegment(value="x"))
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
# Assert
assert truncate_result is None
assert mapping_result is None
def test_default_should_use_dify_config_limits(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
# Act
truncator = VariableTruncator.default()
# Assert
assert truncator._max_size_bytes == 111
assert truncator._array_element_limit == 7
assert truncator._string_length_limit == 33
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=5)
mapping = {"very_long_key": "value"}
# Act
result, truncated = truncator.truncate_variable_mapping(mapping)
# Assert
assert result == {"very_long_key": "..."}
assert truncated is True
def test_truncate_variable_mapping_should_handle_segment_values() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=100)
mapping = {"seg": StringSegment(value="hello")}
# Act
result, truncated = truncator.truncate_variable_mapping(mapping)
# Assert
assert isinstance(result["seg"], StringSegment)
assert result["seg"].value == "hello"
assert truncated is False
@pytest.mark.parametrize(
("value", "expected"),
[
(None, False),
(True, False),
(1, False),
(1.5, False),
("x", True),
({"k": "v"}, True),
],
)
def test_json_value_needs_truncation_should_match_expected_rules(value: Any, expected: bool) -> None:
# Arrange
# Act
result = VariableTruncator._json_value_needs_truncation(value)
# Assert
assert result is expected
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=10)
forced_result = truncator_module._PartResult(
value=StringSegment(value="this is too long"),
value_size=100,
truncated=True,
)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
# Act
result = truncator.truncate(StringSegment(value="input"))
# Assert
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert not result.result.value.startswith('"')
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator()
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
# Act / Assert
with pytest.raises(AssertionError):
truncator._truncate_segment(IntegerSegment(value=1), 10)
def test_calculate_json_size_should_unwrap_segment_values() -> None:
# Arrange
segment = StringSegment(value="abc")
# Act
size = VariableTruncator.calculate_json_size(segment)
# Assert
assert size == VariableTruncator.calculate_json_size("abc")
def test_calculate_json_size_should_handle_updated_variable_instances() -> None:
# Arrange
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
# Act
size = VariableTruncator.calculate_json_size(updated)
# Assert
assert size > 0
def test_maybe_qa_structure_should_validate_shape() -> None:
# Arrange
# Act / Assert
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
assert VariableTruncator._maybe_qa_structure({}) is False
def test_maybe_parent_child_structure_should_validate_shape() -> None:
# Arrange
# Act / Assert
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
assert (
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"}) is False
)
def test_truncate_object_should_truncate_segment_values_inside_object() -> None:
# Arrange
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
mapping = {"s": StringSegment(value="long-content")}
# Act
result = truncator._truncate_object(mapping, 20)
# Assert
assert result.truncated is True
assert isinstance(result.value["s"], StringSegment)
def test_truncate_json_primitives_should_handle_updated_variable_input() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=100)
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
# Act
result = truncator._truncate_json_primitives(updated, 100)
# Assert
assert isinstance(result.value, dict)
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type() -> None:
# Arrange
truncator = VariableTruncator()
# Act / Assert
with pytest.raises(AssertionError):
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=10)
forced_segment = ObjectSegment(value={"k": "v"})
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
# Act
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
# Assert
assert result.truncated is True
assert isinstance(result.result, StringSegment)

View File

@ -559,3 +559,757 @@ class TestWebhookServiceUnit:
result = _prepare_webhook_execution("test_webhook", is_debug=True)
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
# === Merged from test_webhook_service_additional.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from flask import Flask
from graphon.variables.types import SegmentType
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import RequestEntityTooLarge
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from models.enums import AppTriggerStatus
from models.model import App
from models.trigger import WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def _app(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_session = MagicMock()
fake_session.query.return_value = _FakeQuery(None)
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"key": "value"}}
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
"webhook-1", is_debug=True
)
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"mode": "debug"}}
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/vnd.custom"},
data="plain content",
):
result = WebhookService.extract_webhook_data(webhook_trigger)
# Assert
assert result["body"] == {"raw": "plain content"}
warning_mock.assert_called_once()
def test_extract_webhook_data_should_raise_for_request_too_large(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
# Act / Assert
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
with pytest.raises(RequestEntityTooLarge):
WebhookService.extract_webhook_data(MagicMock())
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
# Arrange
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b""):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = MagicMock()
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
body, files = WebhookService._extract_text_body()
# Assert
assert body == {"raw": ""}
assert files == {}
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
monkeypatch.setattr(service_module, "magic", fake_magic)
# Act
result = WebhookService._detect_binary_mimetype(b"binary")
# Assert
assert result == "application/octet-stream"
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
file_obj = MagicMock()
file_obj.to_dict.return_value = {"id": "f-1"}
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
uploaded = MagicMock()
uploaded.filename = "file.unknown"
uploaded.content_type = None
uploaded.read.return_value = b"content"
# Act
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
# Assert
assert result == {"f": {"id": "f-1"}}
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
manager = MagicMock()
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
expected_file = MagicMock()
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
# Act
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
# Assert
assert result is expected_file
manager.create_file_by_raw.assert_called_once()
@pytest.mark.parametrize(
("raw_value", "param_type", "expected"),
[
("42", SegmentType.NUMBER, 42),
("3.14", SegmentType.NUMBER, 3.14),
("yes", SegmentType.BOOLEAN, True),
("no", SegmentType.BOOLEAN, False),
],
)
def test_convert_form_value_should_convert_supported_types(
raw_value: str,
param_type: str,
expected: Any,
) -> None:
# Arrange
# Act
result = WebhookService._convert_form_value("param", raw_value, param_type)
# Assert
assert result == expected
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Unsupported type"):
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
# Act
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
# Assert
assert result == {"x": 1}
warning_mock.assert_called_once()
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="validation failed"):
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
# Arrange
raw_params = {"optional": "x"}
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required parameter missing"):
WebhookService._process_parameters(raw_params, config, is_form_data=True)
def test_process_parameters_should_include_unconfigured_parameters() -> None:
# Arrange
raw_params = {"known": "1", "unknown": "x"}
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
# Act
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
# Assert
assert result == {"known": 1, "unknown": "x"}
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Required body content missing"):
WebhookService._process_body_parameters(
raw_body={"raw": ""},
body_configs=[WebhookBodyParameter(name="raw", required=True)],
content_type=ContentType.TEXT,
)
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
# Arrange
raw_body = {"message": "hello", "extra": "x"}
body_configs = [
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
]
# Act
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
# Assert
assert result == {"message": "hello", "extra": "x"}
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
# Arrange
headers = {"x_api_key": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act
WebhookService._validate_required_headers(headers, configs)
# Assert
assert True
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
# Arrange
headers = {"x-other": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required header missing"):
WebhookService._validate_required_headers(headers, configs)
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
# Arrange
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
# Act
result = WebhookService._validate_http_metadata(webhook_data, node_data)
# Assert
assert result["valid"] is False
assert "Content-type mismatch" in result["error"]
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
# Arrange
headers = {"content-type": "application/json; charset=utf-8"}
# Act
result = WebhookService._extract_content_type(headers)
# Assert
assert result == "application/json"
def test_build_workflow_inputs_should_include_expected_keys() -> None:
# Arrange
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
# Act
result = WebhookService.build_workflow_inputs(webhook_data)
# Assert
assert result["webhook_data"] == webhook_data
assert result["webhook_headers"] == {"h": "v"}
assert result["webhook_query_params"] == {"q": 1}
assert result["webhook_body"] == {"b": 2}
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
webhook_data = {"body": {"x": 1}}
session = MagicMock()
_patch_session(monkeypatch, session)
end_user = SimpleNamespace(id="end-user-1")
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
)
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
monkeypatch.setattr(service_module, "QuotaType", quota_type)
trigger_async_mock = MagicMock()
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
# Act
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
# Assert
trigger_async_mock.assert_called_once()
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
)
quota_type = SimpleNamespace(
TRIGGER=SimpleNamespace(
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
)
)
monkeypatch.setattr(service_module, "QuotaType", quota_type)
mark_rate_limited_mock = MagicMock()
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
# Act / Assert
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
mark_rate_limited_mock.assert_called_once_with("tenant-1")
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
)
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
# Act / Assert
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
logger_exception_mock.assert_called_once()
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(
walk_nodes=lambda _node_type: [
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
]
)
# Act / Assert
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
lock = MagicMock()
lock.acquire.return_value = False
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
# Act / Assert
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
class _WorkflowWebhookTrigger:
app_id = "app_id"
tenant_id = "tenant_id"
webhook_id = "webhook_id"
node_id = "node_id"
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
self.id = None
self.app_id = app_id
self.tenant_id = tenant_id
self.node_id = node_id
self.webhook_id = webhook_id
self.created_by = created_by
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def __init__(self) -> None:
self.added: list[Any] = []
self.deleted: list[Any] = []
self.commit_count = 0
self.existing_records = [SimpleNamespace(node_id="node-stale")]
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: self.existing_records)
def add(self, obj: Any) -> None:
self.added.append(obj)
def flush(self) -> None:
for idx, obj in enumerate(self.added, start=1):
if obj.id is None:
obj.id = f"rec-{idx}"
def commit(self) -> None:
self.commit_count += 1
def delete(self, obj: Any) -> None:
self.deleted.append(obj)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.return_value = None
fake_session = _Session()
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
redis_set_mock = MagicMock()
redis_delete_mock = MagicMock()
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
_patch_session(monkeypatch, fake_session)
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
assert fake_session.commit_count == 2
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [])
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: [])
def commit(self) -> None:
return None
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
_patch_session(monkeypatch, _Session())
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert logger_exception_mock.call_count == 1
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
# Arrange
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
# Act
body, status = WebhookService.generate_webhook_response(node_config)
# Assert
assert status == 200
assert "message" in body
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
# Arrange
# Act
webhook_id = WebhookService.generate_webhook_id()
# Assert
assert isinstance(webhook_id, str)
assert len(webhook_id) == 24
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
# Arrange
# Act
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
# Assert
assert result == 123

View File

@ -176,3 +176,300 @@ class TestWorkflowRunService:
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory
# === Merged from test_workflow_run_service.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
from services import workflow_run_service as service_module
from services.workflow_run_service import WorkflowRunService
@pytest.fixture
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
# Arrange
node_repo = MagicMock()
workflow_run_repo = MagicMock()
factory = SimpleNamespace(
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
# Assert
return node_repo, workflow_run_repo, factory
def _app_model(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
# Act
service = WorkflowRunService()
# Assert
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_create_sessionmaker_when_engine_is_provided(
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
class FakeEngine:
pass
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "Engine", FakeEngine)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
engine = cast(Engine, FakeEngine())
# Act
service = WorkflowRunService(session_factory=engine)
# Assert
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
node_repo, workflow_run_repo, factory = repository_factory_mocks
session_factory = MagicMock(name="session_factory")
# Act
service = WorkflowRunService(session_factory=session_factory)
# Assert
assert service._session_factory is session_factory
assert service._node_execution_service_repo is node_repo
assert service._workflow_run_repo is workflow_run_repo
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="pagination")
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
# Act
result = service.get_paginate_workflow_runs(
app_model=app_model,
args=args,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Assert
assert result is expected
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
limit=7,
last_id="last-1",
status="succeeded",
)
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
run_with_message = SimpleNamespace(
id="run-1",
status="running",
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
)
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
# Act
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
# Assert
assert result is pagination
assert len(result.data) == 2
assert result.data[0].message_id == "msg-1"
assert result.data[0].conversation_id == "conv-1"
assert result.data[0].status == "running"
assert not hasattr(result.data[1], "message_id")
assert result.data[1].id == "run-2"
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="workflow_run")
workflow_run_repo.get_workflow_run_by_id.return_value = expected
# Act
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
# Assert
assert result is expected
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
run_id="run-1",
)
def test_get_workflow_runs_count_should_forward_optional_filters(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = {"total": 3, "succeeded": 2}
workflow_run_repo.get_workflow_runs_count.return_value = expected
# Act
result = service.get_workflow_runs_count(
app_model=app_model,
status="succeeded",
time_range="7d",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Assert
assert result == expected
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
status="succeeded",
time_range="7d",
)
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-1")
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == []
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
class FakeEndUser:
def __init__(self, tenant_id: str) -> None:
self.tenant_id = tenant_id
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
app_model = _app_model(id="app-1")
expected = [SimpleNamespace(id="exec-1")]
node_repo.get_executions_by_workflow_run.return_value = expected
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-end-user",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-account")
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
node_repo.get_executions_by_workflow_run.return_value = expected
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-account",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id=None)
# Act / Assert
with pytest.raises(ValueError, match="tenant_id cannot be None"):
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)

View File

@ -0,0 +1,831 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity,
DatasetEntity,
DatasetRetrieveConfigEntity,
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
)
from core.helper import encrypter
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import Account, App, AppMode, AppModelConfig
from services.workflow import workflow_converter as converter_module
from services.workflow.workflow_converter import WorkflowConverter
try:
from graphon.enums import BuiltinNodeTypes
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.message_entities import PromptMessageRole
from graphon.variables.input_entities import VariableEntity, VariableEntityType
except ModuleNotFoundError:
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
@pytest.fixture
def converter() -> WorkflowConverter:
return WorkflowConverter()
def _app_model(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _app_model_config(**kwargs: Any) -> AppModelConfig:
return cast(AppModelConfig, SimpleNamespace(**kwargs))
def _build_start_graph() -> dict[str, Any]:
return {
"nodes": [
{
"id": "start",
"position": None,
"data": {"type": BuiltinNodeTypes.START, "variables": [{"variable": "name"}, {"variable": "city"}]},
}
],
"edges": [],
}
def _build_model_config(mode: str | LLMMode) -> ModelConfigEntity:
return ModelConfigEntity(provider="openai", model="gpt-4", mode=mode, parameters={}, stop=[])
@pytest.fixture
def default_variables() -> list[VariableEntity]:
return [
VariableEntity(variable="text_input", label="text-input", type=VariableEntityType.TEXT_INPUT),
VariableEntity(variable="paragraph", label="paragraph", type=VariableEntityType.PARAGRAPH),
VariableEntity(variable="select", label="select", type=VariableEntityType.SELECT),
]
def test__convert_to_start_node(default_variables: list[VariableEntity]) -> None:
result = WorkflowConverter()._convert_to_start_node(default_variables)
assert result["id"] == "start"
assert result["data"]["type"] == BuiltinNodeTypes.START
assert result["data"]["variables"][0]["type"] == "text-input"
assert result["data"]["variables"][0]["variable"] == "text_input"
def test__convert_to_http_request_node_for_chatbot(default_variables: list[VariableEntity]) -> None:
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.CHAT
extension = APIBasedExtension(
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
extension.id = "api_based_extension_id"
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable",
type="api",
config={"api_based_extension_id": "api_based_extension_id"},
),
]
nodes, mapping = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables,
)
assert len(nodes) == 2
assert nodes[0]["data"]["type"] == BuiltinNodeTypes.HTTP_REQUEST
assert nodes[1]["data"]["type"] == BuiltinNodeTypes.CODE
body = json.loads(nodes[0]["data"]["body"]["data"])
assert body["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
assert body["params"]["query"] == "{{#sys.query#}}"
assert body["params"]["inputs"]["text_input"] == "{{#start.text_input#}}"
assert mapping == {"external_variable": "code_1"}
def test__convert_to_http_request_node_for_workflow_app(default_variables: list[VariableEntity]) -> None:
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.WORKFLOW
extension = APIBasedExtension(
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
extension.id = "api_based_extension_id"
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable",
type="api",
config={"api_based_extension_id": "api_based_extension_id"},
),
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables,
)
body = json.loads(nodes[0]["data"]["body"]["data"])
assert body["params"]["query"] == ""
def test__convert_to_knowledge_retrieval_node_for_chatbot() -> None:
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=AppMode.ADVANCED_CHAT,
dataset_config=dataset_config,
model_config=model_config,
)
assert node is not None
assert node["data"]["query_variable_selector"] == ["sys", "query"]
assert node["data"]["multiple_retrieval_config"]["top_k"] == 5
def test__convert_to_knowledge_retrieval_node_for_workflow_app() -> None:
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
query_variable="query",
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=AppMode.WORKFLOW,
dataset_config=dataset_config,
model_config=model_config,
)
assert node is not None
assert node["data"]["query_variable_selector"] == ["start", "query"]
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables: list[VariableEntity]) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helper for {{text_input}} and {{paragraph}}",
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert node["data"]["type"] == BuiltinNodeTypes.LLM
assert node["data"]["memory"] is not None
assert node["data"]["prompt_template"][0]["role"] == "user"
assert "{{#start.text_input#}}" in node["data"]["prompt_template"][0]["text"]
def test__convert_to_llm_node_for_chatbot_simple_chat_model_with_empty_template(
default_variables: list[VariableEntity],
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="ignored",
)
monkeypatch.setattr(
converter_module.SimplePromptTransform,
"get_prompt_template",
lambda self, **kwargs: {"prompt_template": PromptTemplateParser(""), "prompt_rules": {}},
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert node["data"]["prompt_template"] == []
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables: list[VariableEntity]) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[AdvancedChatMessageEntity(text="Hello {{text_input}}", role=PromptMessageRole.USER)]
),
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert isinstance(node["data"]["prompt_template"], list)
assert node["data"]["prompt_template"][0]["role"] == PromptMessageRole.USER.value
def test__convert_to_llm_node_for_chatbot_advanced_chat_model_without_template(
default_variables: list[VariableEntity],
) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=None,
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.WORKFLOW,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert node["data"]["prompt_template"] == []
assert node["data"]["memory"] is None
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables: list[VariableEntity]) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(
provider="openai",
model="gpt-3.5-turbo-instruct",
mode=LLMMode.COMPLETION.value,
parameters={},
stop=[],
)
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt="Hello {{text_input}} and {{#query#}}",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
),
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.COMPLETION,
new_app_mode=AppMode.ADVANCED_CHAT,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert node["data"]["prompt_template"]["text"].find("{{#sys.query#}}") != -1
assert node["data"]["memory"]["role_prefix"]["user"] == "Human"
def test__convert_to_end_node() -> None:
node = WorkflowConverter()._convert_to_end_node()
assert node["id"] == "end"
assert node["data"]["type"] == BuiltinNodeTypes.END
def test__convert_to_answer_node() -> None:
node = WorkflowConverter()._convert_to_answer_node()
assert node["id"] == "answer"
assert node["data"]["type"] == BuiltinNodeTypes.ANSWER
def test_convert_to_workflow_should_raise_when_app_model_config_is_missing(converter: WorkflowConverter) -> None:
app_model = _app_model(app_model_config=None)
with pytest.raises(ValueError, match="App model config is required"):
converter.convert_to_workflow(
app_model=app_model,
account=_account(id="account-1"),
name="new-app",
icon_type="emoji",
icon="robot",
icon_background="#fff",
)
@pytest.mark.parametrize(
("source_mode", "expected_mode"),
[
(AppMode.CHAT, AppMode.ADVANCED_CHAT),
(AppMode.COMPLETION, AppMode.WORKFLOW),
],
)
def test_convert_to_workflow_should_create_new_app_with_fallback_fields(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
source_mode: AppMode,
expected_mode: AppMode,
) -> None:
class FakeApp:
def __init__(self) -> None:
self.id = "new-app-id"
workflow = SimpleNamespace(app_id=None)
monkeypatch.setattr(converter, "convert_app_model_config_to_workflow", MagicMock(return_value=workflow))
monkeypatch.setattr(converter_module, "App", FakeApp)
db_session = SimpleNamespace(add=MagicMock(), flush=MagicMock(), commit=MagicMock())
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
send_mock = MagicMock()
monkeypatch.setattr(converter_module.app_was_created, "send", send_mock)
account = _account(id="account-1")
app_model = _app_model(
tenant_id="tenant-1",
name="Source App",
mode=source_mode,
icon_type="emoji",
icon="sparkles",
icon_background="#123456",
enable_site=True,
enable_api=True,
api_rpm=10,
api_rph=100,
is_public=False,
app_model_config=_app_model_config(id="config-1"),
)
new_app = converter.convert_to_workflow(
app_model=app_model,
account=account,
name="",
icon_type="",
icon="",
icon_background="",
)
assert new_app.name == "Source App(workflow)"
assert new_app.mode == expected_mode
assert new_app.icon_type == "emoji"
assert new_app.icon == "sparkles"
assert new_app.icon_background == "#123456"
assert new_app.created_by == "account-1"
assert workflow.app_id == "new-app-id"
db_session.add.assert_called_once()
db_session.flush.assert_called_once()
db_session.commit.assert_called_once()
send_mock.assert_called_once_with(new_app, account=account)
def test_convert_app_model_config_to_workflow_should_build_advanced_chat_graph_and_features(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT)
app_config = SimpleNamespace(
variables=[SimpleNamespace(variable="name")],
external_data_variables=[SimpleNamespace(variable="ext")],
dataset=SimpleNamespace(id="dataset"),
model=SimpleNamespace(),
prompt_template=SimpleNamespace(),
additional_features=SimpleNamespace(file_upload=SimpleNamespace()),
app_model_config_dict={
"opening_statement": "hello",
"suggested_questions": ["q1"],
"suggested_questions_after_answer": True,
"speech_to_text": True,
"text_to_speech": {"enabled": True},
"file_upload": {"enabled": True},
"sensitive_word_avoidance": {"enabled": True},
"retriever_resource": {"enabled": True},
},
)
class FakeWorkflow:
VERSION_DRAFT = "draft"
def __init__(self, **kwargs: Any) -> None:
self.__dict__.update(kwargs)
monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.ADVANCED_CHAT))
monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config))
monkeypatch.setattr(
converter,
"_convert_to_start_node",
MagicMock(
return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}}
),
)
monkeypatch.setattr(
converter,
"_convert_to_http_request_node",
MagicMock(
return_value=(
[{"id": "http", "position": None, "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}],
{"ext": "code_1"},
)
),
)
monkeypatch.setattr(
converter,
"_convert_to_knowledge_retrieval_node",
MagicMock(
return_value={"id": "knowledge", "position": None, "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}
),
)
monkeypatch.setattr(
converter,
"_convert_to_llm_node",
MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}),
)
monkeypatch.setattr(
converter,
"_convert_to_answer_node",
MagicMock(return_value={"id": "answer", "position": None, "data": {"type": BuiltinNodeTypes.ANSWER}}),
)
monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow)
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock())
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
workflow = converter.convert_app_model_config_to_workflow(
app_model=app_model,
app_model_config=_app_model_config(id="cfg"),
account_id="account-1",
)
graph = json.loads(workflow.graph)
node_ids = [node["id"] for node in graph["nodes"]]
assert node_ids == ["start", "http", "knowledge", "llm", "answer"]
features = json.loads(workflow.features)
assert "opening_statement" in features
assert "retriever_resource" in features
db_session.add.assert_called_once()
db_session.commit.assert_called_once()
def test_convert_app_model_config_to_workflow_should_build_workflow_mode_with_end_node(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.COMPLETION)
app_config = SimpleNamespace(
variables=[SimpleNamespace(variable="name")],
external_data_variables=[],
dataset=SimpleNamespace(id="dataset"),
model=SimpleNamespace(),
prompt_template=SimpleNamespace(),
additional_features=None,
app_model_config_dict={
"text_to_speech": {"enabled": False},
"file_upload": {"enabled": False},
"sensitive_word_avoidance": {"enabled": False},
},
)
class FakeWorkflow:
VERSION_DRAFT = "draft"
def __init__(self, **kwargs: Any) -> None:
self.__dict__.update(kwargs)
monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.WORKFLOW))
monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config))
monkeypatch.setattr(
converter,
"_convert_to_start_node",
MagicMock(
return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}}
),
)
monkeypatch.setattr(converter, "_convert_to_knowledge_retrieval_node", MagicMock(return_value=None))
monkeypatch.setattr(
converter,
"_convert_to_llm_node",
MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}),
)
monkeypatch.setattr(
converter,
"_convert_to_end_node",
MagicMock(return_value={"id": "end", "position": None, "data": {"type": BuiltinNodeTypes.END}}),
)
monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow)
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock())
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
workflow = converter.convert_app_model_config_to_workflow(
app_model=app_model,
app_model_config=_app_model_config(id="cfg"),
account_id="account-1",
)
graph = json.loads(workflow.graph)
node_ids = [node["id"] for node in graph["nodes"]]
assert node_ids == ["start", "llm", "end"]
features = json.loads(workflow.features)
assert set(features.keys()) == {"text_to_speech", "file_upload", "sensitive_word_avoidance"}
def test_convert_to_app_config_should_route_to_correct_manager(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent_result = SimpleNamespace(kind="agent")
chat_result = SimpleNamespace(kind="chat")
completion_result = SimpleNamespace(kind="completion")
monkeypatch.setattr(
converter_module.AgentChatAppConfigManager, "get_app_config", MagicMock(return_value=agent_result)
)
monkeypatch.setattr(converter_module.ChatAppConfigManager, "get_app_config", MagicMock(return_value=chat_result))
monkeypatch.setattr(
converter_module.CompletionAppConfigManager,
"get_app_config",
MagicMock(return_value=completion_result),
)
from_agent_mode = converter._convert_to_app_config(
app_model=_app_model(mode=AppMode.AGENT_CHAT, is_agent=False),
app_model_config=_app_model_config(id="cfg-1"),
)
from_agent_flag = converter._convert_to_app_config(
app_model=_app_model(mode=AppMode.CHAT, is_agent=True),
app_model_config=_app_model_config(id="cfg-2"),
)
from_chat_mode = converter._convert_to_app_config(
app_model=_app_model(mode=AppMode.CHAT, is_agent=False),
app_model_config=_app_model_config(id="cfg-3"),
)
from_completion_mode = converter._convert_to_app_config(
app_model=_app_model(mode=AppMode.COMPLETION, is_agent=False),
app_model_config=_app_model_config(id="cfg-4"),
)
assert from_agent_mode is agent_result
assert from_agent_flag is agent_result
assert from_chat_mode is chat_result
assert from_completion_mode is completion_result
def test_convert_to_app_config_should_raise_for_invalid_app_mode(converter: WorkflowConverter) -> None:
app_model = _app_model(mode=AppMode.WORKFLOW, is_agent=False)
with pytest.raises(ValueError, match="Invalid app mode"):
converter._convert_to_app_config(app_model=app_model, app_model_config=_app_model_config(id="cfg"))
def test_convert_to_http_request_node_should_skip_non_api_and_missing_extension_id(
converter: WorkflowConverter,
) -> None:
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT)
external_data_variables = [
ExternalDataVariableEntity(variable="skip_type", type="dataset", config={"api_based_extension_id": "x"}),
ExternalDataVariableEntity(variable="skip_config", type="api", config={}),
]
nodes, mapping = converter._convert_to_http_request_node(
app_model=app_model,
variables=[],
external_data_variables=external_data_variables,
)
assert nodes == []
assert mapping == {}
def test_convert_to_knowledge_retrieval_node_should_return_none_for_workflow_without_query_variable(
converter: WorkflowConverter,
) -> None:
dataset_config = DatasetEntity(
dataset_ids=["ds-1"],
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=None,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
),
)
model_config = _build_model_config(mode=LLMMode.CHAT)
node = converter._convert_to_knowledge_retrieval_node(
new_app_mode=AppMode.WORKFLOW,
dataset_config=dataset_config,
model_config=model_config,
)
assert node is None
def test_convert_to_llm_node_should_raise_when_simple_chat_template_missing(
converter: WorkflowConverter,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.CHAT)
prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE)
with pytest.raises(ValueError, match="Simple prompt template is required"):
converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
def test_convert_to_llm_node_should_raise_when_prompt_template_parser_type_is_invalid_for_chat(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.CHAT)
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="Hello {{name}}",
)
monkeypatch.setattr(
converter_module.SimplePromptTransform,
"get_prompt_template",
lambda self, **kwargs: {"prompt_template": "invalid"},
)
with pytest.raises(TypeError, match="Expected PromptTemplateParser"):
converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
def test_convert_to_llm_node_should_raise_when_simple_completion_template_missing(
converter: WorkflowConverter,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.COMPLETION)
prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE)
with pytest.raises(ValueError, match="Simple prompt template is required"):
converter._convert_to_llm_node(
original_app_mode=AppMode.COMPLETION,
new_app_mode=AppMode.WORKFLOW,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
def test_convert_to_llm_node_should_raise_when_completion_prompt_rules_type_is_invalid(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.COMPLETION)
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="Hello {{name}}",
)
monkeypatch.setattr(
converter_module.SimplePromptTransform,
"get_prompt_template",
lambda self, **kwargs: {"prompt_template": PromptTemplateParser("Hello {{name}}"), "prompt_rules": "invalid"},
)
with pytest.raises(TypeError, match="Expected dict for prompt_rules"):
converter._convert_to_llm_node(
original_app_mode=AppMode.COMPLETION,
new_app_mode=AppMode.ADVANCED_CHAT,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
def test_convert_to_llm_node_should_use_empty_text_for_advanced_completion_without_template(
converter: WorkflowConverter,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.COMPLETION)
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=None,
)
llm_node = converter._convert_to_llm_node(
original_app_mode=AppMode.COMPLETION,
new_app_mode=AppMode.WORKFLOW,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
assert llm_node["data"]["prompt_template"]["text"] == ""
assert llm_node["data"]["memory"] is None
def test_replace_template_variables_should_replace_start_and_external_references(converter: WorkflowConverter) -> None:
template = "Hello {{name}} from {{city}} with {{weather}}"
variables = [{"variable": "name"}, {"variable": "city"}]
external_mapping = {"weather": "code_1"}
result = converter._replace_template_variables(template, variables, external_mapping)
assert result == "Hello {{#start.name#}} from {{#start.city#}} with {{#code_1.result#}}"
def test_graph_helpers_should_create_edges_append_nodes_and_choose_mode(converter: WorkflowConverter) -> None:
graph = {"nodes": [{"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START}}], "edges": []}
node = {"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}
edge = converter._create_edge("start", "llm")
updated_graph = converter._append_node(graph, node)
workflow_mode = converter._get_new_app_mode(_app_model(mode=AppMode.COMPLETION))
advanced_chat_mode = converter._get_new_app_mode(_app_model(mode=AppMode.CHAT))
assert edge == {"id": "start-llm", "source": "start", "target": "llm"}
assert updated_graph["nodes"][-1]["id"] == "llm"
assert updated_graph["edges"][-1]["source"] == "start"
assert workflow_mode == AppMode.WORKFLOW
assert advanced_chat_mode == AppMode.ADVANCED_CHAT
def test_get_api_based_extension_should_raise_when_extension_not_found(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
db_session = SimpleNamespace(scalar=MagicMock(return_value=None))
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
with pytest.raises(ValueError, match="API Based Extension not found"):
converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1")
db_session.scalar.assert_called_once()
def test_get_api_based_extension_should_return_entity_when_found(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
extension = SimpleNamespace(id="ext-1")
db_session = SimpleNamespace(scalar=MagicMock(return_value=extension))
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
result = converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1")
assert result is extension
db_session.scalar.assert_called_once()

View File

@ -1,10 +1,9 @@
from __future__ import annotations
import json
import queue
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from itertools import cycle
from threading import Event
import pytest
@ -224,3 +223,577 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -
buffer_state.task_id_ready.set()
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
assert task_id == expected
# === Merged from test_workflow_event_snapshot_service_additional.py ===
import json
import queue
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from threading import Event
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState, VariablePool
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import StreamEvent
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services import workflow_event_snapshot_service as service_module
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"query": "hello"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=1.2,
total_tokens=5,
total_steps=2,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.outputs = {"answer": "ok"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionMaker:
def __init__(self, session: Any) -> None:
self._session = session
def __call__(self) -> _SessionContext:
return _SessionContext(self._session)
class _SubscriptionContext:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def __enter__(self) -> Any:
return self._subscription
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _Topic:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def subscribe(self) -> _SubscriptionContext:
return _SubscriptionContext(self._subscription)
class _StaticSubscription:
def receive(self, timeout: int = 1) -> None:
return None
@dataclass(frozen=True)
class _PauseEntity(WorkflowPauseEntity):
state: bytes
@property
def id(self) -> str:
return "pause-1"
@property
def workflow_execution_id(self) -> str:
return "run-1"
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return datetime(2024, 1, 1, tzinfo=UTC)
def get_state(self) -> bytes:
return self.state
def get_pause_reasons(self) -> list[Any]:
return []
def test_get_message_context_should_return_none_when_no_message() -> None:
# Arrange
session = SimpleNamespace(scalar=MagicMock(return_value=None))
session_maker = _SessionMaker(session)
# Act
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
# Assert
assert result is None
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None:
# Arrange
message = SimpleNamespace(
id="msg-1",
conversation_id="conv-1",
created_at=None,
answer="answer",
)
session = SimpleNamespace(scalar=MagicMock(return_value=message))
session_maker = _SessionMaker(session)
# Act
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
# Assert
assert result is not None
assert result.created_at == 0
assert result.message_id == "msg-1"
assert result.conversation_id == "conv-1"
assert result.answer == "answer"
def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None:
# Arrange
# Act
result = service_module._load_resumption_context(None)
# Assert
assert result is None
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None:
# Arrange
pause_entity = _PauseEntity(state=b"not-a-valid-state")
# Act
result = service_module._load_resumption_context(pause_entity)
# Assert
assert result is None
def test_load_resumption_context_should_parse_valid_state_into_context() -> None:
# Arrange
context = _build_resumption_context_additional(task_id="task-ctx")
pause_entity = _PauseEntity(state=context.dumps().encode())
# Act
result = service_module._load_resumption_context(pause_entity)
# Assert
assert result is not None
assert result.get_generate_entity().task_id == "task-ctx"
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None:
# Arrange
# Act
result = service_module._resolve_task_id(
resumption_context=None,
buffer_state=None,
workflow_run_id="run-1",
)
# Assert
assert result == "run-1"
@pytest.mark.parametrize(
("payload", "expected"),
[
(b'{"event":"node_started"}', {"event": "node_started"}),
(b"invalid-json", None),
(b"[]", None),
],
)
def test_parse_event_message_should_parse_only_json_object(
payload: bytes,
expected: dict[str, Any] | None,
) -> None:
# Arrange
# Act
result = service_module._parse_event_message(payload)
# Assert
assert result == expected
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None:
# Arrange
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
# Act
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
# Assert
assert is_finished is True
assert paused_without_flag is False
assert paused_with_flag is True
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
def test_apply_message_context_should_update_payload_when_context_exists() -> None:
# Arrange
payload: dict[str, Any] = {"event": "workflow_started"}
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
# Act
service_module._apply_message_context(payload, context)
# Assert
assert payload["conversation_id"] == "conv-1"
assert payload["message_id"] == "msg-1"
assert payload["created_at"] == 1700000000
def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None:
# Arrange
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-1"}'
return None
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
event = buffer_state.queue.get(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert ready is True
assert finished is True
assert buffer_state.task_id_hint == "task-1"
assert event["event"] == "node_started"
def test_start_buffering_should_drop_old_event_when_queue_is_full(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
class QueueWithSingleFull:
def __init__(self) -> None:
self._first_put = True
self.items: list[dict[str, Any]] = [{"event": "old"}]
def put_nowait(self, item: dict[str, Any]) -> None:
if self._first_put:
self._first_put = False
raise queue.Full
self.items.append(item)
def get_nowait(self) -> dict[str, Any]:
if not self.items:
raise queue.Empty
return self.items.pop(0)
def empty(self) -> bool:
return len(self.items) == 0
fake_queue = QueueWithSingleFull()
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-2"}'
return None
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert ready is True
assert finished is True
assert fake_queue.items[-1]["task_id"] == "task-2"
def test_start_buffering_should_set_done_event_when_subscription_raises() -> None:
# Arrange
class Subscription:
def receive(self, timeout: int = 1) -> bytes | None:
raise RuntimeError("subscription failure")
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert finished is True
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(
service_module,
"_get_message_context",
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
)
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
monkeypatch.setattr(
service_module,
"_build_snapshot_events",
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
)
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.ADVANCED_CHAT,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events[0] == StreamEvent.PING.value
finished_event = cast(Mapping[str, Any], events[1])
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
assert buffer_state.stop_event.is_set() is True
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
assert called_kwargs["workflow_run_id"] == "run-1"
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
class AlwaysEmptyQueue:
def empty(self) -> bool:
return False
def get(self, timeout: int = 1) -> None:
raise queue.Empty
buffer_state = BufferState(
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
time_values = cycle([0.0, 6.0, 21.0, 26.0])
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
idle_timeout=20.0,
ping_interval=5.0,
)
)
# Assert
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
buffer_state.done_event.set()
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events == [StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events[0] == StreamEvent.PING.value
assert snapshot_builder.call_args.kwargs["pause_entity"] is None

View File

@ -10,6 +10,8 @@ This module tests the document indexing task functionality including:
"""
import uuid
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -1113,13 +1115,17 @@ class TestAdvancedScenarios:
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task)
# Assert
# Verify delete was called to clean up task key
mock_redis.delete.assert_called_once()
expected_task_key = f"tenant_document_indexing_task:{tenant_id}"
# Verify the correct key was deleted (contains tenant_id and "document_indexing")
delete_call_args = mock_redis.delete.call_args[0][0]
assert tenant_id in delete_call_args
assert "document_indexing" in delete_call_args
# Verify the task key for this tenant was deleted (do not assert call count; fixtures may be shared).
mock_redis.delete.assert_any_call(expected_task_key)
deleted_keys = [delete_call.args[0] for delete_call in mock_redis.delete.call_args_list if delete_call.args]
assert expected_task_key in deleted_keys
deleted_task_key = next(key for key in deleted_keys if key == expected_task_key)
assert tenant_id in deleted_task_key
assert "document_indexing" in deleted_task_key
def test_billing_disabled_skips_limit_checks(
self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner, mock_feature_service
@ -1510,3 +1516,475 @@ class TestRobustness:
# Verify the exception message
assert "Feature service" in str(exc_info.value) or isinstance(exc_info.value, Exception)
class _SessionContext:
def __init__(self, session: MagicMock) -> None:
self._session = session
def __enter__(self) -> MagicMock:
return self._session
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[override]
return None
class TestDocumentIndexingTaskSummaryFlow:
"""Additional coverage for summary and tenant queue branches."""
def test_should_return_when_dataset_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test early return when dataset does not exist."""
# Arrange
session = MagicMock()
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = None
session.query.side_effect = lambda model: dataset_query
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
features_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.FeatureService.get_features", features_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
features_mock.assert_not_called()
def test_should_mark_documents_error_when_batch_upload_limit_exceeded(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test batch upload limit triggers error handling."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.first.return_value = document
session = MagicMock()
session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(return_value=_SessionContext(session)),
)
features = SimpleNamespace(
billing=SimpleNamespace(
enabled=True,
subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL),
),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
monkeypatch.setattr("tasks.document_indexing_task.dify_config.BATCH_UPLOAD_LIMIT", "1")
# Act
_document_indexing("dataset-1", ["doc-1", "doc-2"])
# Assert
assert document.indexing_status == "error"
assert "batch upload limit" in document.error
session.commit.assert_called_once()
def test_should_queue_summary_generation_for_completed_documents(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary generation is queued for eligible documents."""
# Arrange
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
indexing_technique="high_quality",
summary_index_setting={"enable": True},
)
doc_eligible = SimpleNamespace(
id="doc-1",
indexing_status="completed",
doc_form="text",
need_summary=True,
)
doc_skip_form = SimpleNamespace(
id="doc-2",
indexing_status="completed",
doc_form="qa_model",
need_summary=True,
)
doc_skip_status = SimpleNamespace(
id="doc-3",
indexing_status="processing",
doc_form="text",
need_summary=True,
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_docs = [SimpleNamespace(id="doc-1"), SimpleNamespace(id="doc-2"), SimpleNamespace(id="doc-3")]
phase1_document_query = MagicMock()
phase1_document_query.where.return_value = phase1_document_query
phase1_document_query.all.return_value = phase1_docs
summary_document_query = MagicMock()
summary_document_query.where.return_value = summary_document_query
summary_document_query.all.return_value = [doc_eligible, doc_skip_form, doc_skip_status]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_document_query
session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query
create_session_mock = MagicMock(
side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]
)
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
indexing_runner = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1", "doc-2", "doc-3"])
# Assert
delay_mock.assert_called_once_with("dataset-1", "doc-1", None)
def test_should_continue_when_summary_queue_fails(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary queueing errors are swallowed."""
# Arrange
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
indexing_technique="high_quality",
summary_index_setting={"enable": True},
)
doc_eligible = SimpleNamespace(
id="doc-1",
indexing_status="completed",
doc_form="text",
need_summary=True,
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_query = MagicMock()
phase1_query.where.return_value = phase1_query
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
summary_query = MagicMock()
summary_query.where.return_value = summary_query
summary_query.all.return_value = [doc_eligible]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_query
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
indexing_runner = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner))
delay_mock = MagicMock(side_effect=Exception("boom"))
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_called_once_with("dataset-1", "doc-1", None)
def test_should_return_when_dataset_missing_after_indexing(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test early return when dataset is missing after indexing."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.side_effect = [dataset, None]
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session3.query.side_effect = lambda model: dataset_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
session3.query.assert_called()
def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary generation skipped when indexing_technique is not high_quality."""
# Arrange
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
indexing_technique="economy",
summary_index_setting={"enable": True},
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session3.query.side_effect = lambda model: dataset_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_not_called()
def test_should_skip_summary_generation_when_indexing_paused(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary generation is skipped when indexing is paused."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)])
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
runner = MagicMock()
runner.run.side_effect = DocumentIsPausedError("paused")
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_not_called()
def test_should_handle_indexing_runner_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test generic indexing runner exception is handled."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
runner = MagicMock()
runner.run.side_effect = RuntimeError("boom")
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_not_called()
def test_should_log_missing_document_entry_in_summary_list(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test falsey document entries are handled in summary iteration."""
# Arrange
class _FalseyDocument:
def __init__(self, doc_id: str) -> None:
self.id = doc_id
def __bool__(self) -> bool:
return False
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
indexing_technique="high_quality",
summary_index_setting={"enable": True},
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_query = MagicMock()
phase1_query.where.return_value = phase1_query
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
summary_query = MagicMock()
summary_query.where.return_value = summary_query
summary_query.all.return_value = [_FalseyDocument("missing-doc")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_query
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_not_called()
def test_normal_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test normal indexing task delegates to tenant queue handler."""
# Arrange
handler = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler)
# Act
normal_document_indexing_task("tenant-1", "dataset-1", ["doc-1"])
# Assert
handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], normal_document_indexing_task)
def test_priority_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test priority indexing task delegates to tenant queue handler."""
# Arrange
handler = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler)
# Act
priority_document_indexing_task("tenant-1", "dataset-1", ["doc-1"])
# Assert
handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], priority_document_indexing_task)

40
api/uv.lock generated
View File

@ -53,23 +53,6 @@ dependencies = [
]
sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748, upload-time = "2026-03-28T17:19:40.6Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/7e/cb94129302d78c46662b47f9897d642fd0b33bdfef4b73b20c6ced35aa4c/aiohttp-3.13.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8ea0c64d1bcbf201b285c2246c51a0c035ba3bbd306640007bc5844a3b4658c1", size = 760027, upload-time = "2026-03-28T17:15:33.022Z" },
{ url = "https://files.pythonhosted.org/packages/5e/cd/2db3c9397c3bd24216b203dd739945b04f8b87bb036c640da7ddb63c75ef/aiohttp-3.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6f742e1fa45c0ed522b00ede565e18f97e4cf8d1883a712ac42d0339dfb0cce7", size = 508325, upload-time = "2026-03-28T17:15:34.714Z" },
{ url = "https://files.pythonhosted.org/packages/36/a3/d28b2722ec13107f2e37a86b8a169897308bab6a3b9e071ecead9d67bd9b/aiohttp-3.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dcfb50ee25b3b7a1222a9123be1f9f89e56e67636b561441f0b304e25aaef8f", size = 502402, upload-time = "2026-03-28T17:15:36.409Z" },
{ url = "https://files.pythonhosted.org/packages/fa/d6/acd47b5f17c4430e555590990a4746efbcb2079909bb865516892bf85f37/aiohttp-3.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3262386c4ff370849863ea93b9ea60fd59c6cf56bf8f93beac625cf4d677c04d", size = 1771224, upload-time = "2026-03-28T17:15:38.223Z" },
{ url = "https://files.pythonhosted.org/packages/98/af/af6e20113ba6a48fd1cd9e5832c4851e7613ef50c7619acdaee6ec5f1aff/aiohttp-3.13.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:473bb5aa4218dd254e9ae4834f20e31f5a0083064ac0136a01a62ddbae2eaa42", size = 1731530, upload-time = "2026-03-28T17:15:39.988Z" },
{ url = "https://files.pythonhosted.org/packages/81/16/78a2f5d9c124ad05d5ce59a9af94214b6466c3491a25fb70760e98e9f762/aiohttp-3.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e56423766399b4c77b965f6aaab6c9546617b8994a956821cc507d00b91d978c", size = 1827925, upload-time = "2026-03-28T17:15:41.944Z" },
{ url = "https://files.pythonhosted.org/packages/2a/1f/79acf0974ced805e0e70027389fccbb7d728e6f30fcac725fb1071e63075/aiohttp-3.13.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8af249343fafd5ad90366a16d230fc265cf1149f26075dc9fe93cfd7c7173942", size = 1923579, upload-time = "2026-03-28T17:15:44.071Z" },
{ url = "https://files.pythonhosted.org/packages/af/53/29f9e2054ea6900413f3b4c3eb9d8331f60678ec855f13ba8714c47fd48d/aiohttp-3.13.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bc0a5cf4f10ef5a2c94fdde488734b582a3a7a000b131263e27c9295bd682d9", size = 1767655, upload-time = "2026-03-28T17:15:45.911Z" },
{ url = "https://files.pythonhosted.org/packages/f3/57/462fe1d3da08109ba4aa8590e7aed57c059af2a7e80ec21f4bac5cfe1094/aiohttp-3.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:5c7ff1028e3c9fc5123a865ce17df1cb6424d180c503b8517afbe89aa566e6be", size = 1630439, upload-time = "2026-03-28T17:15:48.11Z" },
{ url = "https://files.pythonhosted.org/packages/d7/4b/4813344aacdb8127263e3eec343d24e973421143826364fa9fc847f6283f/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ba5cf98b5dcb9bddd857da6713a503fa6d341043258ca823f0f5ab7ab4a94ee8", size = 1745557, upload-time = "2026-03-28T17:15:50.13Z" },
{ url = "https://files.pythonhosted.org/packages/d4/01/1ef1adae1454341ec50a789f03cfafe4c4ac9c003f6a64515ecd32fe4210/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:d85965d3ba21ee4999e83e992fecb86c4614d6920e40705501c0a1f80a583c12", size = 1741796, upload-time = "2026-03-28T17:15:52.351Z" },
{ url = "https://files.pythonhosted.org/packages/22/04/8cdd99af988d2aa6922714d957d21383c559835cbd43fbf5a47ddf2e0f05/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:49f0b18a9b05d79f6f37ddd567695943fcefb834ef480f17a4211987302b2dc7", size = 1805312, upload-time = "2026-03-28T17:15:54.407Z" },
{ url = "https://files.pythonhosted.org/packages/fb/7f/b48d5577338d4b25bbdbae35c75dbfd0493cb8886dc586fbfb2e90862239/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7f78cb080c86fbf765920e5f1ef35af3f24ec4314d6675d0a21eaf41f6f2679c", size = 1621751, upload-time = "2026-03-28T17:15:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/bc/89/4eecad8c1858e6d0893c05929e22343e0ebe3aec29a8a399c65c3cc38311/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:67a3ec705534a614b68bbf1c70efa777a21c3da3895d1c44510a41f5a7ae0453", size = 1826073, upload-time = "2026-03-28T17:15:58.489Z" },
{ url = "https://files.pythonhosted.org/packages/f5/5c/9dc8293ed31b46c39c9c513ac7ca152b3c3d38e0ea111a530ad12001b827/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d6630ec917e85c5356b2295744c8a97d40f007f96a1c76bf1928dc2e27465393", size = 1760083, upload-time = "2026-03-28T17:16:00.677Z" },
{ url = "https://files.pythonhosted.org/packages/1e/19/8bbf6a4994205d96831f97b7d21a0feed120136e6267b5b22d229c6dc4dc/aiohttp-3.13.4-cp311-cp311-win32.whl", hash = "sha256:54049021bc626f53a5394c29e8c444f726ee5a14b6e89e0ad118315b1f90f5e3", size = 439690, upload-time = "2026-03-28T17:16:02.902Z" },
{ url = "https://files.pythonhosted.org/packages/0c/f5/ac409ecd1007528d15c3e8c3a57d34f334c70d76cfb7128a28cffdebd4c1/aiohttp-3.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:c033f2bc964156030772d31cbf7e5defea181238ce1f87b9455b786de7d30145", size = 463824, upload-time = "2026-03-28T17:16:05.058Z" },
{ url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158, upload-time = "2026-03-28T17:16:06.901Z" },
{ url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037, upload-time = "2026-03-28T17:16:08.82Z" },
{ url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556, upload-time = "2026-03-28T17:16:10.63Z" },
@ -1586,7 +1569,7 @@ dev = [
{ name = "lxml-stubs", specifier = "~=0.5.1" },
{ name = "mypy", specifier = "~=1.19.1" },
{ name = "pandas-stubs", specifier = "~=3.0.0" },
{ name = "pyrefly", specifier = ">=0.57.1" },
{ name = "pyrefly", specifier = ">=0.59.1" },
{ name = "pytest", specifier = "~=9.0.2" },
{ name = "pytest-benchmark", specifier = "~=5.2.3" },
{ name = "pytest-cov", specifier = "~=7.1.0" },
@ -4839,18 +4822,19 @@ wheels = [
[[package]]
name = "pyrefly"
version = "0.57.1"
version = "0.59.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" }
sdist = { url = "https://files.pythonhosted.org/packages/d5/ce/7882c2af92b2ff6505fcd3430eff8048ece6c6254cc90bdc76ecee12dfab/pyrefly-0.59.1.tar.gz", hash = "sha256:bf1675b0c38d45df2c8f8618cbdfa261a1b92430d9d31eba16e0282b551e210f", size = 5475432, upload-time = "2026-04-01T22:04:04.11Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" },
{ url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" },
{ url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" },
{ url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" },
{ url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" },
{ url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" },
{ url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" },
{ url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" },
{ url = "https://files.pythonhosted.org/packages/d0/10/04a0e05b08fc855b6fe38c3df549925fc3c2c6e750506870de7335d3e1f7/pyrefly-0.59.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:390db3cd14aa7e0268e847b60cd9ee18b04273eddfa38cf341ed3bb43f3fef2a", size = 12868133, upload-time = "2026-04-01T22:03:39.436Z" },
{ url = "https://files.pythonhosted.org/packages/c7/78/fa7be227c3e3fcacee501c1562278dd026186ffd1b5b5beb51d3941a3aed/pyrefly-0.59.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d246d417b6187c1650d7f855f61c68fbfd6d6155dc846d4e4d273a3e6b5175cb", size = 12379325, upload-time = "2026-04-01T22:03:42.046Z" },
{ url = "https://files.pythonhosted.org/packages/bb/13/6828ce1c98171b5f8388f33c4b0b9ea2ab8c49abe0ef8d793c31e30a05cb/pyrefly-0.59.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575ac67b04412dc651a7143d27e38a40fbdd3c831c714d5520d0e9d4c8631ab4", size = 35826408, upload-time = "2026-04-01T22:03:45.067Z" },
{ url = "https://files.pythonhosted.org/packages/23/56/79ed8ece9a7ecad0113c394a06a084107db3ad8f1fefe19e7ded43c51245/pyrefly-0.59.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:062e6262ce1064d59dcad81ac0499bb7a3ad501e9bc8a677a50dc630ff0bf862", size = 38532699, upload-time = "2026-04-01T22:03:48.376Z" },
{ url = "https://files.pythonhosted.org/packages/18/7d/ecc025e0f0e3f295b497f523cc19cefaa39e57abede8fc353d29445d174b/pyrefly-0.59.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ef4247f9e6f734feb93e1f2b75335b943629956e509f545cc9cdcccd76dd20", size = 36743570, upload-time = "2026-04-01T22:03:51.362Z" },
{ url = "https://files.pythonhosted.org/packages/2f/03/b1ce882ebcb87c673165c00451fbe4df17bf96ccfde18c75880dc87c5f5e/pyrefly-0.59.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a2d01723b84d042f4fa6ec871ffd52d0a7e83b0ea791c2e0bb0ff750abce56", size = 41236246, upload-time = "2026-04-01T22:03:54.361Z" },
{ url = "https://files.pythonhosted.org/packages/17/af/5e9c7afd510e7dd64a2204be0ed39e804089cbc4338675a28615c7176acb/pyrefly-0.59.1-py3-none-win32.whl", hash = "sha256:4ea70c780848f8376411e787643ae5d2d09da8a829362332b7b26d15ebcbaf56", size = 11884747, upload-time = "2026-04-01T22:03:56.776Z" },
{ url = "https://files.pythonhosted.org/packages/aa/c1/7db1077627453fd1068f0761f059a9512645c00c4c20acfb9f0c24ac02ec/pyrefly-0.59.1-py3-none-win_amd64.whl", hash = "sha256:67e6a08cfd129a0d2788d5e40a627f9860e0fe91a876238d93d5c63ff4af68ae", size = 12720608, upload-time = "2026-04-01T22:03:59.252Z" },
{ url = "https://files.pythonhosted.org/packages/07/16/4bb6e5fce5a9cf0992932d9435d964c33e507aaaf96fdfbb1be493078a4a/pyrefly-0.59.1-py3-none-win_arm64.whl", hash = "sha256:01179cb215cf079e8223a064f61a074f7079aa97ea705cbbc68af3d6713afd15", size = 12223158, upload-time = "2026-04-01T22:04:01.869Z" },
]
[[package]]

View File

@ -5,7 +5,6 @@
"prepare": "vp config"
},
"devDependencies": {
"taze": "catalog:",
"vite-plus": "catalog:"
},
"engines": {

View File

Before

Width:  |  Height:  |  Size: 36 KiB

After

Width:  |  Height:  |  Size: 36 KiB

View File

Before

Width:  |  Height:  |  Size: 465 B

After

Width:  |  Height:  |  Size: 465 B

View File

Before

Width:  |  Height:  |  Size: 643 B

After

Width:  |  Height:  |  Size: 643 B

View File

Before

Width:  |  Height:  |  Size: 297 B

After

Width:  |  Height:  |  Size: 297 B

View File

Before

Width:  |  Height:  |  Size: 3.0 KiB

After

Width:  |  Height:  |  Size: 3.0 KiB

View File

Before

Width:  |  Height:  |  Size: 607 B

After

Width:  |  Height:  |  Size: 607 B

View File

Before

Width:  |  Height:  |  Size: 599 B

After

Width:  |  Height:  |  Size: 599 B

View File

Before

Width:  |  Height:  |  Size: 2.4 KiB

After

Width:  |  Height:  |  Size: 2.4 KiB

View File

Before

Width:  |  Height:  |  Size: 874 B

After

Width:  |  Height:  |  Size: 874 B

View File

Before

Width:  |  Height:  |  Size: 435 B

After

Width:  |  Height:  |  Size: 435 B

View File

Before

Width:  |  Height:  |  Size: 1.6 KiB

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

Before

Width:  |  Height:  |  Size: 364 KiB

After

Width:  |  Height:  |  Size: 364 KiB

View File

Before

Width:  |  Height:  |  Size: 1.0 KiB

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

Before

Width:  |  Height:  |  Size: 1.4 KiB

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

Before

Width:  |  Height:  |  Size: 8.1 KiB

After

Width:  |  Height:  |  Size: 8.1 KiB

View File

Before

Width:  |  Height:  |  Size: 1.3 KiB

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

Before

Width:  |  Height:  |  Size: 561 B

After

Width:  |  Height:  |  Size: 561 B

View File

Before

Width:  |  Height:  |  Size: 7.5 KiB

After

Width:  |  Height:  |  Size: 7.5 KiB

View File

Before

Width:  |  Height:  |  Size: 193 B

After

Width:  |  Height:  |  Size: 193 B

View File

Before

Width:  |  Height:  |  Size: 717 B

After

Width:  |  Height:  |  Size: 717 B

View File

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

Before

Width:  |  Height:  |  Size: 56 KiB

After

Width:  |  Height:  |  Size: 56 KiB

View File

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

Before

Width:  |  Height:  |  Size: 214 B

After

Width:  |  Height:  |  Size: 214 B

View File

Before

Width:  |  Height:  |  Size: 3.1 KiB

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

Before

Width:  |  Height:  |  Size: 3.3 KiB

After

Width:  |  Height:  |  Size: 3.3 KiB

View File

Before

Width:  |  Height:  |  Size: 3.5 KiB

After

Width:  |  Height:  |  Size: 3.5 KiB

View File

Before

Width:  |  Height:  |  Size: 1.9 KiB

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

Before

Width:  |  Height:  |  Size: 3.5 KiB

After

Width:  |  Height:  |  Size: 3.5 KiB

View File

Before

Width:  |  Height:  |  Size: 1.9 KiB

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

Before

Width:  |  Height:  |  Size: 2.3 KiB

After

Width:  |  Height:  |  Size: 2.3 KiB

View File

Before

Width:  |  Height:  |  Size: 2.6 KiB

After

Width:  |  Height:  |  Size: 2.6 KiB

View File

Before

Width:  |  Height:  |  Size: 4.5 KiB

After

Width:  |  Height:  |  Size: 4.5 KiB

View File

Before

Width:  |  Height:  |  Size: 1.6 KiB

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

Before

Width:  |  Height:  |  Size: 2.4 KiB

After

Width:  |  Height:  |  Size: 2.4 KiB

View File

Before

Width:  |  Height:  |  Size: 2.3 KiB

After

Width:  |  Height:  |  Size: 2.3 KiB

Some files were not shown because too many files have changed in this diff Show More