Merge branch 'main' into feat/workflow-run-history-infinite-scroll
2
.gitignore
vendored
@ -212,7 +212,7 @@ api/.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
/node_modules
|
||||
node_modules
|
||||
.vite-hooks/_
|
||||
|
||||
# plugin migrate
|
||||
|
||||
@ -89,6 +89,12 @@ if $web_modified; then
|
||||
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
|
||||
fi
|
||||
|
||||
echo "Running knip"
|
||||
if ! pnpm run knip; then
|
||||
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Running unit tests check"
|
||||
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@ -71,7 +71,7 @@ class AppImportApi(Resource):
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
|
||||
@ -193,7 +193,7 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
)
|
||||
|
||||
|
||||
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -210,7 +210,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: Any, **kwargs: Any):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import overload
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(
|
||||
view: Callable[..., Any] | None = None,
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R],
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: Any, **kwargs: Any):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
@ -68,14 +84,30 @@ def get_app_model(
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def get_app_model_with_trial(
|
||||
view: Callable[..., Any] | None = None,
|
||||
@overload
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: Callable[P, R],
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: Any, **kwargs: Any):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
|
||||
@ -158,10 +158,11 @@ class DataSourceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
binding_id = str(binding_id)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response, request
|
||||
@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -70,7 +71,7 @@ def _api_prerequisite(f):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Any, cast, overload
|
||||
from typing import cast, overload
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R](
|
||||
return interceptor
|
||||
|
||||
|
||||
def validate_dataset_token(
|
||||
view: Callable[..., Any] | None = None,
|
||||
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(view_func)
|
||||
def decorated(*args: Any, **kwargs: Any) -> Any:
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
positional_parameters = [
|
||||
parameter
|
||||
for parameter in inspect.signature(view).parameters.values()
|
||||
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||
]
|
||||
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
@wraps(view)
|
||||
def decorated(*args: object, **kwargs: object) -> R:
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# First try to get from kwargs (explicit parameter)
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
# If not in kwargs, try to extract from positional args
|
||||
if not dataset_id and args:
|
||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
||||
# This is a class method, dataset_id should be in args[1]
|
||||
potential_id = args[1]
|
||||
# Validate it's a string-like UUID, not another object
|
||||
try:
|
||||
# Try to convert to string and check if it's a valid UUID format
|
||||
str_id = str(potential_id)
|
||||
# Basic check: UUIDs are 36 chars with hyphens
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from class method args")
|
||||
elif len(args) > 0:
|
||||
# Not a class method, check if args[0] looks like a UUID
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from positional args")
|
||||
if not dataset_id and args:
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from positional args")
|
||||
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.get(Account, ta.account_id)
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.get(Account, ta.account_id)
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
if args and isinstance(args[0], Resource):
|
||||
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
|
||||
return view_func(api_token.tenant_id, *args, **kwargs)
|
||||
if expects_bound_instance:
|
||||
if not args:
|
||||
raise TypeError("validate_dataset_token expected a bound resource instance.")
|
||||
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
||||
|
||||
return decorated
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
return decorated
|
||||
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
|
||||
@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
from controllers.trigger import bp
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -23,6 +23,7 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
|
||||
webhook_id, is_debug=is_debug
|
||||
)
|
||||
|
||||
webhook_data: RawWebhookDataDict
|
||||
try:
|
||||
# Use new unified extraction and validation
|
||||
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
@ -3,13 +3,19 @@
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import orjson
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class IdentityDict(TypedDict, total=False):
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
user_type: str
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
@ -84,7 +90,7 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
|
||||
return log_dict
|
||||
|
||||
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
|
||||
def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None:
|
||||
tenant_id = getattr(record, "tenant_id", None)
|
||||
user_id = getattr(record, "user_id", None)
|
||||
user_type = getattr(record, "user_type", None)
|
||||
@ -92,7 +98,7 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
if not any([tenant_id, user_id, user_type]):
|
||||
return None
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
identity: IdentityDict = {}
|
||||
if tenant_id:
|
||||
identity["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Callable
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Self, cast
|
||||
from typing import Any, Self
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
@ -338,12 +338,11 @@ class BaseSession[
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
validated_request = cast(ReceiveRequestT, validated_request)
|
||||
|
||||
responder = RequestResponder[ReceiveRequestT, SendResultT](
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request,
|
||||
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
@ -359,15 +358,14 @@ class BaseSession[
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
notification = cast(ReceiveNotificationT, notification)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
self._received_notification(notification)
|
||||
self._handle_incoming(notification)
|
||||
self._received_notification(notification) # type: ignore[arg-type]
|
||||
self._handle_incoming(notification) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
|
||||
|
||||
@ -1,5 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from flask import Flask
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_login import DifyLoginManager
|
||||
|
||||
|
||||
class DifyApp(Flask):
|
||||
pass
|
||||
"""Flask application type with Dify-specific extension attributes."""
|
||||
|
||||
login_manager: DifyLoginManager
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import Response, request
|
||||
from flask import Request, Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
@ -16,13 +17,35 @@ from models import Account, Tenant, TenantAccountJoin
|
||||
from models.model import AppMCPServer, EndUser
|
||||
from services.account_service import AccountService
|
||||
|
||||
login_manager = flask_login.LoginManager()
|
||||
type LoginUser = Account | EndUser
|
||||
|
||||
|
||||
class DifyLoginManager(flask_login.LoginManager):
|
||||
"""Project-specific Flask-Login manager with a stable unauthorized contract.
|
||||
|
||||
Dify registers `unauthorized_handler` below to always return a JSON `Response`.
|
||||
Overriding this method lets callers rely on that narrower return type instead of
|
||||
Flask-Login's broader callback contract.
|
||||
"""
|
||||
|
||||
def unauthorized(self) -> Response:
|
||||
"""Return the registered unauthorized handler result as a Flask `Response`."""
|
||||
return cast(Response, super().unauthorized())
|
||||
|
||||
def load_user_from_request_context(self) -> None:
|
||||
"""Populate Flask-Login's request-local user cache for the current request."""
|
||||
self._load_user()
|
||||
|
||||
|
||||
login_manager = DifyLoginManager()
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
def load_user_from_request(request_from_flask_login: Request) -> LoginUser | None:
|
||||
"""Load user based on the request."""
|
||||
del request_from_flask_login
|
||||
|
||||
# Skip authentication for documentation endpoints
|
||||
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
|
||||
return None
|
||||
@ -100,10 +123,12 @@ def load_user_from_request(request_from_flask_login):
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@user_logged_in.connect
|
||||
@user_loaded_from_request.connect
|
||||
def on_user_logged_in(_sender, user):
|
||||
def on_user_logged_in(_sender: object, user: LoginUser) -> None:
|
||||
"""Called when a user logged in.
|
||||
|
||||
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
|
||||
@ -114,8 +139,10 @@ def on_user_logged_in(_sender, user):
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
def unauthorized_handler() -> Response:
|
||||
"""Handle unauthorized requests."""
|
||||
# Keep this as a concrete `Response`; `DifyLoginManager.unauthorized()` narrows
|
||||
# Flask-Login's callback contract based on this override.
|
||||
return Response(
|
||||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||
status=401,
|
||||
@ -123,5 +150,5 @@ def unauthorized_handler():
|
||||
)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
def init_app(app: DifyApp) -> None:
|
||||
login_manager.init_app(app)
|
||||
|
||||
@ -2,19 +2,19 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from flask import current_app, g, has_request_context, request
|
||||
from flask import Response, current_app, g, has_request_context, request
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_login import DifyLoginManager
|
||||
from libs.token import check_csrf_token
|
||||
from models import Account
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flask.typing import ResponseReturnValue
|
||||
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
@ -29,7 +29,13 @@ def _resolve_current_user() -> EndUser | Account | None:
|
||||
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||
|
||||
|
||||
def current_account_with_tenant():
|
||||
def _get_login_manager() -> DifyLoginManager:
|
||||
"""Return the project login manager with Dify's narrowed unauthorized contract."""
|
||||
app = cast(DifyApp, current_app)
|
||||
return app.login_manager
|
||||
|
||||
|
||||
def current_account_with_tenant() -> tuple[Account, str]:
|
||||
"""
|
||||
Resolve the underlying account for the current user proxy and ensure tenant context exists.
|
||||
Allows tests to supply plain Account mocks without the LocalProxy helper.
|
||||
@ -42,7 +48,7 @@ def current_account_with_tenant():
|
||||
return user, user.current_tenant_id
|
||||
|
||||
|
||||
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
|
||||
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
"""
|
||||
If you decorate a view with this, it will ensure that the current user is
|
||||
logged in and authenticated before calling the actual view. (If they are
|
||||
@ -77,13 +83,16 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
user = _resolve_current_user()
|
||||
if user is None or not user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
# `DifyLoginManager` guarantees that the registered unauthorized handler
|
||||
# is surfaced here as a concrete Flask `Response`.
|
||||
unauthorized_response: Response = _get_login_manager().unauthorized()
|
||||
return unauthorized_response
|
||||
g._login_user = user
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
@ -96,7 +105,7 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
|
||||
def _get_user() -> EndUser | Account | None:
|
||||
if has_request_context():
|
||||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user() # type: ignore
|
||||
_get_login_manager().load_user_from_request_context()
|
||||
|
||||
return g._login_user
|
||||
|
||||
|
||||
@ -171,7 +171,7 @@ dev = [
|
||||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.57.1",
|
||||
"pyrefly>=0.59.1",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
||||
@ -8,7 +8,7 @@ from hashlib import sha256
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@ -144,22 +144,26 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
account = db.session.query(Account).filter_by(id=user_id).first()
|
||||
account = db.session.get(Account, user_id)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
current_tenant = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
|
||||
.limit(1)
|
||||
)
|
||||
if current_tenant:
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
else:
|
||||
available_ta = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id)
|
||||
available_ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not available_ta:
|
||||
return None
|
||||
@ -195,7 +199,7 @@ class AccountService:
|
||||
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
@ -371,8 +375,10 @@ class AccountService:
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: AccountIntegrate | None = (
|
||||
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
|
||||
account_integrate: AccountIntegrate | None = db.session.scalar(
|
||||
select(AccountIntegrate)
|
||||
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if account_integrate:
|
||||
@ -416,7 +422,9 @@ class AccountService:
|
||||
def update_account_email(account: Account, email: str) -> Account:
|
||||
"""Update account email"""
|
||||
account.email = email
|
||||
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first()
|
||||
account_integrate = db.session.scalar(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
|
||||
)
|
||||
if account_integrate:
|
||||
db.session.delete(account_integrate)
|
||||
db.session.add(account)
|
||||
@ -818,7 +826,7 @@ class AccountService:
|
||||
)
|
||||
)
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
return None
|
||||
|
||||
@ -1018,7 +1026,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def check_email_unique(email: str) -> bool:
|
||||
return db.session.query(Account).filter_by(email=email).first() is None
|
||||
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
|
||||
|
||||
|
||||
class TenantService:
|
||||
@ -1384,10 +1392,10 @@ class RegisterService:
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DifySetup).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.execute(delete(DifySetup))
|
||||
db.session.execute(delete(TenantAccountJoin))
|
||||
db.session.execute(delete(Account))
|
||||
db.session.execute(delete(Tenant))
|
||||
db.session.commit()
|
||||
|
||||
logger.exception("Setup account failed, email: %s, name: %s", email, name)
|
||||
@ -1488,7 +1496,11 @@ class RegisterService:
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not ta:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
@ -1545,21 +1557,18 @@ class RegisterService:
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
tenant = (
|
||||
db.session.query(Tenant)
|
||||
.where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
|
||||
.first()
|
||||
tenant = db.session.scalar(
|
||||
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
tenant_account = db.session.execute(
|
||||
select(Account, TenantAccountJoin.role)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
|
||||
.first()
|
||||
)
|
||||
).first()
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
|
||||
@ -4,6 +4,8 @@ import uuid
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import TypedDict
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -23,6 +25,27 @@ from tasks.annotation.enable_annotation_reply_task import enable_annotation_repl
|
||||
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
|
||||
|
||||
|
||||
class AnnotationJobStatusDict(TypedDict):
|
||||
job_id: str
|
||||
job_status: str
|
||||
|
||||
|
||||
class EmbeddingModelDict(TypedDict):
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class AnnotationSettingDict(TypedDict):
|
||||
id: str
|
||||
enabled: bool
|
||||
score_threshold: float
|
||||
embedding_model: EmbeddingModelDict | dict
|
||||
|
||||
|
||||
class AnnotationSettingDisabledDict(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
@ -85,7 +108,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str):
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@ -109,7 +132,7 @@ class AppAnnotationService:
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str):
|
||||
def disable_app_annotation(cls, app_id: str) -> AnnotationJobStatusDict:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
@ -567,7 +590,7 @@ class AppAnnotationService:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
@ -602,7 +625,9 @@ class AppAnnotationService:
|
||||
return {"enabled": False}
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
def update_app_annotation_setting(
|
||||
cls, app_id: str, annotation_setting_id: str, args: dict
|
||||
) -> AnnotationSettingDict:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
|
||||
@ -32,6 +32,11 @@ class SubscriptionPlan(TypedDict):
|
||||
expiration_date: int
|
||||
|
||||
|
||||
class KnowledgeRateLimitDict(TypedDict):
|
||||
limit: int
|
||||
subscription_plan: str
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
@ -58,7 +63,7 @@ class BillingService:
|
||||
return usage_info
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str):
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
|
||||
|
||||
@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from graphon.nodes.http_request.exc import InvalidHttpMethodError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import ssrf_proxy
|
||||
@ -103,8 +103,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@ -112,8 +114,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@ -132,8 +136,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@ -144,9 +150,12 @@ class ExternalDatasetService:
|
||||
@staticmethod
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
||||
count = (
|
||||
db.session.query(ExternalKnowledgeBindings)
|
||||
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(ExternalKnowledgeBindings.id)).where(
|
||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if count > 0:
|
||||
return True, count
|
||||
@ -154,8 +163,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
external_knowledge_binding: ExternalKnowledgeBindings | None = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_binding: ExternalKnowledgeBindings | None = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
@ -163,8 +174,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
raise ValueError("api template not found")
|
||||
@ -238,12 +251,17 @@ class ExternalDatasetService:
|
||||
@staticmethod
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
|
||||
if db.session.scalar(
|
||||
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
|
||||
):
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
|
||||
.first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(
|
||||
ExternalKnowledgeApis.id == args.get("external_knowledge_api_id"),
|
||||
ExternalKnowledgeApis.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if external_knowledge_api is None:
|
||||
@ -286,16 +304,18 @@ class ExternalDatasetService:
|
||||
external_retrieval_parameters: dict,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
):
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
|
||||
.first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
raise ValueError("external api template not found")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
|
||||
@ -18,6 +18,16 @@ from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryDict(TypedDict):
|
||||
content: str
|
||||
|
||||
|
||||
class RetrieveResponseDict(TypedDict):
|
||||
query: QueryDict
|
||||
records: list[dict[str, Any]]
|
||||
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
@ -150,7 +160,7 @@ class HitTestingService:
|
||||
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
|
||||
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict:
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
|
||||
return {
|
||||
@ -161,7 +171,7 @@ class HitTestingService:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
|
||||
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> RetrieveResponseDict:
|
||||
records = []
|
||||
if dataset.provider == "external":
|
||||
for document in documents:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -25,10 +27,14 @@ class MetadataService:
|
||||
raise ValueError("Metadata name cannot exceed 255 characters.")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check if metadata name already exists
|
||||
if (
|
||||
db.session.query(DatasetMetadata)
|
||||
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
|
||||
.first()
|
||||
if db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(
|
||||
DatasetMetadata.tenant_id == current_tenant_id,
|
||||
DatasetMetadata.dataset_id == dataset_id,
|
||||
DatasetMetadata.name == metadata_args.name,
|
||||
)
|
||||
.limit(1)
|
||||
):
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
@ -54,10 +60,14 @@ class MetadataService:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
# check if metadata name already exists
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if (
|
||||
db.session.query(DatasetMetadata)
|
||||
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name)
|
||||
.first()
|
||||
if db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(
|
||||
DatasetMetadata.tenant_id == current_tenant_id,
|
||||
DatasetMetadata.dataset_id == dataset_id,
|
||||
DatasetMetadata.name == name,
|
||||
)
|
||||
.limit(1)
|
||||
):
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
@ -65,7 +75,11 @@ class MetadataService:
|
||||
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
|
||||
metadata = db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
|
||||
.limit(1)
|
||||
)
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
old_name = metadata.name
|
||||
@ -74,9 +88,9 @@ class MetadataService:
|
||||
metadata.updated_at = naive_utc_now()
|
||||
|
||||
# update related documents
|
||||
dataset_metadata_bindings = (
|
||||
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||
)
|
||||
dataset_metadata_bindings = db.session.scalars(
|
||||
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
|
||||
).all()
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
@ -101,15 +115,19 @@ class MetadataService:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
|
||||
metadata = db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
|
||||
.limit(1)
|
||||
)
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
db.session.delete(metadata)
|
||||
|
||||
# deal related documents
|
||||
dataset_metadata_bindings = (
|
||||
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||
)
|
||||
dataset_metadata_bindings = db.session.scalars(
|
||||
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
|
||||
).all()
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
@ -224,16 +242,23 @@ class MetadataService:
|
||||
|
||||
# deal metadata binding (in the same transaction as the doc_metadata update)
|
||||
if not operation.partial_update:
|
||||
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
|
||||
db.session.execute(
|
||||
delete(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.document_id == operation.document_id
|
||||
)
|
||||
)
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
for metadata_value in operation.metadata_list:
|
||||
# check if binding already exists
|
||||
if operation.partial_update:
|
||||
existing_binding = (
|
||||
db.session.query(DatasetMetadataBinding)
|
||||
.filter_by(document_id=operation.document_id, metadata_id=metadata_value.id)
|
||||
.first()
|
||||
existing_binding = db.session.scalar(
|
||||
select(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.document_id == operation.document_id,
|
||||
DatasetMetadataBinding.metadata_id == metadata_value.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if existing_binding:
|
||||
continue
|
||||
@ -275,9 +300,13 @@ class MetadataService:
|
||||
"id": item.get("id"),
|
||||
"name": item.get("name"),
|
||||
"type": item.get("type"),
|
||||
"count": db.session.query(DatasetMetadataBinding)
|
||||
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
|
||||
.count(),
|
||||
"count": db.session.scalar(
|
||||
select(func.count(DatasetMetadataBinding.id)).where(
|
||||
DatasetMetadataBinding.metadata_id == item.get("id"),
|
||||
DatasetMetadataBinding.dataset_id == dataset.id,
|
||||
)
|
||||
)
|
||||
or 0,
|
||||
}
|
||||
for item in dataset.doc_metadata or []
|
||||
if item.get("id") != "built-in"
|
||||
|
||||
@ -156,27 +156,27 @@ class RagPipelineService:
|
||||
:param template_id: template id
|
||||
:param template_info: template info
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
# check template name is exist
|
||||
template_name = template_info.name
|
||||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
template = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
PipelineCustomizedTemplate.id != template_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if template:
|
||||
raise ValueError("Template name is already exists")
|
||||
@ -192,13 +192,13 @@ class RagPipelineService:
|
||||
"""
|
||||
Delete customized pipeline template.
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
@ -210,14 +210,14 @@ class RagPipelineService:
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by rag pipeline
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# return draft workflow
|
||||
@ -232,28 +232,28 @@ class RagPipelineService:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == pipeline.workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""Fetch a published workflow snapshot by ID for restore operations."""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if workflow and workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise IsDraftWorkflowError("source workflow must be published")
|
||||
@ -974,7 +974,7 @@ class RagPipelineService:
|
||||
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
|
||||
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
|
||||
if document_id:
|
||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||
document = db.session.get(Document, document_id.value)
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = error
|
||||
@ -1178,12 +1178,12 @@ class RagPipelineService:
|
||||
"""
|
||||
Publish customized pipeline template
|
||||
"""
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, pipeline_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
if not pipeline.workflow_id:
|
||||
raise ValueError("Pipeline workflow not found")
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = db.session.get(Workflow, pipeline.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
with Session(db.engine) as session:
|
||||
@ -1194,21 +1194,21 @@ class RagPipelineService:
|
||||
# check template name is exist
|
||||
template_name = args.get("name")
|
||||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
template = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if template:
|
||||
raise ValueError("Template name is already exists")
|
||||
|
||||
max_position = (
|
||||
db.session.query(func.max(PipelineCustomizedTemplate.position))
|
||||
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(PipelineCustomizedTemplate.position)).where(
|
||||
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
|
||||
)
|
||||
)
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
@ -1239,13 +1239,14 @@ class RagPipelineService:
|
||||
|
||||
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
|
||||
return (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
db.session.scalar(
|
||||
select(func.count(Workflow.id)).where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
def get_node_last_run(
|
||||
@ -1353,11 +1354,11 @@ class RagPipelineService:
|
||||
|
||||
def get_recommended_plugins(self, type: str) -> dict:
|
||||
# Query active recommended plugins
|
||||
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
if type and type != "all":
|
||||
query = query.where(PipelineRecommendedPlugin.type == type)
|
||||
stmt = stmt.where(PipelineRecommendedPlugin.type == type)
|
||||
|
||||
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
|
||||
pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
|
||||
|
||||
if not pipeline_recommended_plugins:
|
||||
return {
|
||||
@ -1396,14 +1397,12 @@ class RagPipelineService:
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
document_pipeline_execution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
document_pipeline_execution_log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
|
||||
)
|
||||
if not document_pipeline_execution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
# convert to app config
|
||||
@ -1432,23 +1431,23 @@ class RagPipelineService:
|
||||
"""
|
||||
Get datasource plugins
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
dataset: Dataset | None = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
pipeline: Pipeline | None = db.session.scalar(
|
||||
select(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
@ -1530,23 +1529,23 @@ class RagPipelineService:
|
||||
"""
|
||||
Get pipeline
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
dataset: Dataset | None = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
pipeline: Pipeline | None = db.session.scalar(
|
||||
select(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
from datetime import datetime
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy import delete, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -42,20 +42,22 @@ class WorkflowToolManageService:
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
existing_workflow_tool_provider = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
|
||||
|
||||
app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_app_id} not found")
|
||||
@ -122,30 +124,30 @@ class WorkflowToolManageService:
|
||||
:return: the updated tool
|
||||
"""
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
existing_workflow_tool_provider = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App | None = (
|
||||
db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if app is None:
|
||||
@ -234,9 +236,11 @@ class WorkflowToolManageService:
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
"""
|
||||
db.session.query(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@ -251,10 +255,10 @@ class WorkflowToolManageService:
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_tool: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@ -267,10 +271,10 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_tool: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@ -284,8 +288,8 @@ class WorkflowToolManageService:
|
||||
if db_tool is None:
|
||||
raise ValueError("Tool not found")
|
||||
|
||||
workflow_app: App | None = (
|
||||
db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
|
||||
workflow_app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
@ -331,10 +335,10 @@ class WorkflowToolManageService:
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_tool: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import mimetypes
|
||||
import secrets
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import orjson
|
||||
from flask import request
|
||||
@ -50,6 +50,14 @@ logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class RawWebhookDataDict(TypedDict):
|
||||
method: str
|
||||
headers: dict[str, str]
|
||||
query_params: dict[str, str]
|
||||
body: dict[str, Any]
|
||||
files: dict[str, Any]
|
||||
|
||||
|
||||
class WebhookService:
|
||||
"""Service for handling webhook operations."""
|
||||
|
||||
@ -145,7 +153,7 @@ class WebhookService:
|
||||
@classmethod
|
||||
def extract_and_validate_webhook_data(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
|
||||
) -> dict[str, Any]:
|
||||
) -> RawWebhookDataDict:
|
||||
"""Extract and validate webhook data in a single unified process.
|
||||
|
||||
Args:
|
||||
@ -173,7 +181,7 @@ class WebhookService:
|
||||
return processed_data
|
||||
|
||||
@classmethod
|
||||
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]:
|
||||
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> RawWebhookDataDict:
|
||||
"""Extract raw data from incoming webhook request without type conversion.
|
||||
|
||||
Args:
|
||||
@ -189,7 +197,7 @@ class WebhookService:
|
||||
"""
|
||||
cls._validate_content_length()
|
||||
|
||||
data = {
|
||||
data: RawWebhookDataDict = {
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"query_params": dict(request.args),
|
||||
@ -223,7 +231,7 @@ class WebhookService:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
|
||||
def _process_and_validate_data(cls, raw_data: RawWebhookDataDict, node_data: WebhookData) -> RawWebhookDataDict:
|
||||
"""Process and validate webhook data according to node configuration.
|
||||
|
||||
Args:
|
||||
@ -664,7 +672,7 @@ class WebhookService:
|
||||
raise ValueError(f"Required header missing: {header_name}")
|
||||
|
||||
@classmethod
|
||||
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
|
||||
def _validate_http_metadata(cls, webhook_data: RawWebhookDataDict, node_data: WebhookData) -> dict[str, Any]:
|
||||
"""Validate HTTP method and content-type.
|
||||
|
||||
Args:
|
||||
@ -729,7 +737,7 @@ class WebhookService:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]:
|
||||
def build_workflow_inputs(cls, webhook_data: RawWebhookDataDict) -> dict[str, Any]:
|
||||
"""Construct workflow inputs payload from webhook data.
|
||||
|
||||
Args:
|
||||
@ -747,7 +755,7 @@ class WebhookService:
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_execution(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: RawWebhookDataDict, workflow: Workflow
|
||||
) -> None:
|
||||
"""Trigger workflow execution via AsyncWorkflowService.
|
||||
|
||||
|
||||
@ -129,6 +129,7 @@ class VariableTruncator(BaseTruncator):
|
||||
used_size += self.calculate_json_size(key)
|
||||
if used_size > budget:
|
||||
truncated_mapping[key] = "..."
|
||||
is_truncated = True
|
||||
continue
|
||||
value_budget = (budget - used_size) // (length - len(truncated_mapping))
|
||||
if isinstance(value, Segment):
|
||||
@ -164,9 +165,9 @@ class VariableTruncator(BaseTruncator):
|
||||
result = self._truncate_segment(segment, self._max_size_bytes)
|
||||
|
||||
if result.value_size > self._max_size_bytes:
|
||||
if isinstance(result.value, str):
|
||||
result = self._truncate_string(result.value, self._max_size_bytes)
|
||||
return TruncationResult(StringSegment(value=result.value), True)
|
||||
if isinstance(result.value, StringSegment):
|
||||
fallback_result = self._truncate_string(result.value.value, self._max_size_bytes)
|
||||
return TruncationResult(StringSegment(value=fallback_result.value), True)
|
||||
|
||||
# Apply final fallback - convert to JSON string and truncate
|
||||
json_str = dumps_with_segments(result.value, ensure_ascii=False)
|
||||
|
||||
@ -20,7 +20,7 @@ def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
||||
app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from models.account import Account, TenantAccountRole
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
||||
flask_app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
|
||||
return flask_app
|
||||
|
||||
|
||||
|
||||
17
api/tests/unit_tests/extensions/test_ext_login.py
Normal file
@ -0,0 +1,17 @@
|
||||
import json
|
||||
|
||||
from flask import Response
|
||||
|
||||
from extensions.ext_login import unauthorized_handler
|
||||
|
||||
|
||||
def test_unauthorized_handler_returns_json_response() -> None:
|
||||
response = unauthorized_handler()
|
||||
|
||||
assert isinstance(response, Response)
|
||||
assert response.status_code == 401
|
||||
assert response.content_type == "application/json"
|
||||
assert json.loads(response.get_data(as_text=True)) == {
|
||||
"code": "unauthorized",
|
||||
"message": "Unauthorized.",
|
||||
}
|
||||
@ -2,11 +2,12 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from flask_login import LoginManager, UserMixin
|
||||
from flask import Flask, Response, g
|
||||
from flask_login import UserMixin
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
import libs.login as login_module
|
||||
from extensions.ext_login import DifyLoginManager
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
|
||||
@ -39,9 +40,12 @@ def login_app(mocker: MockerFixture) -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
|
||||
login_manager = LoginManager()
|
||||
login_manager = DifyLoginManager()
|
||||
login_manager.init_app(app)
|
||||
login_manager.unauthorized = mocker.Mock(name="unauthorized", return_value="Unauthorized")
|
||||
login_manager.unauthorized = mocker.Mock(
|
||||
name="unauthorized",
|
||||
return_value=Response("Unauthorized", status=401, content_type="application/json"),
|
||||
)
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(_user_id: str):
|
||||
@ -109,18 +113,43 @@ class TestLoginRequired:
|
||||
resolved_user: MockUser | None,
|
||||
description: str,
|
||||
):
|
||||
"""Test that missing or unauthenticated users are redirected."""
|
||||
"""Test that missing or unauthenticated users return the manager response."""
|
||||
|
||||
resolve_user = resolve_current_user(resolved_user)
|
||||
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
|
||||
assert result == "Unauthorized", description
|
||||
assert result is login_app.login_manager.unauthorized.return_value, description
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 401
|
||||
resolve_user.assert_called_once_with()
|
||||
login_app.login_manager.unauthorized.assert_called_once_with()
|
||||
csrf_check.assert_not_called()
|
||||
|
||||
def test_unauthorized_access_propagates_response_object(
|
||||
self,
|
||||
login_app: Flask,
|
||||
protected_view,
|
||||
csrf_check: MagicMock,
|
||||
resolve_current_user,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test that unauthorized responses are propagated as Flask Response objects."""
|
||||
resolve_user = resolve_current_user(None)
|
||||
response = Response("Unauthorized", status=401, content_type="application/json")
|
||||
mocker.patch.object(
|
||||
login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response)
|
||||
)
|
||||
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
|
||||
assert result is response
|
||||
assert isinstance(result, Response)
|
||||
resolve_user.assert_called_once_with()
|
||||
csrf_check.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method", "login_disabled"),
|
||||
[
|
||||
@ -168,10 +197,14 @@ class TestGetUser:
|
||||
"""Test that _get_user loads user if not already in g."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
def _load_user() -> None:
|
||||
def load_user_from_request_context() -> None:
|
||||
g._login_user = mock_user
|
||||
|
||||
load_user = mocker.patch.object(login_app.login_manager, "_load_user", side_effect=_load_user)
|
||||
load_user = mocker.patch.object(
|
||||
login_app.login_manager,
|
||||
"load_user_from_request_context",
|
||||
side_effect=load_user_from_request_context,
|
||||
)
|
||||
|
||||
with login_app.test_request_context():
|
||||
user = login_module._get_user()
|
||||
|
||||
@ -401,10 +401,7 @@ class TestMetadataServiceCreateMetadata:
|
||||
metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string")
|
||||
|
||||
# Mock query to return None (no existing metadata with same name)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
# Mock BuiltInField enum iteration
|
||||
with patch("services.metadata_service.BuiltInField") as mock_builtin:
|
||||
@ -417,10 +414,6 @@ class TestMetadataServiceCreateMetadata:
|
||||
assert result is not None
|
||||
assert isinstance(result, DatasetMetadata)
|
||||
|
||||
# Verify query was made to check for duplicates
|
||||
mock_db_session.query.assert_called()
|
||||
mock_query.filter_by.assert_called()
|
||||
|
||||
# Verify metadata was added and committed
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
@ -468,10 +461,7 @@ class TestMetadataServiceCreateMetadata:
|
||||
|
||||
# Mock existing metadata with same name
|
||||
existing_metadata = MetadataTestDataFactory.create_metadata_mock(name="category")
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_metadata
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = existing_metadata
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Metadata name already exists"):
|
||||
@ -500,10 +490,7 @@ class TestMetadataServiceCreateMetadata:
|
||||
)
|
||||
|
||||
# Mock query to return None (no duplicate in database)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
# Mock BuiltInField to include the conflicting name
|
||||
with patch("services.metadata_service.BuiltInField") as mock_builtin:
|
||||
@ -597,27 +584,11 @@ class TestMetadataServiceUpdateMetadataName:
|
||||
|
||||
existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
|
||||
|
||||
# Mock query for duplicate check (no duplicate)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Mock metadata retrieval
|
||||
def query_side_effect(model):
|
||||
if model == DatasetMetadata:
|
||||
mock_meta_query = Mock()
|
||||
mock_meta_query.filter_by.return_value = mock_meta_query
|
||||
mock_meta_query.first.return_value = existing_metadata
|
||||
return mock_meta_query
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = query_side_effect
|
||||
# Mock scalar calls: first for duplicate check (None), second for metadata retrieval
|
||||
mock_db_session.scalar.side_effect = [None, existing_metadata]
|
||||
|
||||
# Mock no metadata bindings (no documents to update)
|
||||
mock_binding_query = Mock()
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.all.return_value = []
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Mock BuiltInField enum
|
||||
with patch("services.metadata_service.BuiltInField") as mock_builtin:
|
||||
@ -655,22 +626,8 @@ class TestMetadataServiceUpdateMetadataName:
|
||||
metadata_id = "non-existent-metadata"
|
||||
new_name = "updated_category"
|
||||
|
||||
# Mock query for duplicate check (no duplicate)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Mock metadata retrieval to return None
|
||||
def query_side_effect(model):
|
||||
if model == DatasetMetadata:
|
||||
mock_meta_query = Mock()
|
||||
mock_meta_query.filter_by.return_value = mock_meta_query
|
||||
mock_meta_query.first.return_value = None # Not found
|
||||
return mock_meta_query
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = query_side_effect
|
||||
# Mock scalar calls: first for duplicate check (None), second for metadata retrieval (None = not found)
|
||||
mock_db_session.scalar.side_effect = [None, None]
|
||||
|
||||
# Mock BuiltInField enum
|
||||
with patch("services.metadata_service.BuiltInField") as mock_builtin:
|
||||
@ -746,15 +703,10 @@ class TestMetadataServiceDeleteMetadata:
|
||||
existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
|
||||
|
||||
# Mock metadata retrieval
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_metadata
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = existing_metadata
|
||||
|
||||
# Mock no metadata bindings (no documents to update)
|
||||
mock_binding_query = Mock()
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.all.return_value = []
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = MetadataService.delete_metadata(dataset_id, metadata_id)
|
||||
@ -788,10 +740,7 @@ class TestMetadataServiceDeleteMetadata:
|
||||
metadata_id = "non-existent-metadata"
|
||||
|
||||
# Mock metadata retrieval to return None
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Metadata not found"):
|
||||
@ -1013,10 +962,7 @@ class TestMetadataServiceGetDatasetMetadatas:
|
||||
)
|
||||
|
||||
# Mock usage count queries
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.count.return_value = 5 # 5 documents use this metadata
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = 5 # 5 documents use this metadata
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
|
||||
@ -292,7 +292,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
|
||||
mock_db_session.scalar.return_value = api
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
|
||||
assert result is api
|
||||
@ -302,7 +302,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
When the record is absent, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
|
||||
@ -320,7 +320,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
existing_api = Mock(spec=ExternalKnowledgeApis)
|
||||
existing_api.settings_dict = {"api_key": "stored-key"}
|
||||
existing_api.settings = '{"api_key":"stored-key"}'
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
|
||||
mock_db_session.scalar.return_value = existing_api
|
||||
|
||||
args = {
|
||||
"name": "New Name",
|
||||
@ -340,7 +340,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
Updating a non‑existent API template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.update_external_knowledge_api(
|
||||
@ -356,7 +356,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
|
||||
mock_db_session.scalar.return_value = api
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
|
||||
|
||||
@ -368,7 +368,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
Deletion of a missing template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
|
||||
@ -394,7 +394,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
|
||||
mock_db_session.scalar.return_value = 3
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
|
||||
@ -406,7 +406,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
Zero bindings should return ``(False, 0)``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
|
||||
@ -419,7 +419,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
"""
|
||||
|
||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
|
||||
mock_db_session.scalar.return_value = binding
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
assert result is binding
|
||||
@ -429,7 +429,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
Missing binding should result in a ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
@ -460,7 +460,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
# Raw string; the service itself calls json.loads on it
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
|
||||
mock_db_session.scalar.return_value = external_api
|
||||
|
||||
process_parameter = {"foo": "value", "bar": "optional"}
|
||||
|
||||
@ -474,7 +474,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
When the referenced API template is missing, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
|
||||
@ -488,7 +488,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
external_api.settings = (
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
|
||||
mock_db_session.scalar.return_value = external_api
|
||||
|
||||
process_parameter = {"bar": "present"} # missing "foo"
|
||||
|
||||
@ -702,7 +702,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
}
|
||||
|
||||
# No existing dataset with same name.
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None, # duplicate‑name check
|
||||
Mock(spec=ExternalKnowledgeApis), # external knowledge api
|
||||
]
|
||||
@ -724,7 +724,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
"""
|
||||
|
||||
existing_dataset = Mock(spec=Dataset)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
|
||||
mock_db_session.scalar.return_value = existing_dataset
|
||||
|
||||
args = {
|
||||
"name": "Existing",
|
||||
@ -744,7 +744,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
"""
|
||||
|
||||
# First call: duplicate name check – not found.
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None,
|
||||
None, # external knowledge api lookup
|
||||
]
|
||||
@ -763,8 +763,10 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
|
||||
"""
|
||||
|
||||
# duplicate name check
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None,
|
||||
Mock(spec=ExternalKnowledgeApis),
|
||||
None,
|
||||
Mock(spec=ExternalKnowledgeApis),
|
||||
]
|
||||
@ -826,7 +828,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
# First query: binding; second query: api.
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
@ -861,7 +863,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
Missing binding should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
@ -878,7 +880,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
"""
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding()
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
None,
|
||||
]
|
||||
@ -901,7 +903,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
|
||||
@ -117,9 +117,7 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv
|
||||
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None:
|
||||
first_query = mocker.Mock()
|
||||
first_query.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=first_query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
rag_pipeline_service.get_pipeline("tenant-1", "dataset-1")
|
||||
@ -131,12 +129,8 @@ def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service
|
||||
def test_update_customized_pipeline_template_success(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
|
||||
# First query finds the template, second query (duplicate check) returns None
|
||||
query_mock_1 = mocker.Mock()
|
||||
query_mock_1.where.return_value.first.return_value = template
|
||||
query_mock_2 = mocker.Mock()
|
||||
query_mock_2.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", side_effect=[query_mock_1, query_mock_2])
|
||||
# First scalar finds the template, second scalar (duplicate check) returns None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None])
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
@ -152,9 +146,7 @@ def test_update_customized_pipeline_template_success(mocker) -> None:
|
||||
|
||||
|
||||
def test_update_customized_pipeline_template_not_found(mocker) -> None:
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i"))
|
||||
@ -166,9 +158,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
duplicate = SimpleNamespace(name="dup")
|
||||
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.first.side_effect = [template, duplicate]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate])
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i"))
|
||||
@ -181,9 +171,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
|
||||
|
||||
def test_delete_customized_pipeline_template_success(mocker) -> None:
|
||||
template = SimpleNamespace(id="tpl-1")
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.first.return_value = template
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
|
||||
delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete")
|
||||
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
|
||||
@ -196,9 +184,7 @@ def test_delete_customized_pipeline_template_success(mocker) -> None:
|
||||
|
||||
|
||||
def test_delete_customized_pipeline_template_not_found(mocker) -> None:
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
@ -397,18 +383,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
|
||||
|
||||
|
||||
def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None:
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.count.return_value = 1
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1)
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
assert rag_pipeline_service.is_workflow_exist(pipeline) is True
|
||||
|
||||
|
||||
def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None:
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.count.return_value = 0
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0)
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
assert rag_pipeline_service.is_workflow_exist(pipeline) is False
|
||||
@ -738,8 +720,7 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non
|
||||
|
||||
|
||||
def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None:
|
||||
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
|
||||
from models.workflow import Workflow
|
||||
from models.dataset import Pipeline
|
||||
|
||||
# 1. Setup mocks
|
||||
pipeline = mocker.Mock(spec=Pipeline)
|
||||
@ -754,36 +735,15 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
|
||||
# Mock db itself to avoid app context errors
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
|
||||
# Improved mocking for session.query
|
||||
def mock_query_side_effect(model):
|
||||
m = mocker.Mock()
|
||||
if model == Pipeline:
|
||||
m.where.return_value.first.return_value = pipeline
|
||||
elif model == Workflow:
|
||||
m.where.return_value.first.return_value = workflow
|
||||
elif model == PipelineCustomizedTemplate:
|
||||
m.where.return_value.first.return_value = None
|
||||
elif model == Dataset:
|
||||
m.where.return_value.first.return_value = mocker.Mock()
|
||||
else:
|
||||
# For func.max cases
|
||||
m.where.return_value.scalar.return_value = 5
|
||||
m.where.return_value.first.return_value = mocker.Mock()
|
||||
return m
|
||||
|
||||
mock_db.session.query.side_effect = mock_query_side_effect
|
||||
# Mock get() for Pipeline and Workflow PK lookups
|
||||
mock_db.session.get.side_effect = [pipeline, workflow]
|
||||
# Mock scalar() for template name check (None) and max position (5)
|
||||
mock_db.session.scalar.side_effect = [None, 5]
|
||||
|
||||
# Mock retrieve_dataset
|
||||
dataset = mocker.Mock()
|
||||
pipeline.retrieve_dataset.return_value = dataset
|
||||
|
||||
# Mock max position
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.func.max", return_value=1)
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.db.session.query.return_value.where.return_value.scalar",
|
||||
return_value=5,
|
||||
)
|
||||
|
||||
# Mock RagPipelineDslService
|
||||
mock_dsl_service = mocker.Mock()
|
||||
mock_dsl_service.export_rag_pipeline_dsl.return_value = {"dsl": "content"}
|
||||
@ -839,9 +799,7 @@ def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None:
|
||||
workflow.rag_pipeline_variables = []
|
||||
|
||||
# Mock queries
|
||||
mock_query = mocker.Mock()
|
||||
mock_query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
|
||||
@ -881,11 +839,9 @@ def test_retry_error_document_success(mocker, rag_pipeline_service) -> None:
|
||||
|
||||
workflow = mocker.Mock()
|
||||
|
||||
# Mock queries
|
||||
mock_query = mocker.Mock()
|
||||
# Log lookup, then Pipeline lookup
|
||||
mock_query.where.return_value.first.side_effect = [log, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
|
||||
# Mock queries: Log lookup via scalar, Pipeline lookup via get
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
|
||||
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
|
||||
@ -913,7 +869,7 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None:
|
||||
# Mock db aggressively
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.engine = mocker.Mock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mocker.Mock()
|
||||
mock_db.session.scalar.return_value = mocker.Mock()
|
||||
|
||||
pipeline = mocker.Mock(spec=Pipeline)
|
||||
pipeline.id = "p-1"
|
||||
@ -976,7 +932,7 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
|
||||
workflow = mocker.Mock(spec=Workflow)
|
||||
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
|
||||
mock_db.session.scalar.return_value = workflow
|
||||
|
||||
# 2. Run test
|
||||
result = rag_pipeline_service.get_draft_workflow(pipeline)
|
||||
@ -998,7 +954,7 @@ def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None:
|
||||
workflow = mocker.Mock(spec=Workflow)
|
||||
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
|
||||
mock_db.session.scalar.return_value = workflow
|
||||
|
||||
# 2. Run test
|
||||
result = rag_pipeline_service.get_published_workflow(pipeline)
|
||||
@ -1319,11 +1275,8 @@ def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions
|
||||
|
||||
|
||||
def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value.all.return_value = []
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value = query
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = rag_pipeline_service.get_recommended_plugins("all")
|
||||
|
||||
@ -1336,11 +1289,8 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
|
||||
def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None:
|
||||
plugin_a = SimpleNamespace(plugin_id="plugin-a")
|
||||
plugin_b = SimpleNamespace(plugin_id="plugin-b")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value.all.return_value = [plugin_a, plugin_b]
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value = query
|
||||
mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools",
|
||||
@ -1568,9 +1518,7 @@ def test_get_second_step_parameters_filters_first_step_variables(mocker, rag_pip
|
||||
|
||||
|
||||
def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Document pipeline execution log not found"):
|
||||
rag_pipeline_service.retry_error_document(
|
||||
@ -1581,9 +1529,7 @@ def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pi
|
||||
def test_get_datasource_plugins_raises_when_workflow_not_found(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline or workflow not found"):
|
||||
@ -1656,8 +1602,7 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
|
||||
|
||||
document = SimpleNamespace(indexing_status="waiting", error=None)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = document
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
|
||||
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
|
||||
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
|
||||
@ -1712,9 +1657,7 @@ def test_run_datasource_node_preview_raises_for_unsupported_provider(mocker, rag
|
||||
|
||||
|
||||
def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
|
||||
@ -1722,9 +1665,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker
|
||||
|
||||
def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = pipeline
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline workflow not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"})
|
||||
@ -1732,8 +1673,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
rag_pipeline_service.get_pipeline("t1", "d1")
|
||||
@ -1742,8 +1682,7 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
|
||||
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, None]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline not found"):
|
||||
rag_pipeline_service.get_pipeline("t1", "d1")
|
||||
@ -1783,8 +1722,7 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
|
||||
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = template
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
|
||||
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
@ -2011,8 +1949,7 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
|
||||
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [pipeline, None]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
|
||||
@ -2021,11 +1958,9 @@ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocke
|
||||
def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
|
||||
workflow = SimpleNamespace(id="wf-1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [pipeline, workflow]
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.engine = mocker.Mock()
|
||||
mock_db.session.query.return_value = query
|
||||
mock_db.session.get.side_effect = [pipeline, workflow]
|
||||
session_ctx = mocker.MagicMock()
|
||||
session_ctx.__enter__.return_value = SimpleNamespace()
|
||||
session_ctx.__exit__.return_value = False
|
||||
@ -2038,11 +1973,8 @@ def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker
|
||||
|
||||
def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None:
|
||||
plugin = SimpleNamespace(plugin_id="plugin-a")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value.all.return_value = [plugin]
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value = query
|
||||
mock_db.session.scalars.return_value.all.return_value = [plugin]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[])
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[])
|
||||
@ -2056,8 +1988,8 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
|
||||
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
|
||||
exec_log = SimpleNamespace(pipeline_id="p1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [exec_log, None]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline not found"):
|
||||
rag_pipeline_service.retry_error_document(
|
||||
@ -2069,8 +2001,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_
|
||||
exec_log = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [exec_log, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
@ -2086,8 +2018,7 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
|
||||
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
|
||||
)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
|
||||
assert rag_pipeline_service.get_datasource_plugins("t1", "d1", True) == []
|
||||
@ -2250,8 +2181,7 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
|
||||
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
|
||||
)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", return_value=[]
|
||||
@ -2291,8 +2221,7 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
|
||||
],
|
||||
)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials",
|
||||
@ -2310,8 +2239,7 @@ def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service)
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
|
||||
result = rag_pipeline_service.get_pipeline("t1", "d1")
|
||||
|
||||
|
||||
@ -173,9 +173,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "email", "test@example.com"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||
|
||||
mock_password_dependencies["compare_password"].return_value = True
|
||||
|
||||
@ -188,9 +186,7 @@ class TestAccountService:
|
||||
|
||||
def test_authenticate_account_not_found(self, mock_db_dependencies):
|
||||
"""Test authentication when account does not exist."""
|
||||
# Setup smart database query mock - no matching results
|
||||
query_results = {("Account", "email", "notfound@example.com"): None}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
@ -202,9 +198,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "email", "banned@example.com"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
|
||||
@ -214,9 +208,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "email", "test@example.com"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||
|
||||
mock_password_dependencies["compare_password"].return_value = False
|
||||
|
||||
@ -230,9 +222,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "email", "pending@example.com"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||
|
||||
mock_password_dependencies["compare_password"].return_value = True
|
||||
|
||||
@ -422,12 +412,8 @@ class TestAccountService:
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {
|
||||
("Account", "id", "user-123"): mock_account,
|
||||
("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
|
||||
}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant_join
|
||||
|
||||
# Mock datetime
|
||||
with patch("services.account_service.datetime") as mock_datetime:
|
||||
@ -444,9 +430,7 @@ class TestAccountService:
|
||||
|
||||
def test_load_user_not_found(self, mock_db_dependencies):
|
||||
"""Test user loading when user does not exist."""
|
||||
# Setup smart database query mock - no matching results
|
||||
query_results = {("Account", "id", "non-existent-user"): None}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = AccountService.load_user("non-existent-user")
|
||||
@ -459,9 +443,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "id", "user-123"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
@ -476,13 +458,9 @@ class TestAccountService:
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
|
||||
|
||||
# Setup smart database query mock for complex scenario
|
||||
query_results = {
|
||||
("Account", "id", "user-123"): mock_account,
|
||||
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
|
||||
("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
|
||||
}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||
# First scalar: current tenant (None), second scalar: available tenant
|
||||
mock_db_dependencies["db"].session.scalar.side_effect = [None, mock_available_tenant]
|
||||
|
||||
# Mock datetime
|
||||
with patch("services.account_service.datetime") as mock_datetime:
|
||||
@ -503,13 +481,9 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
|
||||
# Setup smart database query mock for no tenants scenario
|
||||
query_results = {
|
||||
("Account", "id", "user-123"): mock_account,
|
||||
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
|
||||
("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
|
||||
}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||
# First scalar: current tenant (None), second scalar: available tenant (None)
|
||||
mock_db_dependencies["db"].session.scalar.side_effect = [None, None]
|
||||
|
||||
# Mock datetime
|
||||
with patch("services.account_service.datetime") as mock_datetime:
|
||||
@ -1060,7 +1034,7 @@ class TestRegisterService:
|
||||
)
|
||||
|
||||
# Verify rollback operations were called
|
||||
mock_db_dependencies["db"].session.query.assert_called()
|
||||
mock_db_dependencies["db"].session.execute.assert_called()
|
||||
|
||||
# ==================== Registration Tests ====================
|
||||
|
||||
@ -1625,10 +1599,8 @@ class TestRegisterService:
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
|
||||
# Mock the db.session.query for TenantAccountJoin
|
||||
mock_db_query = MagicMock()
|
||||
mock_db_query.filter_by.return_value.first.return_value = None # No existing member
|
||||
mock_db_dependencies["db"].session.query.return_value = mock_db_query
|
||||
# Mock scalar for TenantAccountJoin lookup - no existing member
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Mock TenantService methods
|
||||
with (
|
||||
@ -1803,14 +1775,9 @@ class TestRegisterService:
|
||||
}
|
||||
mock_get_invitation_by_token.return_value = invitation_data
|
||||
|
||||
# Mock database queries - complex query mocking
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
# Mock scalar for tenant lookup, execute for account+role lookup
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
|
||||
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
@ -1842,10 +1809,8 @@ class TestRegisterService:
|
||||
}
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Mock database queries - no tenant found
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = None
|
||||
mock_db_dependencies["db"].session.query.return_value = mock_query
|
||||
# Mock scalar for tenant lookup - not found
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
@ -1868,14 +1833,9 @@ class TestRegisterService:
|
||||
}
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Mock database queries
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = None # No account found
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
# Mock scalar for tenant, execute for account+role
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
|
||||
mock_db_dependencies["db"].session.execute.return_value.first.return_value = None # No account found
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
@ -1901,14 +1861,9 @@ class TestRegisterService:
|
||||
}
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Mock database queries
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
# Mock scalar for tenant, execute for account+role
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
|
||||
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
|
||||
@ -799,10 +799,7 @@ class TestExternalDatasetServiceGetAPI:
|
||||
api_id = "api-123"
|
||||
expected_api = factory.create_external_knowledge_api_mock(api_id=api_id)
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = expected_api
|
||||
mock_db.session.scalar.return_value = expected_api
|
||||
|
||||
# Act
|
||||
tenant_id = "tenant-123"
|
||||
@ -810,16 +807,12 @@ class TestExternalDatasetServiceGetAPI:
|
||||
|
||||
# Assert
|
||||
assert result.id == api_id
|
||||
mock_query.filter_by.assert_called_once_with(id=api_id, tenant_id=tenant_id)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_api_not_found(self, mock_db, factory):
|
||||
"""Test error when API is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
@ -848,10 +841,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
"settings": {"endpoint": "https://new.example.com", "api_key": "new-key"},
|
||||
}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_api
|
||||
mock_db.session.scalar.return_value = existing_api
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
|
||||
@ -881,10 +871,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
|
||||
}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_api
|
||||
mock_db.session.scalar.return_value = existing_api
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, "user-123", api_id, args)
|
||||
@ -897,10 +884,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
def test_update_external_knowledge_api_not_found(self, mock_db, factory):
|
||||
"""Test error when API is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
args = {"name": "Updated API"}
|
||||
|
||||
@ -912,10 +896,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
def test_update_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
|
||||
"""Test error when tenant ID doesn't match."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
args = {"name": "Updated API"}
|
||||
|
||||
@ -934,10 +915,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
|
||||
args = {"name": "New Name Only"}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_api
|
||||
mock_db.session.scalar.return_value = existing_api
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.update_external_knowledge_api("tenant-123", "user-123", "api-123", args)
|
||||
@ -958,10 +936,7 @@ class TestExternalDatasetServiceDeleteAPI:
|
||||
|
||||
existing_api = factory.create_external_knowledge_api_mock(api_id=api_id, tenant_id=tenant_id)
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_api
|
||||
mock_db.session.scalar.return_value = existing_api
|
||||
|
||||
# Act
|
||||
ExternalDatasetService.delete_external_knowledge_api(tenant_id, api_id)
|
||||
@ -974,10 +949,7 @@ class TestExternalDatasetServiceDeleteAPI:
|
||||
def test_delete_external_knowledge_api_not_found(self, mock_db, factory):
|
||||
"""Test error when API is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
@ -987,10 +959,7 @@ class TestExternalDatasetServiceDeleteAPI:
|
||||
def test_delete_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
|
||||
"""Test error when tenant ID doesn't match."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
@ -1006,10 +975,7 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.count.return_value = 1
|
||||
mock_db.session.scalar.return_value = 1
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
@ -1024,10 +990,7 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.count.return_value = 10
|
||||
mock_db.session.scalar.return_value = 10
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
@ -1042,10 +1005,7 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.count.return_value = 0
|
||||
mock_db.session.scalar.return_value = 0
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
@ -1067,10 +1027,7 @@ class TestExternalDatasetServiceGetBinding:
|
||||
|
||||
expected_binding = factory.create_external_knowledge_binding_mock(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = expected_binding
|
||||
mock_db.session.scalar.return_value = expected_binding
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id(tenant_id, dataset_id)
|
||||
@ -1083,10 +1040,7 @@ class TestExternalDatasetServiceGetBinding:
|
||||
def test_get_external_knowledge_binding_not_found(self, mock_db, factory):
|
||||
"""Test error when binding is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
@ -1113,10 +1067,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
|
||||
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = api
|
||||
mock_db.session.scalar.return_value = api
|
||||
|
||||
process_parameter = {"param1": "value1", "param2": "value2"}
|
||||
|
||||
@ -1134,10 +1085,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
|
||||
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = api
|
||||
mock_db.session.scalar.return_value = api
|
||||
|
||||
process_parameter = {}
|
||||
|
||||
@ -1149,10 +1097,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
def test_document_create_args_validate_api_not_found(self, mock_db, factory):
|
||||
"""Test validation fails when API is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
@ -1165,10 +1110,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
settings = {}
|
||||
api = factory.create_external_knowledge_api_mock(settings=[settings])
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = api
|
||||
mock_db.session.scalar.return_value = api
|
||||
|
||||
# Act & Assert - should not raise
|
||||
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
|
||||
@ -1186,10 +1128,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
|
||||
api = factory.create_external_knowledge_api_mock(settings=[settings])
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = api
|
||||
mock_db.session.scalar.return_value = api
|
||||
|
||||
process_parameter = {"required_param": "value"}
|
||||
|
||||
@ -1498,24 +1437,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
|
||||
api = factory.create_external_knowledge_api_mock(api_id="api-123")
|
||||
|
||||
# Mock database queries
|
||||
mock_dataset_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == Dataset:
|
||||
return mock_dataset_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_dataset_query.filter_by.return_value = mock_dataset_query
|
||||
mock_dataset_query.first.return_value = None
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [None, api]
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
|
||||
@ -1534,10 +1456,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
# Arrange
|
||||
existing_dataset = factory.create_dataset_mock(name="Duplicate Dataset")
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_dataset
|
||||
mock_db.session.scalar.return_value = existing_dataset
|
||||
|
||||
args = {"name": "Duplicate Dataset"}
|
||||
|
||||
@ -1549,23 +1468,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
def test_create_external_dataset_api_not_found_error(self, mock_db, factory):
|
||||
"""Test error when external knowledge API is not found."""
|
||||
# Arrange
|
||||
mock_dataset_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == Dataset:
|
||||
return mock_dataset_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_dataset_query.filter_by.return_value = mock_dataset_query
|
||||
mock_dataset_query.first.return_value = None
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = None
|
||||
mock_db.session.scalar.side_effect = [None, None]
|
||||
|
||||
args = {"name": "Test Dataset", "external_knowledge_api_id": "nonexistent-api"}
|
||||
|
||||
@ -1579,23 +1482,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
# Arrange
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_dataset_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == Dataset:
|
||||
return mock_dataset_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_dataset_query.filter_by.return_value = mock_dataset_query
|
||||
mock_dataset_query.first.return_value = None
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [None, api]
|
||||
|
||||
args = {"name": "Test Dataset", "external_knowledge_api_id": "api-123"}
|
||||
|
||||
@ -1609,23 +1496,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
# Arrange
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_dataset_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == Dataset:
|
||||
return mock_dataset_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_dataset_query.filter_by.return_value = mock_dataset_query
|
||||
mock_dataset_query.first.return_value = None
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [None, api]
|
||||
|
||||
args = {"name": "Test Dataset", "external_knowledge_id": "knowledge-123"}
|
||||
|
||||
@ -1651,23 +1522,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
)
|
||||
api = factory.create_external_knowledge_api_mock(api_id="api-123")
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@ -1695,10 +1550,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
def test_fetch_external_knowledge_retrieval_binding_not_found_error(self, mock_db, factory):
|
||||
"""Test error when external knowledge binding is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
@ -1712,23 +1564,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@ -1751,23 +1587,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@ -1799,23 +1619,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@ -1856,23 +1660,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
)
|
||||
api = factory.create_external_knowledge_api_mock(api_id="api-123")
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
@ -1891,23 +1679,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 503
|
||||
|
||||
@ -85,3 +85,644 @@ def test_get_provider_list_strips_credentials(service_with_fake_configurations:
|
||||
assert len(custom_models) == 1
|
||||
# The sanitizer should drop credentials in list response
|
||||
assert custom_models[0].credentials is None
|
||||
|
||||
|
||||
# === Merged from test_model_provider_service.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
|
||||
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from models.provider import ProviderType
|
||||
from services import model_provider_service as service_module
|
||||
from services.errors.app_model_config import ProviderNotFoundError
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
|
||||
manager = MagicMock()
|
||||
service = ModelProviderService()
|
||||
service._get_provider_manager = MagicMock(return_value=manager)
|
||||
return service, manager
|
||||
|
||||
|
||||
def _build_provider_configuration(
|
||||
*,
|
||||
provider_name: str = "openai",
|
||||
supported_model_types: list[ModelType] | None = None,
|
||||
custom_models: list[Any] | None = None,
|
||||
custom_config_available: bool = True,
|
||||
) -> SimpleNamespace:
|
||||
if supported_model_types is None:
|
||||
supported_model_types = [ModelType.LLM]
|
||||
return SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
provider=provider_name,
|
||||
label=I18nObject(en_US=provider_name),
|
||||
description=None,
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
background=None,
|
||||
help=None,
|
||||
supported_model_types=supported_model_types,
|
||||
configurate_methods=[],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
custom_configuration=SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="Credential 1",
|
||||
available_credentials=[],
|
||||
),
|
||||
models=custom_models,
|
||||
can_added_models=[],
|
||||
),
|
||||
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
|
||||
is_custom_configuration_available=lambda: custom_config_available,
|
||||
)
|
||||
|
||||
|
||||
def test__get_provider_configuration_should_return_configuration_when_provider_exists() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
provider_configuration = SimpleNamespace(name="provider-config")
|
||||
manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
# Act
|
||||
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
# Assert
|
||||
assert result is provider_configuration
|
||||
|
||||
|
||||
def test__get_provider_configuration_should_raise_error_when_provider_is_missing() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_configurations.return_value = {}
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ProviderNotFoundError, match="does not exist"):
|
||||
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
|
||||
|
||||
|
||||
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
allowed = _build_provider_configuration(
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
custom_config_available=False,
|
||||
)
|
||||
filtered = _build_provider_configuration(
|
||||
provider_name="embedding",
|
||||
supported_model_types=[ModelType.TEXT_EMBEDDING],
|
||||
custom_config_available=True,
|
||||
)
|
||||
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
|
||||
|
||||
# Act
|
||||
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert result[0].custom_configuration.status.value == "no-configure"
|
||||
|
||||
|
||||
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
class _Model:
|
||||
def __init__(self, model_name: str) -> None:
|
||||
self.model_name = model_name
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"label": {"en_US": self.model_name},
|
||||
"model_type": ModelType.LLM,
|
||||
"features": [],
|
||||
"fetch_from": FetchFrom.PREDEFINED_MODEL,
|
||||
"model_properties": {},
|
||||
"deprecated": False,
|
||||
"status": ModelStatus.ACTIVE,
|
||||
"load_balancing_enabled": False,
|
||||
"has_invalid_load_balancing_configs": False,
|
||||
"provider": {
|
||||
"provider": "openai",
|
||||
"label": {"en_US": "OpenAI"},
|
||||
"icon_small": None,
|
||||
"icon_small_dark": None,
|
||||
"supported_model_types": [ModelType.LLM],
|
||||
},
|
||||
}
|
||||
|
||||
provider_configurations = SimpleNamespace(
|
||||
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
|
||||
)
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
# Act
|
||||
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].model == "gpt-4o"
|
||||
assert result[1].provider.provider == "openai"
|
||||
provider_configurations.get_models.assert_called_once_with(provider="openai")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"get_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
{"token": "abc"},
|
||||
),
|
||||
(
|
||||
"validate_provider_credentials",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
|
||||
"validate_provider_credentials",
|
||||
({"token": "abc"},),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}, "credential_name": "A"},
|
||||
"create_provider_credential",
|
||||
({"token": "abc"}, "A"),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_provider_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "B",
|
||||
},
|
||||
"update_provider_credential",
|
||||
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"delete_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"switch_active_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_provider_credential_methods_should_delegate_to_provider_configuration(
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
provider_call_kwargs: Any,
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
# Act
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
# Assert
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
provider_method = getattr(provider_configuration, provider_method_name)
|
||||
if isinstance(provider_call_kwargs, tuple):
|
||||
provider_method.assert_called_once_with(*provider_call_kwargs)
|
||||
elif isinstance(provider_call_kwargs, dict):
|
||||
provider_method.assert_called_once_with(**provider_call_kwargs)
|
||||
else:
|
||||
provider_method.assert_called_once_with(provider_call_kwargs)
|
||||
if method_name == "get_provider_credential":
|
||||
assert result == {"token": "abc"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"get_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
{"api_key": "x"},
|
||||
),
|
||||
(
|
||||
"validate_model_credentials",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
},
|
||||
"validate_custom_model_credentials",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
"create_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
"update_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"delete_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_custom_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"switch_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"add_model_credential_to_model_list",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"add_model_credential_to_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
},
|
||||
"delete_custom_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_custom_model_methods_should_convert_model_type_and_delegate(
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
expected_kwargs: dict[str, Any],
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
# Act
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
# Assert
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
|
||||
if method_name == "get_model_credential":
|
||||
assert result == {"api_key": "x"}
|
||||
|
||||
|
||||
def test_get_models_by_model_type_should_group_active_non_deprecated_models() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
openai_provider = SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
anthropic_provider = SimpleNamespace(
|
||||
provider="anthropic",
|
||||
label=I18nObject(en_US="Anthropic"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
models = [
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="gpt-4o",
|
||||
label=I18nObject(en_US="GPT-4o"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=False,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="old-openai",
|
||||
label=I18nObject(en_US="Old OpenAI"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=anthropic_provider,
|
||||
model="old-anthropic",
|
||||
label=I18nObject(en_US="Old Anthropic"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
]
|
||||
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
# Act
|
||||
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert len(result[0].models) == 1
|
||||
assert result[0].models[0].model == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "schema", "expected_count"),
|
||||
[
|
||||
(None, None, 0),
|
||||
({"api_key": "x"}, None, 0),
|
||||
(
|
||||
{"api_key": "x"},
|
||||
SimpleNamespace(
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
)
|
||||
]
|
||||
),
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
|
||||
credentials: dict[str, Any] | None,
|
||||
schema: Any,
|
||||
expected_count: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.get_current_credentials.return_value = credentials
|
||||
provider_configuration.get_model_schema.return_value = schema
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
|
||||
|
||||
# Assert
|
||||
assert len(result) == expected_count
|
||||
provider_configuration.get_current_credentials.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4o")
|
||||
if credentials:
|
||||
provider_configuration.get_model_schema.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials=credentials,
|
||||
)
|
||||
else:
|
||||
provider_configuration.get_model_schema.assert_not_called()
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = SimpleNamespace(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
provider=SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4o"
|
||||
assert result.provider.provider == "openai"
|
||||
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = None
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.side_effect = RuntimeError("boom")
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_update_default_model_of_model_type_should_delegate_to_provider_manager() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
# Act
|
||||
service.update_default_model_of_model_type(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM.value,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Assert
|
||||
manager.update_default_model_record.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
|
||||
factory_constructor = MagicMock(return_value=factory_instance)
|
||||
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
|
||||
|
||||
# Act
|
||||
result = service.get_model_provider_icon(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
icon_type="icon_small",
|
||||
lang="en_US",
|
||||
)
|
||||
|
||||
# Assert
|
||||
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
|
||||
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
|
||||
assert result == (b"icon-bytes", "image/png")
|
||||
|
||||
|
||||
def test_switch_preferred_provider_should_convert_enum_and_delegate(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
service.switch_preferred_provider(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
preferred_provider_type=ProviderType.SYSTEM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "provider_method_name"),
|
||||
[
|
||||
("enable_model", "enable_model"),
|
||||
("disable_model", "disable_model"),
|
||||
],
|
||||
)
|
||||
def test_model_enablement_methods_should_convert_model_type_and_delegate(
|
||||
method_name: str,
|
||||
provider_method_name: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
getattr(service, method_name)(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
@ -316,7 +316,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result == expected_detail
|
||||
@ -346,7 +346,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result["name"] == f"App from {mode}"
|
||||
@ -369,7 +369,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
@ -392,7 +392,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
@ -432,9 +432,197 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result["model_config"] == complex_model_config
|
||||
assert len(result["workflows"]) == 2
|
||||
assert len(result["tools"]) == 3
|
||||
|
||||
|
||||
# === Merged from test_recommended_app_service_additional.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from services import recommended_app_service as service_module
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], result)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_db_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(session=session))
|
||||
|
||||
# Assert
|
||||
return session
|
||||
|
||||
|
||||
def _mock_factory_for_apps(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
mode: str,
|
||||
result: dict[str, Any],
|
||||
fallback_result: dict[str, Any] | None = None,
|
||||
) -> tuple[MagicMock, MagicMock]:
|
||||
retrieval_instance = MagicMock()
|
||||
retrieval_instance.get_recommended_apps_and_categories.return_value = result
|
||||
retrieval_factory = MagicMock(return_value=retrieval_instance)
|
||||
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False)
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_recommend_app_factory",
|
||||
MagicMock(return_value=retrieval_factory),
|
||||
)
|
||||
|
||||
builtin_instance = MagicMock()
|
||||
if fallback_result is not None:
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_buildin_recommend_app_retrieval",
|
||||
MagicMock(return_value=builtin_instance),
|
||||
)
|
||||
return retrieval_instance, builtin_instance
|
||||
|
||||
|
||||
def test_get_recommended_apps_and_categories_should_not_query_trial_table_when_trial_feature_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]}
|
||||
retrieval_instance, builtin_instance = _mock_factory_for_apps(
|
||||
monkeypatch,
|
||||
mode="remote",
|
||||
result=expected,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called()
|
||||
mocked_db_session.scalar.assert_not_called()
|
||||
|
||||
|
||||
def test_get_recommended_apps_and_categories_should_fallback_and_enrich_can_trial_when_trial_feature_enabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
remote_result = {"recommended_apps": [], "categories": []}
|
||||
fallback_result = {"recommended_apps": [{"app_id": "app-1"}, {"app_id": "app-2"}], "categories": ["all"]}
|
||||
_, builtin_instance = _mock_factory_for_apps(
|
||||
monkeypatch,
|
||||
mode="remote",
|
||||
result=remote_result,
|
||||
fallback_result=fallback_result,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
mocked_db_session.scalar.side_effect = [SimpleNamespace(id="trial-app"), None]
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
|
||||
|
||||
# Assert
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
assert result["recommended_apps"][0]["can_trial"] is True
|
||||
assert result["recommended_apps"][1]["can_trial"] is False
|
||||
assert mocked_db_session.scalar.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("trial_query_result", "expected_can_trial"),
|
||||
[
|
||||
(SimpleNamespace(id="trial"), True),
|
||||
(None, False),
|
||||
],
|
||||
)
|
||||
def test_get_recommend_app_detail_should_set_can_trial_when_trial_feature_enabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocked_db_session: MagicMock,
|
||||
trial_query_result: Any,
|
||||
expected_can_trial: bool,
|
||||
) -> None:
|
||||
# Arrange
|
||||
detail = {"id": "app-1", "name": "Test App"}
|
||||
retrieval_instance = MagicMock()
|
||||
retrieval_instance.get_recommend_app_detail.return_value = detail
|
||||
retrieval_factory = MagicMock(return_value=retrieval_instance)
|
||||
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False)
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_recommend_app_factory",
|
||||
MagicMock(return_value=retrieval_factory),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
mocked_db_session.scalar.return_value = trial_query_result
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail("app-1"))
|
||||
|
||||
# Assert
|
||||
assert result["id"] == "app-1"
|
||||
assert result["can_trial"] is expected_can_trial
|
||||
mocked_db_session.scalar.assert_called_once()
|
||||
|
||||
|
||||
def test_add_trial_app_record_should_increment_count_when_existing_record_found(
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
existing_record = SimpleNamespace(count=3)
|
||||
mocked_db_session.scalar.return_value = existing_record
|
||||
|
||||
# Act
|
||||
RecommendedAppService.add_trial_app_record("app-1", "account-1")
|
||||
|
||||
# Assert
|
||||
assert existing_record.count == 4
|
||||
mocked_db_session.scalar.assert_called_once()
|
||||
mocked_db_session.commit.assert_called_once()
|
||||
mocked_db_session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_add_trial_app_record_should_create_new_record_when_no_existing_record(
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocked_db_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
RecommendedAppService.add_trial_app_record("app-2", "account-2")
|
||||
|
||||
# Assert
|
||||
mocked_db_session.scalar.assert_called_once()
|
||||
mocked_db_session.add.assert_called_once()
|
||||
added = mocked_db_session.add.call_args.args[0]
|
||||
assert added.app_id == "app-2"
|
||||
assert added.account_id == "account-2"
|
||||
assert added.count == 1
|
||||
mocked_db_session.commit.assert_called_once()
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
|
||||
from events.event_handlers.sync_workflow_schedule_when_app_published import (
|
||||
sync_schedule_from_workflow,
|
||||
)
|
||||
@ -14,6 +17,8 @@ from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from models.workflow import Workflow
|
||||
from services.errors.account import AccountNotFoundError
|
||||
from services.trigger import schedule_service as service_module
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
|
||||
|
||||
@ -775,5 +780,158 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_mock() -> MagicMock:
|
||||
return MagicMock(spec=Session)
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_update_schedule_should_update_only_node_id_without_recomputing_time(
|
||||
session_mock: MagicMock,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
schedule = MagicMock(spec=WorkflowSchedulePlan)
|
||||
schedule.cron_expression = "0 10 * * *"
|
||||
schedule.timezone = "UTC"
|
||||
session_mock.get.return_value = schedule
|
||||
|
||||
next_run_mock = MagicMock(return_value=datetime(2026, 1, 1, 10, 0, tzinfo=UTC))
|
||||
monkeypatch.setattr(service_module, "calculate_next_run_at", next_run_mock)
|
||||
|
||||
# Act
|
||||
result = ScheduleService.update_schedule(
|
||||
session=session_mock,
|
||||
schedule_id="schedule-1",
|
||||
updates=SchedulePlanUpdate(node_id="node-new"),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is schedule
|
||||
assert schedule.node_id == "node-new"
|
||||
next_run_mock.assert_not_called()
|
||||
session_mock.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_get_tenant_owner_should_raise_when_account_record_missing(session_mock: MagicMock) -> None:
|
||||
# Arrange
|
||||
join = SimpleNamespace(account_id="account-404")
|
||||
session_mock.execute.return_value.scalar_one_or_none.return_value = join
|
||||
session_mock.get.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AccountNotFoundError, match="Account not found: account-404"):
|
||||
ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1")
|
||||
|
||||
|
||||
def test_get_tenant_owner_should_raise_when_no_owner_or_admin_found(session_mock: MagicMock) -> None:
|
||||
# Arrange
|
||||
session_mock.execute.return_value.scalar_one_or_none.side_effect = [None, None]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AccountNotFoundError, match="Account not found for tenant: tenant-1"):
|
||||
ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1")
|
||||
|
||||
|
||||
def test_update_next_run_at_should_raise_when_schedule_not_found(session_mock: MagicMock) -> None:
|
||||
# Arrange
|
||||
session_mock.get.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleNotFoundError, match="Schedule not found: schedule-1"):
|
||||
ScheduleService.update_next_run_at(session=session_mock, schedule_id="schedule-1")
|
||||
|
||||
|
||||
def test_to_schedule_config_should_build_from_cron_mode() -> None:
|
||||
# Arrange
|
||||
node_config: dict[str, Any] = {
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"mode": "cron",
|
||||
"cron_expression": "0 12 * * *",
|
||||
"timezone": "Asia/Kolkata",
|
||||
},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = ScheduleService.to_schedule_config(node_config=node_config)
|
||||
|
||||
# Assert
|
||||
assert result.node_id == "node-1"
|
||||
assert result.cron_expression == "0 12 * * *"
|
||||
assert result.timezone == "Asia/Kolkata"
|
||||
|
||||
|
||||
def test_to_schedule_config_should_raise_for_cron_mode_without_expression() -> None:
|
||||
# Arrange
|
||||
node_config = {"id": "node-1", "data": {"mode": "cron", "cron_expression": ""}}
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleConfigError, match="Cron expression is required for cron mode"):
|
||||
ScheduleService.to_schedule_config(node_config=node_config)
|
||||
|
||||
|
||||
def test_to_schedule_config_should_build_from_visual_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
node_config = {
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"mode": "visual",
|
||||
"frequency": "daily",
|
||||
"visual_config": {"time": "9:30 AM"},
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}
|
||||
monkeypatch.setattr(ScheduleService, "visual_to_cron", MagicMock(return_value="30 9 * * *"))
|
||||
|
||||
# Act
|
||||
result = ScheduleService.to_schedule_config(node_config=node_config)
|
||||
|
||||
# Assert
|
||||
assert result.cron_expression == "30 9 * * *"
|
||||
|
||||
|
||||
def test_to_schedule_config_should_raise_for_invalid_mode() -> None:
|
||||
# Arrange
|
||||
node_config = {"id": "node-1", "data": {"mode": "manual"}}
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: manual"):
|
||||
ScheduleService.to_schedule_config(node_config=node_config)
|
||||
|
||||
|
||||
def test_extract_schedule_config_should_raise_when_graph_is_empty() -> None:
|
||||
# Arrange
|
||||
workflow = _workflow(graph_dict={})
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleConfigError, match="Workflow graph is empty"):
|
||||
ScheduleService.extract_schedule_config(workflow=workflow)
|
||||
|
||||
|
||||
def test_extract_schedule_config_should_raise_when_mode_invalid() -> None:
|
||||
# Arrange
|
||||
workflow = _workflow(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "schedule-1",
|
||||
"data": {
|
||||
"type": TRIGGER_SCHEDULE_NODE_TYPE,
|
||||
"mode": "invalid",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: invalid"):
|
||||
ScheduleService.extract_schedule_config(workflow=workflow)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -12,6 +12,7 @@ This test suite covers all functionality of the current VariableTruncator includ
|
||||
import functools
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@ -199,14 +200,14 @@ class TestArrayTruncation:
|
||||
|
||||
def test_small_array_no_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that small arrays are not truncated."""
|
||||
small_array = [1, 2]
|
||||
small_array: list[object] = [1, 2]
|
||||
result = small_truncator._truncate_array(small_array, 1000)
|
||||
assert result.value == small_array
|
||||
assert result.truncated is False
|
||||
|
||||
def test_array_element_limit_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that arrays over element limit are truncated."""
|
||||
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
|
||||
large_array: list[object] = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
|
||||
result = small_truncator._truncate_array(large_array, 1000)
|
||||
|
||||
assert result.truncated is True
|
||||
@ -215,7 +216,7 @@ class TestArrayTruncation:
|
||||
def test_array_size_budget_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test array truncation due to size budget constraints."""
|
||||
# Create array with strings that will exceed size budget
|
||||
large_strings = ["very long string " * 5, "another long string " * 5]
|
||||
large_strings: list[object] = ["very long string " * 5, "another long string " * 5]
|
||||
result = small_truncator._truncate_array(large_strings, 50)
|
||||
|
||||
assert result.truncated is True
|
||||
@ -276,10 +277,10 @@ class TestObjectTruncation:
|
||||
|
||||
# Values should be truncated if they exist
|
||||
for key, value in result.value.items():
|
||||
if isinstance(value, str):
|
||||
original_value = obj_with_long_values[key]
|
||||
# Value should be same or smaller
|
||||
assert len(value) <= len(original_value)
|
||||
assert isinstance(value, str)
|
||||
original_value = obj_with_long_values[key]
|
||||
# Value should be same or smaller
|
||||
assert len(value) <= len(original_value)
|
||||
|
||||
def test_object_key_dropping(self, small_truncator):
|
||||
"""Test object truncation where keys are dropped due to size constraints."""
|
||||
@ -506,10 +507,9 @@ class TestEdgeCases:
|
||||
truncator = VariableTruncator(string_length_limit=10)
|
||||
|
||||
# Unicode characters
|
||||
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
|
||||
unicode_text = "你好世界你好世界你好世界" # Multi-byte UTF-8 characters
|
||||
result = truncator.truncate(StringSegment(value=unicode_text))
|
||||
if len(unicode_text) > 10:
|
||||
assert result.truncated is True
|
||||
assert result.truncated is True
|
||||
|
||||
# Special JSON characters
|
||||
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
|
||||
@ -631,13 +631,12 @@ class TestIntegrationScenarios:
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
# Should handle all data types appropriately
|
||||
if result.truncated:
|
||||
# Verify the result is smaller or equal than original
|
||||
original_size = truncator.calculate_json_size(mixed_data)
|
||||
if isinstance(result.result, ObjectSegment):
|
||||
result_size = truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ObjectSegment)
|
||||
# Verify the result is smaller or equal than original
|
||||
original_size = truncator.calculate_json_size(mixed_data)
|
||||
result_size = truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
|
||||
def test_file_and_array_file_variable_mapping(self, file):
|
||||
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
|
||||
@ -675,3 +674,229 @@ def test_dummy_variable_truncator_methods():
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.result == segment
|
||||
assert result.truncated is False
|
||||
|
||||
|
||||
# === Merged from test_variable_truncator_additional.py ===
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
|
||||
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
from services import variable_truncator as truncator_module
|
||||
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
|
||||
|
||||
|
||||
class _AbstractPassthrough(BaseTruncator):
|
||||
def truncate(self, segment: Any) -> TruncationResult:
|
||||
# Arrange / Act
|
||||
return super().truncate(segment) # type: ignore[misc]
|
||||
|
||||
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
|
||||
# Arrange / Act
|
||||
return super().truncate_variable_mapping(v) # type: ignore[misc]
|
||||
|
||||
|
||||
def test_base_truncator_methods_should_execute_abstract_placeholders() -> None:
|
||||
# Arrange
|
||||
passthrough = _AbstractPassthrough()
|
||||
|
||||
# Act
|
||||
truncate_result = passthrough.truncate(StringSegment(value="x"))
|
||||
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
|
||||
|
||||
# Assert
|
||||
assert truncate_result is None
|
||||
assert mapping_result is None
|
||||
|
||||
|
||||
def test_default_should_use_dify_config_limits(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
|
||||
|
||||
# Act
|
||||
truncator = VariableTruncator.default()
|
||||
|
||||
# Assert
|
||||
assert truncator._max_size_bytes == 111
|
||||
assert truncator._array_element_limit == 7
|
||||
assert truncator._string_length_limit == 33
|
||||
|
||||
|
||||
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=5)
|
||||
mapping = {"very_long_key": "value"}
|
||||
|
||||
# Act
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
# Assert
|
||||
assert result == {"very_long_key": "..."}
|
||||
assert truncated is True
|
||||
|
||||
|
||||
def test_truncate_variable_mapping_should_handle_segment_values() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
mapping = {"seg": StringSegment(value="hello")}
|
||||
|
||||
# Act
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result["seg"], StringSegment)
|
||||
assert result["seg"].value == "hello"
|
||||
assert truncated is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, False),
|
||||
(True, False),
|
||||
(1, False),
|
||||
(1.5, False),
|
||||
("x", True),
|
||||
({"k": "v"}, True),
|
||||
],
|
||||
)
|
||||
def test_json_value_needs_truncation_should_match_expected_rules(value: Any, expected: bool) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = VariableTruncator._json_value_needs_truncation(value)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
|
||||
|
||||
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_result = truncator_module._PartResult(
|
||||
value=StringSegment(value="this is too long"),
|
||||
value_size=100,
|
||||
truncated=True,
|
||||
)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
# Act
|
||||
result = truncator.truncate(StringSegment(value="input"))
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert not result.result.value.startswith('"')
|
||||
|
||||
|
||||
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator()
|
||||
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_segment(IntegerSegment(value=1), 10)
|
||||
|
||||
|
||||
def test_calculate_json_size_should_unwrap_segment_values() -> None:
|
||||
# Arrange
|
||||
segment = StringSegment(value="abc")
|
||||
|
||||
# Act
|
||||
size = VariableTruncator.calculate_json_size(segment)
|
||||
|
||||
# Assert
|
||||
assert size == VariableTruncator.calculate_json_size("abc")
|
||||
|
||||
|
||||
def test_calculate_json_size_should_handle_updated_variable_instances() -> None:
|
||||
# Arrange
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
# Act
|
||||
size = VariableTruncator.calculate_json_size(updated)
|
||||
|
||||
# Assert
|
||||
assert size > 0
|
||||
|
||||
|
||||
def test_maybe_qa_structure_should_validate_shape() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
|
||||
assert VariableTruncator._maybe_qa_structure({}) is False
|
||||
|
||||
|
||||
def test_maybe_parent_child_structure_should_validate_shape() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
|
||||
assert (
|
||||
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"}) is False
|
||||
)
|
||||
|
||||
|
||||
def test_truncate_object_should_truncate_segment_values_inside_object() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
|
||||
mapping = {"s": StringSegment(value="long-content")}
|
||||
|
||||
# Act
|
||||
result = truncator._truncate_object(mapping, 20)
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.value["s"], StringSegment)
|
||||
|
||||
|
||||
def test_truncate_json_primitives_should_handle_updated_variable_input() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
# Act
|
||||
result = truncator._truncate_json_primitives(updated, 100)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
|
||||
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_segment = ObjectSegment(value={"k": "v"})
|
||||
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
# Act
|
||||
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
|
||||
@ -559,3 +559,757 @@ class TestWebhookServiceUnit:
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=True)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
|
||||
|
||||
# === Merged from test_webhook_service_additional.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from graphon.variables.types import SegmentType
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.return_value = _FakeQuery(None)
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"key": "value"}}
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
"webhook-1", is_debug=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"mode": "debug"}}
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/vnd.custom"},
|
||||
data="plain content",
|
||||
):
|
||||
result = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result["body"] == {"raw": "plain content"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_raise_for_request_too_large(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
|
||||
|
||||
# Act / Assert
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
|
||||
with pytest.raises(RequestEntityTooLarge):
|
||||
WebhookService.extract_webhook_data(MagicMock())
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b""):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
|
||||
body, files = WebhookService._extract_text_body()
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": ""}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
|
||||
monkeypatch.setattr(service_module, "magic", fake_magic)
|
||||
|
||||
# Act
|
||||
result = WebhookService._detect_binary_mimetype(b"binary")
|
||||
|
||||
# Assert
|
||||
assert result == "application/octet-stream"
|
||||
|
||||
|
||||
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
file_obj = MagicMock()
|
||||
file_obj.to_dict.return_value = {"id": "f-1"}
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
|
||||
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
|
||||
|
||||
uploaded = MagicMock()
|
||||
uploaded.filename = "file.unknown"
|
||||
uploaded.content_type = None
|
||||
uploaded.read.return_value = b"content"
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result == {"f": {"id": "f-1"}}
|
||||
|
||||
|
||||
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
manager = MagicMock()
|
||||
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
|
||||
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
|
||||
expected_file = MagicMock()
|
||||
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
|
||||
|
||||
# Act
|
||||
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result is expected_file
|
||||
manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw_value", "param_type", "expected"),
|
||||
[
|
||||
("42", SegmentType.NUMBER, 42),
|
||||
("3.14", SegmentType.NUMBER, 3.14),
|
||||
("yes", SegmentType.BOOLEAN, True),
|
||||
("no", SegmentType.BOOLEAN, False),
|
||||
],
|
||||
)
|
||||
def test_convert_form_value_should_convert_supported_types(
|
||||
raw_value: str,
|
||||
param_type: str,
|
||||
expected: Any,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._convert_form_value("param", raw_value, param_type)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Unsupported type"):
|
||||
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
|
||||
|
||||
|
||||
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
|
||||
|
||||
# Assert
|
||||
assert result == {"x": 1}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
|
||||
# Arrange
|
||||
raw_params = {"optional": "x"}
|
||||
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required parameter missing"):
|
||||
WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_include_unconfigured_parameters() -> None:
|
||||
# Arrange
|
||||
raw_params = {"known": "1", "unknown": "x"}
|
||||
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
# Assert
|
||||
assert result == {"known": 1, "unknown": "x"}
|
||||
|
||||
|
||||
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required body content missing"):
|
||||
WebhookService._process_body_parameters(
|
||||
raw_body={"raw": ""},
|
||||
body_configs=[WebhookBodyParameter(name="raw", required=True)],
|
||||
content_type=ContentType.TEXT,
|
||||
)
|
||||
|
||||
|
||||
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
|
||||
# Arrange
|
||||
raw_body = {"message": "hello", "extra": "x"}
|
||||
body_configs = [
|
||||
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
|
||||
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
|
||||
|
||||
# Assert
|
||||
assert result == {"message": "hello", "extra": "x"}
|
||||
|
||||
|
||||
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
|
||||
# Arrange
|
||||
headers = {"x_api_key": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
# Assert
|
||||
assert True
|
||||
|
||||
|
||||
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
|
||||
# Arrange
|
||||
headers = {"x-other": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required header missing"):
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
|
||||
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
|
||||
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_http_metadata(webhook_data, node_data)
|
||||
|
||||
# Assert
|
||||
assert result["valid"] is False
|
||||
assert "Content-type mismatch" in result["error"]
|
||||
|
||||
|
||||
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
|
||||
# Arrange
|
||||
headers = {"content-type": "application/json; charset=utf-8"}
|
||||
|
||||
# Act
|
||||
result = WebhookService._extract_content_type(headers)
|
||||
|
||||
# Assert
|
||||
assert result == "application/json"
|
||||
|
||||
|
||||
def test_build_workflow_inputs_should_include_expected_keys() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
|
||||
|
||||
# Act
|
||||
result = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Assert
|
||||
assert result["webhook_data"] == webhook_data
|
||||
assert result["webhook_headers"] == {"h": "v"}
|
||||
assert result["webhook_query_params"] == {"q": 1}
|
||||
assert result["webhook_body"] == {"b": 2}
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
webhook_data = {"body": {"x": 1}}
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
|
||||
)
|
||||
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
trigger_async_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
|
||||
|
||||
# Act
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Assert
|
||||
trigger_async_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
|
||||
)
|
||||
quota_type = SimpleNamespace(
|
||||
TRIGGER=SimpleNamespace(
|
||||
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
mark_rate_limited_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
mark_rate_limited_mock.assert_called_once_with("tenant-1")
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
|
||||
)
|
||||
logger_exception_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
logger_exception_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(
|
||||
walk_nodes=lambda _node_type: [
|
||||
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
|
||||
]
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
|
||||
|
||||
class _WorkflowWebhookTrigger:
|
||||
app_id = "app_id"
|
||||
tenant_id = "tenant_id"
|
||||
webhook_id = "webhook_id"
|
||||
node_id = "node_id"
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
|
||||
self.id = None
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.node_id = node_id
|
||||
self.webhook_id = webhook_id
|
||||
self.created_by = created_by
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.commit_count = 0
|
||||
self.existing_records = [SimpleNamespace(node_id="node-stale")]
|
||||
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: self.existing_records)
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def flush(self) -> None:
|
||||
for idx, obj in enumerate(self.added, start=1):
|
||||
if obj.id is None:
|
||||
obj.id = f"rec-{idx}"
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
|
||||
def delete(self, obj: Any) -> None:
|
||||
self.deleted.append(obj)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.return_value = None
|
||||
|
||||
fake_session = _Session()
|
||||
|
||||
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
redis_set_mock = MagicMock()
|
||||
redis_delete_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
|
||||
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
|
||||
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
assert fake_session.commit_count == 2
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [])
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: [])
|
||||
|
||||
def commit(self) -> None:
|
||||
return None
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
logger_exception_mock = MagicMock()
|
||||
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
_patch_session(monkeypatch, _Session())
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert logger_exception_mock.call_count == 1
|
||||
|
||||
|
||||
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
|
||||
# Arrange
|
||||
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
|
||||
|
||||
# Act
|
||||
body, status = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
# Assert
|
||||
assert status == 200
|
||||
assert "message" in body
|
||||
|
||||
|
||||
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
webhook_id = WebhookService.generate_webhook_id()
|
||||
|
||||
# Assert
|
||||
assert isinstance(webhook_id, str)
|
||||
assert len(webhook_id) == 24
|
||||
|
||||
|
||||
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
|
||||
|
||||
# Assert
|
||||
assert result == 123
|
||||
|
||||
@ -176,3 +176,300 @@ class TestWorkflowRunService:
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
|
||||
|
||||
# === Merged from test_workflow_run_service.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
|
||||
from services import workflow_run_service as service_module
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
|
||||
# Arrange
|
||||
node_repo = MagicMock()
|
||||
workflow_run_repo = MagicMock()
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
|
||||
# Assert
|
||||
return node_repo, workflow_run_repo, factory
|
||||
|
||||
|
||||
def _app_model(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _end_user(**kwargs: Any) -> EndUser:
|
||||
return cast(EndUser, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService()
|
||||
|
||||
# Assert
|
||||
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
|
||||
def test___init___should_create_sessionmaker_when_engine_is_provided(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
class FakeEngine:
|
||||
pass
|
||||
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "Engine", FakeEngine)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
engine = cast(Engine, FakeEngine())
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService(session_factory=engine)
|
||||
|
||||
# Assert
|
||||
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
|
||||
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, workflow_run_repo, factory = repository_factory_mocks
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService(session_factory=session_factory)
|
||||
|
||||
# Assert
|
||||
assert service._session_factory is session_factory
|
||||
assert service._node_execution_service_repo is node_repo
|
||||
assert service._workflow_run_repo is workflow_run_repo
|
||||
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
|
||||
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
|
||||
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="pagination")
|
||||
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
|
||||
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
limit=7,
|
||||
last_id="last-1",
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
|
||||
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
run_with_message = SimpleNamespace(
|
||||
id="run-1",
|
||||
status="running",
|
||||
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
|
||||
)
|
||||
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
|
||||
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
|
||||
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
|
||||
|
||||
# Assert
|
||||
assert result is pagination
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].message_id == "msg-1"
|
||||
assert result.data[0].conversation_id == "conv-1"
|
||||
assert result.data[0].status == "running"
|
||||
assert not hasattr(result.data[1], "message_id")
|
||||
assert result.data[1].id == "run-2"
|
||||
|
||||
|
||||
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="workflow_run")
|
||||
workflow_run_repo.get_workflow_run_by_id.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_runs_count_should_forward_optional_filters(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = {"total": 3, "succeeded": 2}
|
||||
workflow_run_repo.get_workflow_runs_count.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-1")
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
|
||||
class FakeEndUser:
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
|
||||
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
|
||||
app_model = _app_model(id="app-1")
|
||||
expected = [SimpleNamespace(id="exec-1")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-end-user",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-account")
|
||||
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-account",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id=None)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="tenant_id cannot be None"):
|
||||
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
@ -0,0 +1,831 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from models.model import Account, App, AppMode, AppModelConfig
|
||||
from services.workflow import workflow_converter as converter_module
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
try:
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
except ModuleNotFoundError:
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter() -> WorkflowConverter:
|
||||
return WorkflowConverter()
|
||||
|
||||
|
||||
def _app_model(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app_model_config(**kwargs: Any) -> AppModelConfig:
|
||||
return cast(AppModelConfig, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _build_start_graph() -> dict[str, Any]:
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"position": None,
|
||||
"data": {"type": BuiltinNodeTypes.START, "variables": [{"variable": "name"}, {"variable": "city"}]},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
|
||||
def _build_model_config(mode: str | LLMMode) -> ModelConfigEntity:
|
||||
return ModelConfigEntity(provider="openai", model="gpt-4", mode=mode, parameters={}, stop=[])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_variables() -> list[VariableEntity]:
|
||||
return [
|
||||
VariableEntity(variable="text_input", label="text-input", type=VariableEntityType.TEXT_INPUT),
|
||||
VariableEntity(variable="paragraph", label="paragraph", type=VariableEntityType.PARAGRAPH),
|
||||
VariableEntity(variable="select", label="select", type=VariableEntityType.SELECT),
|
||||
]
|
||||
|
||||
|
||||
def test__convert_to_start_node(default_variables: list[VariableEntity]) -> None:
|
||||
result = WorkflowConverter()._convert_to_start_node(default_variables)
|
||||
|
||||
assert result["id"] == "start"
|
||||
assert result["data"]["type"] == BuiltinNodeTypes.START
|
||||
assert result["data"]["variables"][0]["type"] == "text-input"
|
||||
assert result["data"]["variables"][0]["variable"] == "text_input"
|
||||
|
||||
|
||||
def test__convert_to_http_request_node_for_chatbot(default_variables: list[VariableEntity]) -> None:
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.CHAT
|
||||
|
||||
extension = APIBasedExtension(
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
extension.id = "api_based_extension_id"
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=extension)
|
||||
encrypter.decrypt_token = MagicMock(return_value="api_key")
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable",
|
||||
type="api",
|
||||
config={"api_based_extension_id": "api_based_extension_id"},
|
||||
),
|
||||
]
|
||||
|
||||
nodes, mapping = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=default_variables,
|
||||
external_data_variables=external_data_variables,
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert nodes[0]["data"]["type"] == BuiltinNodeTypes.HTTP_REQUEST
|
||||
assert nodes[1]["data"]["type"] == BuiltinNodeTypes.CODE
|
||||
body = json.loads(nodes[0]["data"]["body"]["data"])
|
||||
assert body["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
assert body["params"]["query"] == "{{#sys.query#}}"
|
||||
assert body["params"]["inputs"]["text_input"] == "{{#start.text_input#}}"
|
||||
assert mapping == {"external_variable": "code_1"}
|
||||
|
||||
|
||||
def test__convert_to_http_request_node_for_workflow_app(default_variables: list[VariableEntity]) -> None:
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
|
||||
extension = APIBasedExtension(
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
extension.id = "api_based_extension_id"
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=extension)
|
||||
encrypter.decrypt_token = MagicMock(return_value="api_key")
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable",
|
||||
type="api",
|
||||
config={"api_based_extension_id": "api_based_extension_id"},
|
||||
),
|
||||
]
|
||||
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=default_variables,
|
||||
external_data_variables=external_data_variables,
|
||||
)
|
||||
|
||||
body = json.loads(nodes[0]["data"]["body"]["data"])
|
||||
assert body["params"]["query"] == ""
|
||||
|
||||
|
||||
def test__convert_to_knowledge_retrieval_node_for_chatbot() -> None:
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_id_1", "dataset_id_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert node is not None
|
||||
assert node["data"]["query_variable_selector"] == ["sys", "query"]
|
||||
assert node["data"]["multiple_retrieval_config"]["top_k"] == 5
|
||||
|
||||
|
||||
def test__convert_to_knowledge_retrieval_node_for_workflow_app() -> None:
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_id_1", "dataset_id_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable="query",
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert node is not None
|
||||
assert node["data"]["query_variable_selector"] == ["start", "query"]
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables: list[VariableEntity]) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helper for {{text_input}} and {{paragraph}}",
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert node["data"]["type"] == BuiltinNodeTypes.LLM
|
||||
assert node["data"]["memory"] is not None
|
||||
assert node["data"]["prompt_template"][0]["role"] == "user"
|
||||
assert "{{#start.text_input#}}" in node["data"]["prompt_template"][0]["text"]
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_chat_model_with_empty_template(
|
||||
default_variables: list[VariableEntity],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="ignored",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter_module.SimplePromptTransform,
|
||||
"get_prompt_template",
|
||||
lambda self, **kwargs: {"prompt_template": PromptTemplateParser(""), "prompt_rules": {}},
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert node["data"]["prompt_template"] == []
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables: list[VariableEntity]) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[AdvancedChatMessageEntity(text="Hello {{text_input}}", role=PromptMessageRole.USER)]
|
||||
),
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert isinstance(node["data"]["prompt_template"], list)
|
||||
assert node["data"]["prompt_template"][0]["role"] == PromptMessageRole.USER.value
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_advanced_chat_model_without_template(
|
||||
default_variables: list[VariableEntity],
|
||||
) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=None,
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert node["data"]["prompt_template"] == []
|
||||
assert node["data"]["memory"] is None
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables: list[VariableEntity]) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(
|
||||
provider="openai",
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
mode=LLMMode.COMPLETION.value,
|
||||
parameters={},
|
||||
stop=[],
|
||||
)
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt="Hello {{text_input}} and {{#query#}}",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
|
||||
),
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.COMPLETION,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert node["data"]["prompt_template"]["text"].find("{{#sys.query#}}") != -1
|
||||
assert node["data"]["memory"]["role_prefix"]["user"] == "Human"
|
||||
|
||||
|
||||
def test__convert_to_end_node() -> None:
|
||||
node = WorkflowConverter()._convert_to_end_node()
|
||||
assert node["id"] == "end"
|
||||
assert node["data"]["type"] == BuiltinNodeTypes.END
|
||||
|
||||
|
||||
def test__convert_to_answer_node() -> None:
|
||||
node = WorkflowConverter()._convert_to_answer_node()
|
||||
assert node["id"] == "answer"
|
||||
assert node["data"]["type"] == BuiltinNodeTypes.ANSWER
|
||||
|
||||
|
||||
def test_convert_to_workflow_should_raise_when_app_model_config_is_missing(converter: WorkflowConverter) -> None:
|
||||
app_model = _app_model(app_model_config=None)
|
||||
|
||||
with pytest.raises(ValueError, match="App model config is required"):
|
||||
converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=_account(id="account-1"),
|
||||
name="new-app",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("source_mode", "expected_mode"),
|
||||
[
|
||||
(AppMode.CHAT, AppMode.ADVANCED_CHAT),
|
||||
(AppMode.COMPLETION, AppMode.WORKFLOW),
|
||||
],
|
||||
)
|
||||
def test_convert_to_workflow_should_create_new_app_with_fallback_fields(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
source_mode: AppMode,
|
||||
expected_mode: AppMode,
|
||||
) -> None:
|
||||
class FakeApp:
|
||||
def __init__(self) -> None:
|
||||
self.id = "new-app-id"
|
||||
|
||||
workflow = SimpleNamespace(app_id=None)
|
||||
monkeypatch.setattr(converter, "convert_app_model_config_to_workflow", MagicMock(return_value=workflow))
|
||||
monkeypatch.setattr(converter_module, "App", FakeApp)
|
||||
|
||||
db_session = SimpleNamespace(add=MagicMock(), flush=MagicMock(), commit=MagicMock())
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(converter_module.app_was_created, "send", send_mock)
|
||||
|
||||
account = _account(id="account-1")
|
||||
app_model = _app_model(
|
||||
tenant_id="tenant-1",
|
||||
name="Source App",
|
||||
mode=source_mode,
|
||||
icon_type="emoji",
|
||||
icon="sparkles",
|
||||
icon_background="#123456",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=10,
|
||||
api_rph=100,
|
||||
is_public=False,
|
||||
app_model_config=_app_model_config(id="config-1"),
|
||||
)
|
||||
|
||||
new_app = converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
name="",
|
||||
icon_type="",
|
||||
icon="",
|
||||
icon_background="",
|
||||
)
|
||||
|
||||
assert new_app.name == "Source App(workflow)"
|
||||
assert new_app.mode == expected_mode
|
||||
assert new_app.icon_type == "emoji"
|
||||
assert new_app.icon == "sparkles"
|
||||
assert new_app.icon_background == "#123456"
|
||||
assert new_app.created_by == "account-1"
|
||||
assert workflow.app_id == "new-app-id"
|
||||
db_session.add.assert_called_once()
|
||||
db_session.flush.assert_called_once()
|
||||
db_session.commit.assert_called_once()
|
||||
send_mock.assert_called_once_with(new_app, account=account)
|
||||
|
||||
|
||||
def test_convert_app_model_config_to_workflow_should_build_advanced_chat_graph_and_features(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT)
|
||||
app_config = SimpleNamespace(
|
||||
variables=[SimpleNamespace(variable="name")],
|
||||
external_data_variables=[SimpleNamespace(variable="ext")],
|
||||
dataset=SimpleNamespace(id="dataset"),
|
||||
model=SimpleNamespace(),
|
||||
prompt_template=SimpleNamespace(),
|
||||
additional_features=SimpleNamespace(file_upload=SimpleNamespace()),
|
||||
app_model_config_dict={
|
||||
"opening_statement": "hello",
|
||||
"suggested_questions": ["q1"],
|
||||
"suggested_questions_after_answer": True,
|
||||
"speech_to_text": True,
|
||||
"text_to_speech": {"enabled": True},
|
||||
"file_upload": {"enabled": True},
|
||||
"sensitive_word_avoidance": {"enabled": True},
|
||||
"retriever_resource": {"enabled": True},
|
||||
},
|
||||
)
|
||||
|
||||
class FakeWorkflow:
|
||||
VERSION_DRAFT = "draft"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.ADVANCED_CHAT))
|
||||
monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config))
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_start_node",
|
||||
MagicMock(
|
||||
return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_http_request_node",
|
||||
MagicMock(
|
||||
return_value=(
|
||||
[{"id": "http", "position": None, "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}],
|
||||
{"ext": "code_1"},
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_knowledge_retrieval_node",
|
||||
MagicMock(
|
||||
return_value={"id": "knowledge", "position": None, "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_llm_node",
|
||||
MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_answer_node",
|
||||
MagicMock(return_value={"id": "answer", "position": None, "data": {"type": BuiltinNodeTypes.ANSWER}}),
|
||||
)
|
||||
monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow)
|
||||
|
||||
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock())
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
workflow = converter.convert_app_model_config_to_workflow(
|
||||
app_model=app_model,
|
||||
app_model_config=_app_model_config(id="cfg"),
|
||||
account_id="account-1",
|
||||
)
|
||||
|
||||
graph = json.loads(workflow.graph)
|
||||
node_ids = [node["id"] for node in graph["nodes"]]
|
||||
assert node_ids == ["start", "http", "knowledge", "llm", "answer"]
|
||||
|
||||
features = json.loads(workflow.features)
|
||||
assert "opening_statement" in features
|
||||
assert "retriever_resource" in features
|
||||
db_session.add.assert_called_once()
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_convert_app_model_config_to_workflow_should_build_workflow_mode_with_end_node(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.COMPLETION)
|
||||
app_config = SimpleNamespace(
|
||||
variables=[SimpleNamespace(variable="name")],
|
||||
external_data_variables=[],
|
||||
dataset=SimpleNamespace(id="dataset"),
|
||||
model=SimpleNamespace(),
|
||||
prompt_template=SimpleNamespace(),
|
||||
additional_features=None,
|
||||
app_model_config_dict={
|
||||
"text_to_speech": {"enabled": False},
|
||||
"file_upload": {"enabled": False},
|
||||
"sensitive_word_avoidance": {"enabled": False},
|
||||
},
|
||||
)
|
||||
|
||||
class FakeWorkflow:
|
||||
VERSION_DRAFT = "draft"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.WORKFLOW))
|
||||
monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config))
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_start_node",
|
||||
MagicMock(
|
||||
return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(converter, "_convert_to_knowledge_retrieval_node", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_llm_node",
|
||||
MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_end_node",
|
||||
MagicMock(return_value={"id": "end", "position": None, "data": {"type": BuiltinNodeTypes.END}}),
|
||||
)
|
||||
monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow)
|
||||
|
||||
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock())
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
workflow = converter.convert_app_model_config_to_workflow(
|
||||
app_model=app_model,
|
||||
app_model_config=_app_model_config(id="cfg"),
|
||||
account_id="account-1",
|
||||
)
|
||||
|
||||
graph = json.loads(workflow.graph)
|
||||
node_ids = [node["id"] for node in graph["nodes"]]
|
||||
assert node_ids == ["start", "llm", "end"]
|
||||
|
||||
features = json.loads(workflow.features)
|
||||
assert set(features.keys()) == {"text_to_speech", "file_upload", "sensitive_word_avoidance"}
|
||||
|
||||
|
||||
def test_convert_to_app_config_should_route_to_correct_manager(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
agent_result = SimpleNamespace(kind="agent")
|
||||
chat_result = SimpleNamespace(kind="chat")
|
||||
completion_result = SimpleNamespace(kind="completion")
|
||||
monkeypatch.setattr(
|
||||
converter_module.AgentChatAppConfigManager, "get_app_config", MagicMock(return_value=agent_result)
|
||||
)
|
||||
monkeypatch.setattr(converter_module.ChatAppConfigManager, "get_app_config", MagicMock(return_value=chat_result))
|
||||
monkeypatch.setattr(
|
||||
converter_module.CompletionAppConfigManager,
|
||||
"get_app_config",
|
||||
MagicMock(return_value=completion_result),
|
||||
)
|
||||
|
||||
from_agent_mode = converter._convert_to_app_config(
|
||||
app_model=_app_model(mode=AppMode.AGENT_CHAT, is_agent=False),
|
||||
app_model_config=_app_model_config(id="cfg-1"),
|
||||
)
|
||||
from_agent_flag = converter._convert_to_app_config(
|
||||
app_model=_app_model(mode=AppMode.CHAT, is_agent=True),
|
||||
app_model_config=_app_model_config(id="cfg-2"),
|
||||
)
|
||||
from_chat_mode = converter._convert_to_app_config(
|
||||
app_model=_app_model(mode=AppMode.CHAT, is_agent=False),
|
||||
app_model_config=_app_model_config(id="cfg-3"),
|
||||
)
|
||||
from_completion_mode = converter._convert_to_app_config(
|
||||
app_model=_app_model(mode=AppMode.COMPLETION, is_agent=False),
|
||||
app_model_config=_app_model_config(id="cfg-4"),
|
||||
)
|
||||
|
||||
assert from_agent_mode is agent_result
|
||||
assert from_agent_flag is agent_result
|
||||
assert from_chat_mode is chat_result
|
||||
assert from_completion_mode is completion_result
|
||||
|
||||
|
||||
def test_convert_to_app_config_should_raise_for_invalid_app_mode(converter: WorkflowConverter) -> None:
|
||||
app_model = _app_model(mode=AppMode.WORKFLOW, is_agent=False)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||
converter._convert_to_app_config(app_model=app_model, app_model_config=_app_model_config(id="cfg"))
|
||||
|
||||
|
||||
def test_convert_to_http_request_node_should_skip_non_api_and_missing_extension_id(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT)
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(variable="skip_type", type="dataset", config={"api_based_extension_id": "x"}),
|
||||
ExternalDataVariableEntity(variable="skip_config", type="api", config={}),
|
||||
]
|
||||
|
||||
nodes, mapping = converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=[],
|
||||
external_data_variables=external_data_variables,
|
||||
)
|
||||
|
||||
assert nodes == []
|
||||
assert mapping == {}
|
||||
|
||||
|
||||
def test_convert_to_knowledge_retrieval_node_should_return_none_for_workflow_without_query_variable(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["ds-1"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=None,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
),
|
||||
)
|
||||
model_config = _build_model_config(mode=LLMMode.CHAT)
|
||||
|
||||
node = converter._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert node is None
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_raise_when_simple_chat_template_missing(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.CHAT)
|
||||
prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE)
|
||||
|
||||
with pytest.raises(ValueError, match="Simple prompt template is required"):
|
||||
converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_raise_when_prompt_template_parser_type_is_invalid_for_chat(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.CHAT)
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="Hello {{name}}",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter_module.SimplePromptTransform,
|
||||
"get_prompt_template",
|
||||
lambda self, **kwargs: {"prompt_template": "invalid"},
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="Expected PromptTemplateParser"):
|
||||
converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_raise_when_simple_completion_template_missing(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.COMPLETION)
|
||||
prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE)
|
||||
|
||||
with pytest.raises(ValueError, match="Simple prompt template is required"):
|
||||
converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.COMPLETION,
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_raise_when_completion_prompt_rules_type_is_invalid(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.COMPLETION)
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="Hello {{name}}",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter_module.SimplePromptTransform,
|
||||
"get_prompt_template",
|
||||
lambda self, **kwargs: {"prompt_template": PromptTemplateParser("Hello {{name}}"), "prompt_rules": "invalid"},
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="Expected dict for prompt_rules"):
|
||||
converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.COMPLETION,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_use_empty_text_for_advanced_completion_without_template(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.COMPLETION)
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=None,
|
||||
)
|
||||
|
||||
llm_node = converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.COMPLETION,
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["prompt_template"]["text"] == ""
|
||||
assert llm_node["data"]["memory"] is None
|
||||
|
||||
|
||||
def test_replace_template_variables_should_replace_start_and_external_references(converter: WorkflowConverter) -> None:
|
||||
template = "Hello {{name}} from {{city}} with {{weather}}"
|
||||
variables = [{"variable": "name"}, {"variable": "city"}]
|
||||
external_mapping = {"weather": "code_1"}
|
||||
|
||||
result = converter._replace_template_variables(template, variables, external_mapping)
|
||||
|
||||
assert result == "Hello {{#start.name#}} from {{#start.city#}} with {{#code_1.result#}}"
|
||||
|
||||
|
||||
def test_graph_helpers_should_create_edges_append_nodes_and_choose_mode(converter: WorkflowConverter) -> None:
|
||||
graph = {"nodes": [{"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START}}], "edges": []}
|
||||
node = {"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}
|
||||
|
||||
edge = converter._create_edge("start", "llm")
|
||||
updated_graph = converter._append_node(graph, node)
|
||||
workflow_mode = converter._get_new_app_mode(_app_model(mode=AppMode.COMPLETION))
|
||||
advanced_chat_mode = converter._get_new_app_mode(_app_model(mode=AppMode.CHAT))
|
||||
|
||||
assert edge == {"id": "start-llm", "source": "start", "target": "llm"}
|
||||
assert updated_graph["nodes"][-1]["id"] == "llm"
|
||||
assert updated_graph["edges"][-1]["source"] == "start"
|
||||
assert workflow_mode == AppMode.WORKFLOW
|
||||
assert advanced_chat_mode == AppMode.ADVANCED_CHAT
|
||||
|
||||
|
||||
def test_get_api_based_extension_should_raise_when_extension_not_found(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
db_session = SimpleNamespace(scalar=MagicMock(return_value=None))
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
with pytest.raises(ValueError, match="API Based Extension not found"):
|
||||
converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1")
|
||||
db_session.scalar.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_based_extension_should_return_entity_when_found(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
extension = SimpleNamespace(id="ext-1")
|
||||
db_session = SimpleNamespace(scalar=MagicMock(return_value=extension))
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
result = converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1")
|
||||
|
||||
assert result is extension
|
||||
db_session.scalar.assert_called_once()
|
||||
@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from itertools import cycle
|
||||
from threading import Event
|
||||
|
||||
import pytest
|
||||
@ -224,3 +223,577 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -
|
||||
buffer_state.task_id_ready.set()
|
||||
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
|
||||
assert task_id == expected
|
||||
|
||||
|
||||
# === Merged from test_workflow_event_snapshot_service_additional.py ===
|
||||
|
||||
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from threading import Event
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services import workflow_event_snapshot_service as service_module
|
||||
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
|
||||
|
||||
|
||||
def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"query": "hello"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=1.2,
|
||||
total_tokens=5,
|
||||
total_steps=2,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.outputs = {"answer": "ok"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionMaker:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __call__(self) -> _SessionContext:
|
||||
return _SessionContext(self._session)
|
||||
|
||||
|
||||
class _SubscriptionContext:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._subscription
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _Topic:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def subscribe(self) -> _SubscriptionContext:
|
||||
return _SubscriptionContext(self._subscription)
|
||||
|
||||
|
||||
class _StaticSubscription:
|
||||
def receive(self, timeout: int = 1) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PauseEntity(WorkflowPauseEntity):
|
||||
state: bytes
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return "pause-1"
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return "run-1"
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
return self.state
|
||||
|
||||
def get_pause_reasons(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def test_get_message_context_should_return_none_when_no_message() -> None:
|
||||
# Arrange
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=None))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
# Act
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None:
|
||||
# Arrange
|
||||
message = SimpleNamespace(
|
||||
id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
created_at=None,
|
||||
answer="answer",
|
||||
)
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=message))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
# Act
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.created_at == 0
|
||||
assert result.message_id == "msg-1"
|
||||
assert result.conversation_id == "conv-1"
|
||||
assert result.answer == "answer"
|
||||
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(None)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None:
|
||||
# Arrange
|
||||
pause_entity = _PauseEntity(state=b"not-a-valid-state")
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_resumption_context_should_parse_valid_state_into_context() -> None:
|
||||
# Arrange
|
||||
context = _build_resumption_context_additional(task_id="task-ctx")
|
||||
pause_entity = _PauseEntity(state=context.dumps().encode())
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.get_generate_entity().task_id == "task-ctx"
|
||||
|
||||
|
||||
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._resolve_task_id(
|
||||
resumption_context=None,
|
||||
buffer_state=None,
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == "run-1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("payload", "expected"),
|
||||
[
|
||||
(b'{"event":"node_started"}', {"event": "node_started"}),
|
||||
(b"invalid-json", None),
|
||||
(b"[]", None),
|
||||
],
|
||||
)
|
||||
def test_parse_event_message_should_parse_only_json_object(
|
||||
payload: bytes,
|
||||
expected: dict[str, Any] | None,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._parse_event_message(payload)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None:
|
||||
# Arrange
|
||||
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
|
||||
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
|
||||
|
||||
# Act
|
||||
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
|
||||
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
|
||||
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
|
||||
|
||||
# Assert
|
||||
assert is_finished is True
|
||||
assert paused_without_flag is False
|
||||
assert paused_with_flag is True
|
||||
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
|
||||
|
||||
|
||||
def test_apply_message_context_should_update_payload_when_context_exists() -> None:
|
||||
# Arrange
|
||||
payload: dict[str, Any] = {"event": "workflow_started"}
|
||||
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
|
||||
|
||||
# Act
|
||||
service_module._apply_message_context(payload, context)
|
||||
|
||||
# Assert
|
||||
assert payload["conversation_id"] == "conv-1"
|
||||
assert payload["message_id"] == "msg-1"
|
||||
assert payload["created_at"] == 1700000000
|
||||
|
||||
|
||||
def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None:
|
||||
# Arrange
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-1"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
event = buffer_state.queue.get(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert buffer_state.task_id_hint == "task-1"
|
||||
assert event["event"] == "node_started"
|
||||
|
||||
|
||||
def test_start_buffering_should_drop_old_event_when_queue_is_full(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
class QueueWithSingleFull:
|
||||
def __init__(self) -> None:
|
||||
self._first_put = True
|
||||
self.items: list[dict[str, Any]] = [{"event": "old"}]
|
||||
|
||||
def put_nowait(self, item: dict[str, Any]) -> None:
|
||||
if self._first_put:
|
||||
self._first_put = False
|
||||
raise queue.Full
|
||||
self.items.append(item)
|
||||
|
||||
def get_nowait(self) -> dict[str, Any]:
|
||||
if not self.items:
|
||||
raise queue.Empty
|
||||
return self.items.pop(0)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.items) == 0
|
||||
|
||||
fake_queue = QueueWithSingleFull()
|
||||
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
|
||||
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-2"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert fake_queue.items[-1]["task_id"] == "task-2"
|
||||
|
||||
|
||||
def test_start_buffering_should_set_done_event_when_subscription_raises() -> None:
|
||||
# Arrange
|
||||
class Subscription:
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
raise RuntimeError("subscription failure")
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert finished is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_get_message_context",
|
||||
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_build_snapshot_events",
|
||||
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
|
||||
)
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
finished_event = cast(Mapping[str, Any], events[1])
|
||||
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
|
||||
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
|
||||
assert called_kwargs["workflow_run_id"] == "run-1"
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
|
||||
class AlwaysEmptyQueue:
|
||||
def empty(self) -> bool:
|
||||
return False
|
||||
|
||||
def get(self, timeout: int = 1) -> None:
|
||||
raise queue.Empty
|
||||
|
||||
buffer_state = BufferState(
|
||||
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
time_values = cycle([0.0, 6.0, 21.0, 26.0])
|
||||
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
idle_timeout=20.0,
|
||||
ping_interval=5.0,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
buffer_state.done_event.set()
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events == [StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
assert snapshot_builder.call_args.kwargs["pause_entity"] is None
|
||||
|
||||
@ -10,6 +10,8 @@ This module tests the document indexing task functionality including:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextlib import nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -1113,13 +1115,17 @@ class TestAdvancedScenarios:
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task)
|
||||
|
||||
# Assert
|
||||
# Verify delete was called to clean up task key
|
||||
mock_redis.delete.assert_called_once()
|
||||
expected_task_key = f"tenant_document_indexing_task:{tenant_id}"
|
||||
|
||||
# Verify the correct key was deleted (contains tenant_id and "document_indexing")
|
||||
delete_call_args = mock_redis.delete.call_args[0][0]
|
||||
assert tenant_id in delete_call_args
|
||||
assert "document_indexing" in delete_call_args
|
||||
# Verify the task key for this tenant was deleted (do not assert call count; fixtures may be shared).
|
||||
mock_redis.delete.assert_any_call(expected_task_key)
|
||||
|
||||
deleted_keys = [delete_call.args[0] for delete_call in mock_redis.delete.call_args_list if delete_call.args]
|
||||
assert expected_task_key in deleted_keys
|
||||
|
||||
deleted_task_key = next(key for key in deleted_keys if key == expected_task_key)
|
||||
assert tenant_id in deleted_task_key
|
||||
assert "document_indexing" in deleted_task_key
|
||||
|
||||
def test_billing_disabled_skips_limit_checks(
|
||||
self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner, mock_feature_service
|
||||
@ -1510,3 +1516,475 @@ class TestRobustness:
|
||||
|
||||
# Verify the exception message
|
||||
assert "Feature service" in str(exc_info.value) or isinstance(exc_info.value, Exception)
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: MagicMock) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> MagicMock:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[override]
|
||||
return None
|
||||
|
||||
|
||||
class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Additional coverage for summary and tenant queue branches."""
|
||||
|
||||
def test_should_return_when_dataset_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test early return when dataset does not exist."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = None
|
||||
session.query.side_effect = lambda model: dataset_query
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
features_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.FeatureService.get_features", features_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
features_mock.assert_not_called()
|
||||
|
||||
def test_should_mark_documents_error_when_batch_upload_limit_exceeded(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Test batch upload limit triggers error handling."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.first.return_value = document
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(return_value=_SessionContext(session)),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(
|
||||
enabled=True,
|
||||
subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL),
|
||||
),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.dify_config.BATCH_UPLOAD_LIMIT", "1")
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1", "doc-2"])
|
||||
|
||||
# Assert
|
||||
assert document.indexing_status == "error"
|
||||
assert "batch upload limit" in document.error
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_should_queue_summary_generation_for_completed_documents(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary generation is queued for eligible documents."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
indexing_technique="high_quality",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
|
||||
doc_eligible = SimpleNamespace(
|
||||
id="doc-1",
|
||||
indexing_status="completed",
|
||||
doc_form="text",
|
||||
need_summary=True,
|
||||
)
|
||||
doc_skip_form = SimpleNamespace(
|
||||
id="doc-2",
|
||||
indexing_status="completed",
|
||||
doc_form="qa_model",
|
||||
need_summary=True,
|
||||
)
|
||||
doc_skip_status = SimpleNamespace(
|
||||
id="doc-3",
|
||||
indexing_status="processing",
|
||||
doc_form="text",
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_docs = [SimpleNamespace(id="doc-1"), SimpleNamespace(id="doc-2"), SimpleNamespace(id="doc-3")]
|
||||
phase1_document_query = MagicMock()
|
||||
phase1_document_query.where.return_value = phase1_document_query
|
||||
phase1_document_query.all.return_value = phase1_docs
|
||||
|
||||
summary_document_query = MagicMock()
|
||||
summary_document_query.where.return_value = summary_document_query
|
||||
summary_document_query.all.return_value = [doc_eligible, doc_skip_form, doc_skip_status]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_document_query
|
||||
session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query
|
||||
|
||||
create_session_mock = MagicMock(
|
||||
side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
|
||||
indexing_runner = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner))
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1", "doc-2", "doc-3"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_called_once_with("dataset-1", "doc-1", None)
|
||||
|
||||
def test_should_continue_when_summary_queue_fails(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary queueing errors are swallowed."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
indexing_technique="high_quality",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
|
||||
doc_eligible = SimpleNamespace(
|
||||
id="doc-1",
|
||||
indexing_status="completed",
|
||||
doc_form="text",
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_query = MagicMock()
|
||||
phase1_query.where.return_value = phase1_query
|
||||
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value = summary_query
|
||||
summary_query.all.return_value = [doc_eligible]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_query
|
||||
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
|
||||
indexing_runner = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner))
|
||||
delay_mock = MagicMock(side_effect=Exception("boom"))
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_called_once_with("dataset-1", "doc-1", None)
|
||||
|
||||
def test_should_return_when_dataset_missing_after_indexing(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test early return when dataset is missing after indexing."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.side_effect = [dataset, None]
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session3.query.side_effect = lambda model: dataset_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
session3.query.assert_called()
|
||||
|
||||
def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary generation skipped when indexing_technique is not high_quality."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
indexing_technique="economy",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session3.query.side_effect = lambda model: dataset_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
|
||||
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_not_called()
|
||||
|
||||
def test_should_skip_summary_generation_when_indexing_paused(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary generation is skipped when indexing is paused."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
|
||||
create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)])
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
|
||||
runner = MagicMock()
|
||||
runner.run.side_effect = DocumentIsPausedError("paused")
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner))
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_not_called()
|
||||
|
||||
def test_should_handle_indexing_runner_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test generic indexing runner exception is handled."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
|
||||
runner = MagicMock()
|
||||
runner.run.side_effect = RuntimeError("boom")
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner))
|
||||
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_not_called()
|
||||
|
||||
def test_should_log_missing_document_entry_in_summary_list(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test falsey document entries are handled in summary iteration."""
|
||||
|
||||
# Arrange
|
||||
class _FalseyDocument:
|
||||
def __init__(self, doc_id: str) -> None:
|
||||
self.id = doc_id
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
indexing_technique="high_quality",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_query = MagicMock()
|
||||
phase1_query.where.return_value = phase1_query
|
||||
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value = summary_query
|
||||
summary_query.all.return_value = [_FalseyDocument("missing-doc")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_query
|
||||
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
|
||||
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_not_called()
|
||||
|
||||
def test_normal_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test normal indexing task delegates to tenant queue handler."""
|
||||
# Arrange
|
||||
handler = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler)
|
||||
|
||||
# Act
|
||||
normal_document_indexing_task("tenant-1", "dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], normal_document_indexing_task)
|
||||
|
||||
def test_priority_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test priority indexing task delegates to tenant queue handler."""
|
||||
# Arrange
|
||||
handler = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler)
|
||||
|
||||
# Act
|
||||
priority_document_indexing_task("tenant-1", "dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], priority_document_indexing_task)
|
||||
|
||||
40
api/uv.lock
generated
@ -53,23 +53,6 @@ dependencies = [
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748, upload-time = "2026-03-28T17:19:40.6Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/7e/cb94129302d78c46662b47f9897d642fd0b33bdfef4b73b20c6ced35aa4c/aiohttp-3.13.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8ea0c64d1bcbf201b285c2246c51a0c035ba3bbd306640007bc5844a3b4658c1", size = 760027, upload-time = "2026-03-28T17:15:33.022Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/cd/2db3c9397c3bd24216b203dd739945b04f8b87bb036c640da7ddb63c75ef/aiohttp-3.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6f742e1fa45c0ed522b00ede565e18f97e4cf8d1883a712ac42d0339dfb0cce7", size = 508325, upload-time = "2026-03-28T17:15:34.714Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/a3/d28b2722ec13107f2e37a86b8a169897308bab6a3b9e071ecead9d67bd9b/aiohttp-3.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dcfb50ee25b3b7a1222a9123be1f9f89e56e67636b561441f0b304e25aaef8f", size = 502402, upload-time = "2026-03-28T17:15:36.409Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/d6/acd47b5f17c4430e555590990a4746efbcb2079909bb865516892bf85f37/aiohttp-3.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3262386c4ff370849863ea93b9ea60fd59c6cf56bf8f93beac625cf4d677c04d", size = 1771224, upload-time = "2026-03-28T17:15:38.223Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/af/af6e20113ba6a48fd1cd9e5832c4851e7613ef50c7619acdaee6ec5f1aff/aiohttp-3.13.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:473bb5aa4218dd254e9ae4834f20e31f5a0083064ac0136a01a62ddbae2eaa42", size = 1731530, upload-time = "2026-03-28T17:15:39.988Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/81/16/78a2f5d9c124ad05d5ce59a9af94214b6466c3491a25fb70760e98e9f762/aiohttp-3.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e56423766399b4c77b965f6aaab6c9546617b8994a956821cc507d00b91d978c", size = 1827925, upload-time = "2026-03-28T17:15:41.944Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/1f/79acf0974ced805e0e70027389fccbb7d728e6f30fcac725fb1071e63075/aiohttp-3.13.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8af249343fafd5ad90366a16d230fc265cf1149f26075dc9fe93cfd7c7173942", size = 1923579, upload-time = "2026-03-28T17:15:44.071Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/53/29f9e2054ea6900413f3b4c3eb9d8331f60678ec855f13ba8714c47fd48d/aiohttp-3.13.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bc0a5cf4f10ef5a2c94fdde488734b582a3a7a000b131263e27c9295bd682d9", size = 1767655, upload-time = "2026-03-28T17:15:45.911Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/57/462fe1d3da08109ba4aa8590e7aed57c059af2a7e80ec21f4bac5cfe1094/aiohttp-3.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:5c7ff1028e3c9fc5123a865ce17df1cb6424d180c503b8517afbe89aa566e6be", size = 1630439, upload-time = "2026-03-28T17:15:48.11Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/4b/4813344aacdb8127263e3eec343d24e973421143826364fa9fc847f6283f/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ba5cf98b5dcb9bddd857da6713a503fa6d341043258ca823f0f5ab7ab4a94ee8", size = 1745557, upload-time = "2026-03-28T17:15:50.13Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/01/1ef1adae1454341ec50a789f03cfafe4c4ac9c003f6a64515ecd32fe4210/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:d85965d3ba21ee4999e83e992fecb86c4614d6920e40705501c0a1f80a583c12", size = 1741796, upload-time = "2026-03-28T17:15:52.351Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/04/8cdd99af988d2aa6922714d957d21383c559835cbd43fbf5a47ddf2e0f05/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:49f0b18a9b05d79f6f37ddd567695943fcefb834ef480f17a4211987302b2dc7", size = 1805312, upload-time = "2026-03-28T17:15:54.407Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/7f/b48d5577338d4b25bbdbae35c75dbfd0493cb8886dc586fbfb2e90862239/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7f78cb080c86fbf765920e5f1ef35af3f24ec4314d6675d0a21eaf41f6f2679c", size = 1621751, upload-time = "2026-03-28T17:15:56.564Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/89/4eecad8c1858e6d0893c05929e22343e0ebe3aec29a8a399c65c3cc38311/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:67a3ec705534a614b68bbf1c70efa777a21c3da3895d1c44510a41f5a7ae0453", size = 1826073, upload-time = "2026-03-28T17:15:58.489Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/5c/9dc8293ed31b46c39c9c513ac7ca152b3c3d38e0ea111a530ad12001b827/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d6630ec917e85c5356b2295744c8a97d40f007f96a1c76bf1928dc2e27465393", size = 1760083, upload-time = "2026-03-28T17:16:00.677Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/19/8bbf6a4994205d96831f97b7d21a0feed120136e6267b5b22d229c6dc4dc/aiohttp-3.13.4-cp311-cp311-win32.whl", hash = "sha256:54049021bc626f53a5394c29e8c444f726ee5a14b6e89e0ad118315b1f90f5e3", size = 439690, upload-time = "2026-03-28T17:16:02.902Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/f5/ac409ecd1007528d15c3e8c3a57d34f334c70d76cfb7128a28cffdebd4c1/aiohttp-3.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:c033f2bc964156030772d31cbf7e5defea181238ce1f87b9455b786de7d30145", size = 463824, upload-time = "2026-03-28T17:16:05.058Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158, upload-time = "2026-03-28T17:16:06.901Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037, upload-time = "2026-03-28T17:16:08.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556, upload-time = "2026-03-28T17:16:10.63Z" },
|
||||
@ -1586,7 +1569,7 @@ dev = [
|
||||
{ name = "lxml-stubs", specifier = "~=0.5.1" },
|
||||
{ name = "mypy", specifier = "~=1.19.1" },
|
||||
{ name = "pandas-stubs", specifier = "~=3.0.0" },
|
||||
{ name = "pyrefly", specifier = ">=0.57.1" },
|
||||
{ name = "pyrefly", specifier = ">=0.59.1" },
|
||||
{ name = "pytest", specifier = "~=9.0.2" },
|
||||
{ name = "pytest-benchmark", specifier = "~=5.2.3" },
|
||||
{ name = "pytest-cov", specifier = "~=7.1.0" },
|
||||
@ -4839,18 +4822,19 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pyrefly"
|
||||
version = "0.57.1"
|
||||
version = "0.59.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d5/ce/7882c2af92b2ff6505fcd3430eff8048ece6c6254cc90bdc76ecee12dfab/pyrefly-0.59.1.tar.gz", hash = "sha256:bf1675b0c38d45df2c8f8618cbdfa261a1b92430d9d31eba16e0282b551e210f", size = 5475432, upload-time = "2026-04-01T22:04:04.11Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/10/04a0e05b08fc855b6fe38c3df549925fc3c2c6e750506870de7335d3e1f7/pyrefly-0.59.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:390db3cd14aa7e0268e847b60cd9ee18b04273eddfa38cf341ed3bb43f3fef2a", size = 12868133, upload-time = "2026-04-01T22:03:39.436Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/78/fa7be227c3e3fcacee501c1562278dd026186ffd1b5b5beb51d3941a3aed/pyrefly-0.59.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d246d417b6187c1650d7f855f61c68fbfd6d6155dc846d4e4d273a3e6b5175cb", size = 12379325, upload-time = "2026-04-01T22:03:42.046Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/13/6828ce1c98171b5f8388f33c4b0b9ea2ab8c49abe0ef8d793c31e30a05cb/pyrefly-0.59.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575ac67b04412dc651a7143d27e38a40fbdd3c831c714d5520d0e9d4c8631ab4", size = 35826408, upload-time = "2026-04-01T22:03:45.067Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/56/79ed8ece9a7ecad0113c394a06a084107db3ad8f1fefe19e7ded43c51245/pyrefly-0.59.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:062e6262ce1064d59dcad81ac0499bb7a3ad501e9bc8a677a50dc630ff0bf862", size = 38532699, upload-time = "2026-04-01T22:03:48.376Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/7d/ecc025e0f0e3f295b497f523cc19cefaa39e57abede8fc353d29445d174b/pyrefly-0.59.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ef4247f9e6f734feb93e1f2b75335b943629956e509f545cc9cdcccd76dd20", size = 36743570, upload-time = "2026-04-01T22:03:51.362Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2f/03/b1ce882ebcb87c673165c00451fbe4df17bf96ccfde18c75880dc87c5f5e/pyrefly-0.59.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a2d01723b84d042f4fa6ec871ffd52d0a7e83b0ea791c2e0bb0ff750abce56", size = 41236246, upload-time = "2026-04-01T22:03:54.361Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/17/af/5e9c7afd510e7dd64a2204be0ed39e804089cbc4338675a28615c7176acb/pyrefly-0.59.1-py3-none-win32.whl", hash = "sha256:4ea70c780848f8376411e787643ae5d2d09da8a829362332b7b26d15ebcbaf56", size = 11884747, upload-time = "2026-04-01T22:03:56.776Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/c1/7db1077627453fd1068f0761f059a9512645c00c4c20acfb9f0c24ac02ec/pyrefly-0.59.1-py3-none-win_amd64.whl", hash = "sha256:67e6a08cfd129a0d2788d5e40a627f9860e0fe91a876238d93d5c63ff4af68ae", size = 12720608, upload-time = "2026-04-01T22:03:59.252Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/16/4bb6e5fce5a9cf0992932d9435d964c33e507aaaf96fdfbb1be493078a4a/pyrefly-0.59.1-py3-none-win_arm64.whl", hash = "sha256:01179cb215cf079e8223a064f61a074f7079aa97ea705cbbc68af3d6713afd15", size = 12223158, upload-time = "2026-04-01T22:04:01.869Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
"prepare": "vp config"
|
||||
},
|
||||
"devDependencies": {
|
||||
"taze": "catalog:",
|
||||
"vite-plus": "catalog:"
|
||||
},
|
||||
"engines": {
|
||||
|
||||
|
Before Width: | Height: | Size: 36 KiB After Width: | Height: | Size: 36 KiB |
|
Before Width: | Height: | Size: 465 B After Width: | Height: | Size: 465 B |
|
Before Width: | Height: | Size: 643 B After Width: | Height: | Size: 643 B |
|
Before Width: | Height: | Size: 297 B After Width: | Height: | Size: 297 B |
|
Before Width: | Height: | Size: 13 KiB After Width: | Height: | Size: 13 KiB |
|
Before Width: | Height: | Size: 13 KiB After Width: | Height: | Size: 13 KiB |
|
Before Width: | Height: | Size: 3.0 KiB After Width: | Height: | Size: 3.0 KiB |
|
Before Width: | Height: | Size: 607 B After Width: | Height: | Size: 607 B |
|
Before Width: | Height: | Size: 599 B After Width: | Height: | Size: 599 B |
|
Before Width: | Height: | Size: 2.4 KiB After Width: | Height: | Size: 2.4 KiB |
|
Before Width: | Height: | Size: 874 B After Width: | Height: | Size: 874 B |
|
Before Width: | Height: | Size: 435 B After Width: | Height: | Size: 435 B |
|
Before Width: | Height: | Size: 1.6 KiB After Width: | Height: | Size: 1.6 KiB |
|
Before Width: | Height: | Size: 364 KiB After Width: | Height: | Size: 364 KiB |
|
Before Width: | Height: | Size: 1.0 KiB After Width: | Height: | Size: 1.0 KiB |
|
Before Width: | Height: | Size: 193 B After Width: | Height: | Size: 193 B |
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 1.4 KiB |
|
Before Width: | Height: | Size: 8.1 KiB After Width: | Height: | Size: 8.1 KiB |
|
Before Width: | Height: | Size: 1.3 KiB After Width: | Height: | Size: 1.3 KiB |
|
Before Width: | Height: | Size: 561 B After Width: | Height: | Size: 561 B |
|
Before Width: | Height: | Size: 7.5 KiB After Width: | Height: | Size: 7.5 KiB |
|
Before Width: | Height: | Size: 193 B After Width: | Height: | Size: 193 B |
|
Before Width: | Height: | Size: 717 B After Width: | Height: | Size: 717 B |
|
Before Width: | Height: | Size: 2.7 KiB After Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 1.4 KiB |
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 1.4 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 56 KiB After Width: | Height: | Size: 56 KiB |
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
|
Before Width: | Height: | Size: 214 B After Width: | Height: | Size: 214 B |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 3.3 KiB After Width: | Height: | Size: 3.3 KiB |
|
Before Width: | Height: | Size: 3.5 KiB After Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 1.9 KiB After Width: | Height: | Size: 1.9 KiB |
|
Before Width: | Height: | Size: 3.5 KiB After Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 1.9 KiB After Width: | Height: | Size: 1.9 KiB |
|
Before Width: | Height: | Size: 2.3 KiB After Width: | Height: | Size: 2.3 KiB |
|
Before Width: | Height: | Size: 2.6 KiB After Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 4.5 KiB After Width: | Height: | Size: 4.5 KiB |
|
Before Width: | Height: | Size: 1.6 KiB After Width: | Height: | Size: 1.6 KiB |
|
Before Width: | Height: | Size: 2.4 KiB After Width: | Height: | Size: 2.4 KiB |
|
Before Width: | Height: | Size: 5.5 KiB After Width: | Height: | Size: 5.5 KiB |
|
Before Width: | Height: | Size: 7.2 KiB After Width: | Height: | Size: 7.2 KiB |
|
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 16 KiB |
|
Before Width: | Height: | Size: 7.2 KiB After Width: | Height: | Size: 7.2 KiB |
|
Before Width: | Height: | Size: 5.3 KiB After Width: | Height: | Size: 5.3 KiB |
|
Before Width: | Height: | Size: 2.3 KiB After Width: | Height: | Size: 2.3 KiB |
|
Before Width: | Height: | Size: 8.3 KiB After Width: | Height: | Size: 8.3 KiB |
|
Before Width: | Height: | Size: 8.5 KiB After Width: | Height: | Size: 8.5 KiB |
|
Before Width: | Height: | Size: 9.8 KiB After Width: | Height: | Size: 9.8 KiB |
|
Before Width: | Height: | Size: 611 B After Width: | Height: | Size: 611 B |