mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
6816f89189
5
.github/workflows/autofix.yml
vendored
5
.github/workflows/autofix.yml
vendored
@ -94,11 +94,6 @@ jobs:
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -120,7 +120,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
|
||||
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@ -97,3 +97,8 @@ Feel free to reach out if you encounter any issues during the setup process.
|
||||
## Getting Help
|
||||
|
||||
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
||||
## Automated Agent Contributions
|
||||
|
||||
> [!NOTE]
|
||||
> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@ -33,16 +33,10 @@ api_key_list_model = console_ns.model(
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
if resource_model == App:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if resource is None:
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
|
||||
@ -80,10 +74,13 @@ class BaseApiKeyListResource(Resource):
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
.count()
|
||||
current_key_count: int = (
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
@ -119,14 +116,14 @@ class BaseApiKeyResource(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
getattr(ApiToken, self.resource_id_field) == resource_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if key is None:
|
||||
@ -137,7 +134,7 @@ class BaseApiKeyResource(Resource):
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -5,7 +5,7 @@ from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -376,8 +376,12 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
query = (
|
||||
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
@ -511,8 +515,12 @@ class ChatConversationApi(Resource):
|
||||
|
||||
match args.annotation_status:
|
||||
case "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
query = (
|
||||
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
@ -17,14 +18,18 @@ class BannerApi(Resource):
|
||||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
banners = db.session.scalars(
|
||||
base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort)
|
||||
).all()
|
||||
|
||||
# Fallback to en-US if no banners found and language is not en-US
|
||||
if not banners and language != "en-US":
|
||||
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||
banners = db.session.scalars(
|
||||
base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort)
|
||||
).all()
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
|
||||
@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource):
|
||||
def post(self):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
|
||||
recommended_app = db.session.scalar(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1)
|
||||
)
|
||||
if recommended_app is None:
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.query(App).where(App.id == payload.app_id).first()
|
||||
app = db.session.get(App, payload.app_id)
|
||||
|
||||
if app is None:
|
||||
raise NotFound("App entity not found")
|
||||
@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource):
|
||||
if not app.is_public:
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
installed_app = db.session.scalar(
|
||||
select(InstalledApp)
|
||||
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if installed_app is None:
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Literal, cast
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -476,7 +477,7 @@ class TrialSitApi(Resource):
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
@ -541,13 +542,7 @@ class AppWorkflowApi(Resource):
|
||||
if not app_model.workflow_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
workflow = db.session.get(Workflow, app_model.workflow_id)
|
||||
return workflow
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
installed_app = db.session.scalar(
|
||||
select(InstalledApp)
|
||||
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if installed_app is None:
|
||||
@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1))
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
account_trial_app_record = db.session.scalar(
|
||||
select(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
@ -100,6 +101,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
|
||||
def get_setup_status() -> DifySetup | bool | None:
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return db.session.query(DifySetup).first()
|
||||
return db.session.scalar(select(DifySetup).limit(1))
|
||||
|
||||
return True
|
||||
|
||||
@ -212,13 +212,13 @@ class AccountInitApi(Resource):
|
||||
raise ValueError("invitation_code is required")
|
||||
|
||||
# check invitation code
|
||||
invitation_code = (
|
||||
db.session.query(InvitationCode)
|
||||
invitation_code = db.session.scalar(
|
||||
select(InvitationCode)
|
||||
.where(
|
||||
InvitationCode.code == args.invitation_code,
|
||||
InvitationCode.status == InvitationCodeStatus.UNUSED,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not invitation_code:
|
||||
|
||||
@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if member is None:
|
||||
abort(404)
|
||||
else:
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
@ -29,6 +30,7 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
@ -108,9 +110,29 @@ class TenantListApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenant_dicts = []
|
||||
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
|
||||
is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED
|
||||
tenant_plans: dict[str, SubscriptionPlan] = {}
|
||||
|
||||
if is_saas:
|
||||
tenant_ids = [tenant.id for tenant in tenants]
|
||||
if tenant_ids:
|
||||
tenant_plans = BillingService.get_plan_bulk(tenant_ids)
|
||||
if not tenant_plans:
|
||||
logger.warning("get_plan_bulk returned empty result, falling back to legacy feature path")
|
||||
|
||||
for tenant in tenants:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan: str = CloudPlan.SANDBOX
|
||||
if is_saas:
|
||||
tenant_plan = tenant_plans.get(tenant.id)
|
||||
if tenant_plan:
|
||||
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
|
||||
else:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
elif not is_enterprise_only:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
|
||||
# Create a dictionary with tenant attributes
|
||||
tenant_dict = {
|
||||
@ -118,7 +140,7 @@ class TenantListApi(Resource):
|
||||
"name": tenant.name,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
|
||||
"plan": plan,
|
||||
"current": tenant.id == current_tenant_id if current_tenant_id else False,
|
||||
}
|
||||
|
||||
@ -198,7 +220,7 @@ class SwitchWorkspaceApi(Resource):
|
||||
except Exception:
|
||||
raise AccountNotLinkTenantError("Account not link tenant")
|
||||
|
||||
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
|
||||
new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant
|
||||
if new_tenant is None:
|
||||
raise ValueError("Tenant not found")
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
|
||||
@ -218,13 +219,9 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
and os.environ.get("INIT_PASSWORD")
|
||||
and not db.session.query(DifySetup).first()
|
||||
):
|
||||
raise NotInitValidateError()
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
|
||||
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
raise NotInitValidateError()
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.enums import ConversationFromSource, CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -130,10 +130,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
end_user_id = None
|
||||
account_id = None
|
||||
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
from_source = "api"
|
||||
from_source = ConversationFromSource.API
|
||||
end_user_id = application_generate_entity.user_id
|
||||
else:
|
||||
from_source = "console"
|
||||
from_source = ConversationFromSource.CONSOLE
|
||||
account_id = application_generate_entity.user_id
|
||||
|
||||
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
|
||||
|
||||
@ -6,7 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType
|
||||
from models.enums import CollectionBindingType, ConversationFromSource
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
@ -68,9 +68,9 @@ class AnnotationReplyFeature:
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
|
||||
from_source = "api"
|
||||
from_source = ConversationFromSource.API
|
||||
else:
|
||||
from_source = "console"
|
||||
from_source = ConversationFromSource.CONSOLE
|
||||
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(
|
||||
|
||||
@ -284,27 +284,29 @@ class TidbOnQdrantVector(BaseVector):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
if not ids:
|
||||
return
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchAny(any=ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
|
||||
@ -34,10 +34,12 @@ from .enums import (
|
||||
AppMCPServerStatus,
|
||||
AppStatus,
|
||||
BannerStatus,
|
||||
ConversationFromSource,
|
||||
ConversationStatus,
|
||||
CreatorUserRole,
|
||||
FeedbackFromSource,
|
||||
FeedbackRating,
|
||||
InvokeFrom,
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
@ -1022,10 +1024,12 @@ class Conversation(Base):
|
||||
#
|
||||
# Its value corresponds to the members of `InvokeFrom`.
|
||||
# (api/core/app/entities/app_invoke_entities.py)
|
||||
invoke_from = mapped_column(String(255), nullable=True)
|
||||
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
|
||||
|
||||
# ref: ConversationSource.
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_source: Mapped[ConversationFromSource] = mapped_column(
|
||||
EnumText(ConversationFromSource, length=255), nullable=False
|
||||
)
|
||||
from_end_user_id = mapped_column(StringUUID)
|
||||
from_account_id = mapped_column(StringUUID)
|
||||
read_at = mapped_column(sa.DateTime)
|
||||
@ -1374,8 +1378,10 @@ class Message(Base):
|
||||
)
|
||||
error: Mapped[str | None] = mapped_column(LongText)
|
||||
message_metadata: Mapped[str | None] = mapped_column(LongText)
|
||||
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
|
||||
from_source: Mapped[ConversationFromSource] = mapped_column(
|
||||
EnumText(ConversationFromSource, length=255), nullable=False
|
||||
)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
|
||||
@ -335,7 +335,11 @@ class BillingService:
|
||||
# Redis returns bytes, decode to string and parse JSON
|
||||
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
|
||||
plan_dict = json.loads(json_str)
|
||||
# NOTE (hj24): New billing versions may return timestamp as str, and validate_python
|
||||
# in non-strict mode will coerce it to the expected int type.
|
||||
# To preserve compatibility, always keep non-strict mode here and avoid strict mode.
|
||||
subscription_plan = subscription_adapter.validate_python(plan_dict)
|
||||
# NOTE END
|
||||
tenant_plans[tenant_id] = subscription_plan
|
||||
except Exception:
|
||||
logger.exception(
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from httpx import get
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -28,9 +28,16 @@ from services.tools.tools_transform_service import ToolTransformService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiSchemaParseResult(TypedDict):
|
||||
schema_type: str
|
||||
parameters_schema: list[dict[str, Any]]
|
||||
credentials_schema: list[dict[str, Any]]
|
||||
warning: dict[str, str]
|
||||
|
||||
|
||||
class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> Mapping[str, Any]:
|
||||
def parser_api_schema(schema: str) -> ApiSchemaParseResult:
|
||||
"""
|
||||
parse api schema to tool bundle
|
||||
"""
|
||||
@ -71,7 +78,7 @@ class ApiToolManageService:
|
||||
]
|
||||
|
||||
return cast(
|
||||
Mapping,
|
||||
ApiSchemaParseResult,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
|
||||
@ -18,6 +18,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||
from models.tools import MCPToolProvider
|
||||
@ -681,7 +682,7 @@ class MCPToolManageService:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def _build_tool_provider_response(
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list[MCPTool]
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Build API response for tool provider."""
|
||||
user = db_provider.load_user()
|
||||
@ -703,7 +704,7 @@ class MCPToolManageService:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
raise error
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format."""
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import json
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -34,6 +36,17 @@ class _NodeType(TypedDict):
|
||||
data: dict[str, Any]
|
||||
|
||||
|
||||
class _EdgeType(TypedDict):
|
||||
id: str
|
||||
source: str
|
||||
target: str
|
||||
|
||||
|
||||
class WorkflowGraph(TypedDict):
|
||||
nodes: list[_NodeType]
|
||||
edges: list[_EdgeType]
|
||||
|
||||
|
||||
class WorkflowConverter:
|
||||
"""
|
||||
App Convert to Workflow Mode
|
||||
@ -107,7 +120,7 @@ class WorkflowConverter:
|
||||
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
|
||||
# init workflow graph
|
||||
graph: dict[str, Any] = {"nodes": [], "edges": []}
|
||||
graph: WorkflowGraph = {"nodes": [], "edges": []}
|
||||
|
||||
# Convert list:
|
||||
# - variables -> start
|
||||
@ -385,7 +398,7 @@ class WorkflowConverter:
|
||||
self,
|
||||
original_app_mode: AppMode,
|
||||
new_app_mode: AppMode,
|
||||
graph: dict,
|
||||
graph: WorkflowGraph,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: FileUploadConfig | None = None,
|
||||
@ -595,7 +608,7 @@ class WorkflowConverter:
|
||||
"data": {"title": "ANSWER", "type": BuiltinNodeTypes.ANSWER, "answer": "{{#llm.text#}}"},
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str):
|
||||
def _create_edge(self, source: str, target: str) -> _EdgeType:
|
||||
"""
|
||||
Create Edge
|
||||
:param source: source node id
|
||||
@ -604,7 +617,7 @@ class WorkflowConverter:
|
||||
"""
|
||||
return {"id": f"{source}-{target}", "source": source, "target": target}
|
||||
|
||||
def _append_node(self, graph: dict[str, Any], node: _NodeType):
|
||||
def _append_node(self, graph: WorkflowGraph, node: _NodeType):
|
||||
"""
|
||||
Append Node to Graph
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||
@ -14,6 +15,10 @@ from services.plugin.plugin_service import PluginService
|
||||
from services.workflow.entities import TriggerMetadata
|
||||
|
||||
|
||||
class LogViewDetails(TypedDict):
|
||||
trigger_metadata: dict[str, Any] | None
|
||||
|
||||
|
||||
# Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it
|
||||
class LogView:
|
||||
"""Lightweight wrapper for WorkflowAppLog with computed details.
|
||||
@ -22,12 +27,12 @@ class LogView:
|
||||
- Proxies all other attributes to the underlying `WorkflowAppLog`
|
||||
"""
|
||||
|
||||
def __init__(self, log: WorkflowAppLog, details: dict | None):
|
||||
def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None):
|
||||
self.log = log
|
||||
self.details_ = details
|
||||
|
||||
@property
|
||||
def details(self) -> dict | None:
|
||||
def details(self) -> LogViewDetails | None:
|
||||
return self.details_
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
||||
@ -35,7 +35,7 @@ from factories.variable_factory import build_segment, segment_to_variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models import Account, App, Conversation
|
||||
from models.enums import DraftVariableType
|
||||
from models.enums import ConversationFromSource, DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.file_service import FileService
|
||||
@ -601,7 +601,7 @@ class WorkflowDraftVariableService:
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@ from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
@ -154,7 +155,7 @@ class TestChatMessageApiPermissions:
|
||||
re_sign_file_url_answer="",
|
||||
answer_tokens=0,
|
||||
provider_response_latency=0.0,
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=mock_account.id,
|
||||
feedbacks=[],
|
||||
|
||||
@ -165,8 +165,9 @@ class DifyTestContainers:
|
||||
|
||||
# Start Dify Sandbox container for code execution environment
|
||||
# Dify Sandbox provides a secure environment for executing user code
|
||||
# Use pinned version 0.2.12 to match production docker-compose configuration
|
||||
logger.info("Initializing Dify Sandbox container...")
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network)
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network)
|
||||
self.dify_sandbox.with_exposed_ports(8194)
|
||||
self.dify_sandbox.env = {
|
||||
"API_KEY": "test_api_key",
|
||||
|
||||
@ -13,7 +13,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.token import _real_cookie_name, generate_csrf_token
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import ConversationFromSource, CreatorUserRole
|
||||
from models.model import App, AppMode, Conversation, Message
|
||||
from models.workflow import WorkflowRun
|
||||
from services.account_service import AccountService
|
||||
@ -75,7 +75,7 @@ def _create_conversation(db_session: Session, app_id: str, account_id: str) -> C
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode=AppMode.CHAT,
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
@ -124,7 +124,7 @@ def _create_message(
|
||||
answer_price_unit=0.001,
|
||||
currency="USD",
|
||||
status="normal",
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
inputs={"query": "Hello"},
|
||||
|
||||
@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, UserAction
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.human_input import HumanInputForm, HumanInputFormStatus
|
||||
from models.model import App, Conversation, Message
|
||||
@ -78,8 +79,8 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account.id,
|
||||
from_end_user_id=None,
|
||||
)
|
||||
@ -101,7 +102,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
|
||||
answer_unit_price=Decimal("0.001"),
|
||||
provider_response_latency=0.5,
|
||||
currency="USD",
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
@ -12,15 +13,26 @@ from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from dify_graph.entities import WorkflowExecution
|
||||
from dify_graph.entities.pause_reason import PauseReasonType
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
|
||||
from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.human_input import (
|
||||
BackstageRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_build_human_input_required_reason,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
@ -90,6 +102,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
|
||||
WorkflowRun.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
|
||||
form_ids_subquery = select(HumanInputForm.id).where(
|
||||
HumanInputForm.tenant_id == scope.tenant_id,
|
||||
HumanInputForm.app_id == scope.app_id,
|
||||
)
|
||||
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
|
||||
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
|
||||
session.execute(
|
||||
delete(HumanInputForm).where(
|
||||
HumanInputForm.tenant_id == scope.tenant_id,
|
||||
HumanInputForm.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for state_key in scope.state_keys:
|
||||
@ -504,3 +529,200 @@ class TestDeleteWorkflowPause:
|
||||
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"):
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Integration tests for _PrivateWorkflowPauseEntity using real DB models."""
|
||||
|
||||
def test_properties(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Entity properties delegate to the persisted WorkflowPause model."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(pause)
|
||||
test_scope.state_keys.add(pause.state_object_key)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.id == pause.id
|
||||
assert entity.workflow_execution_id == workflow_run.id
|
||||
assert entity.resumed_at is None
|
||||
|
||||
def test_get_state(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""get_state loads state data from storage using the persisted state_object_key."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
state_key = f"workflow-state-{uuid4()}.json"
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=state_key,
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(pause)
|
||||
test_scope.state_keys.add(state_key)
|
||||
|
||||
expected_state = b'{"test": "state"}'
|
||||
storage.save(state_key, expected_state)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == expected_state
|
||||
|
||||
def test_get_state_caching(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""get_state caches the result so storage is only accessed once."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
state_key = f"workflow-state-{uuid4()}.json"
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=state_key,
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(pause)
|
||||
test_scope.state_keys.add(state_key)
|
||||
|
||||
expected_state = b'{"test": "state"}'
|
||||
storage.save(state_key, expected_state)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
|
||||
result1 = entity.get_state()
|
||||
# Delete from storage to prove second call uses cache
|
||||
storage.delete(state_key)
|
||||
test_scope.state_keys.discard(state_key)
|
||||
result2 = entity.get_state()
|
||||
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
|
||||
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
"""Integration tests for _build_human_input_required_reason using real DB models."""
|
||||
|
||||
def test_prefers_backstage_token_when_available(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Use backstage token when multiple recipient types may exist."""
|
||||
|
||||
expiration_time = naive_utc_now()
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "Alice"},
|
||||
node_title="Ask Name",
|
||||
display_in_ui=True,
|
||||
)
|
||||
|
||||
form_model = HumanInputForm(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
workflow_run_id=str(uuid4()),
|
||||
node_id="node-1",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="rendered",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
db_session_with_containers.add(form_model)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
form_id=form_model.id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
channel_payload="{}",
|
||||
)
|
||||
db_session_with_containers.add(delivery)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
access_token = secrets.token_urlsafe(8)
|
||||
recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=access_token,
|
||||
)
|
||||
db_session_with_containers.add(recipient)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create a pause so the reason has a valid pause_id
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.flush()
|
||||
test_scope.state_keys.add(pause.state_object_key)
|
||||
|
||||
reason_model = WorkflowPauseReason(
|
||||
pause_id=pause.id,
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id=form_model.id,
|
||||
node_id="node-1",
|
||||
message="",
|
||||
)
|
||||
db_session_with_containers.add(reason_model)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh to ensure we have DB-round-tripped objects
|
||||
db_session_with_containers.refresh(form_model)
|
||||
db_session_with_containers.refresh(reason_model)
|
||||
db_session_with_containers.refresh(recipient)
|
||||
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [recipient])
|
||||
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.form_token == access_token
|
||||
assert reason.node_title == "Ask Name"
|
||||
assert reason.form_content == "content"
|
||||
assert reason.inputs[0].output_variable_name == "name"
|
||||
assert reason.actions[0].id == "approve"
|
||||
|
||||
@ -0,0 +1,391 @@
|
||||
"""Integration tests for get_paginated_workflow_runs and get_workflow_runs_count using testcontainers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from dify_graph.entities import WorkflowExecution
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun, WorkflowType
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
|
||||
class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Concrete repository for tests where save() is not under test."""
|
||||
|
||||
def save(self, execution: WorkflowExecution) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TestScope:
|
||||
"""Per-test data scope used to isolate DB rows."""
|
||||
|
||||
tenant_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
app_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
workflow_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
user_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
def _create_workflow_run(
|
||||
session: Session,
|
||||
scope: _TestScope,
|
||||
*,
|
||||
status: WorkflowExecutionStatus,
|
||||
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
created_at_offset: timedelta | None = None,
|
||||
) -> WorkflowRun:
|
||||
"""Create and persist a workflow run bound to the current test scope."""
|
||||
now = naive_utc_now()
|
||||
workflow_run = WorkflowRun(
|
||||
id=str(uuid4()),
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
workflow_id=scope.workflow_id,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=triggered_from,
|
||||
version="draft",
|
||||
graph="{}",
|
||||
inputs="{}",
|
||||
status=status,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=scope.user_id,
|
||||
created_at=now + created_at_offset if created_at_offset is not None else now,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
return workflow_run
|
||||
|
||||
|
||||
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
|
||||
"""Remove test-created DB rows for a test scope."""
|
||||
session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == scope.tenant_id,
|
||||
WorkflowRun.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Build a repository backed by the testcontainers database engine."""
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_scope(db_session_with_containers: Session) -> _TestScope:
|
||||
"""Provide an isolated scope and clean related data after each test."""
|
||||
scope = _TestScope()
|
||||
yield scope
|
||||
_cleanup_scope_data(db_session_with_containers, scope)
|
||||
|
||||
|
||||
class TestGetPaginatedWorkflowRuns:
|
||||
"""Integration tests for get_paginated_workflow_runs."""
|
||||
|
||||
def test_returns_runs_without_status_filter(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Return all runs for the given tenant/app when no status filter is applied."""
|
||||
for status in (
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.RUNNING,
|
||||
):
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=status)
|
||||
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
assert len(result.data) == 3
|
||||
assert result.limit == 20
|
||||
assert result.has_more is False
|
||||
|
||||
def test_filters_by_status(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Return only runs matching the requested status."""
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
|
||||
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
assert len(result.data) == 2
|
||||
assert all(run.status == WorkflowExecutionStatus.SUCCEEDED for run in result.data)
|
||||
|
||||
def test_pagination_has_more(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Return has_more=True when more records exist beyond the limit."""
|
||||
for i in range(5):
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_offset=timedelta(seconds=i),
|
||||
)
|
||||
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=3,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
assert len(result.data) == 3
|
||||
assert result.has_more is True
|
||||
|
||||
def test_cursor_based_pagination(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Cursor-based pagination returns the next page of results."""
|
||||
for i in range(5):
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_offset=timedelta(seconds=i),
|
||||
)
|
||||
|
||||
# First page
|
||||
page1 = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=3,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
assert len(page1.data) == 3
|
||||
assert page1.has_more is True
|
||||
|
||||
# Second page using cursor
|
||||
page2 = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=3,
|
||||
last_id=page1.data[-1].id,
|
||||
status=None,
|
||||
)
|
||||
assert len(page2.data) == 2
|
||||
assert page2.has_more is False
|
||||
|
||||
# No overlap between pages
|
||||
page1_ids = {r.id for r in page1.data}
|
||||
page2_ids = {r.id for r in page2.data}
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
|
||||
def test_invalid_last_id_raises(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Raise ValueError when last_id refers to a non-existent run."""
|
||||
with pytest.raises(ValueError, match="Last workflow run not exists"):
|
||||
repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=str(uuid4()),
|
||||
status=None,
|
||||
)
|
||||
|
||||
def test_tenant_isolation(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Runs from other tenants are not returned."""
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
|
||||
other_scope = _TestScope(app_id=test_scope.app_id)
|
||||
try:
|
||||
_create_workflow_run(db_session_with_containers, other_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].tenant_id == test_scope.tenant_id
|
||||
finally:
|
||||
_cleanup_scope_data(db_session_with_containers, other_scope)
|
||||
|
||||
|
||||
class TestGetWorkflowRunsCount:
|
||||
"""Integration tests for get_workflow_runs_count."""
|
||||
|
||||
def test_count_without_status_filter(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Count all runs grouped by status when no status filter is applied."""
|
||||
for _ in range(3):
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
for _ in range(2):
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.RUNNING)
|
||||
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
)
|
||||
|
||||
assert result["total"] == 6
|
||||
assert result["succeeded"] == 3
|
||||
assert result["failed"] == 2
|
||||
assert result["running"] == 1
|
||||
assert result["stopped"] == 0
|
||||
assert result["partial-succeeded"] == 0
|
||||
|
||||
def test_count_with_status_filter(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Count only runs matching the requested status."""
|
||||
for _ in range(3):
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
|
||||
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
assert result["total"] == 3
|
||||
assert result["succeeded"] == 3
|
||||
assert result["failed"] == 0
|
||||
|
||||
def test_count_with_invalid_status_raises(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Invalid status raises StatementError because the column uses an enum type."""
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
|
||||
with pytest.raises(sa_exc.StatementError) as exc_info:
|
||||
repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="invalid_status",
|
||||
)
|
||||
assert isinstance(exc_info.value.orig, ValueError)
|
||||
|
||||
def test_count_with_time_range(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Time range filter excludes runs created outside the window."""
|
||||
# Recent run (within 1 day)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
# Old run (8 days ago)
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_offset=timedelta(days=-8),
|
||||
)
|
||||
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
assert result["total"] == 1
|
||||
assert result["succeeded"] == 1
|
||||
|
||||
def test_count_with_status_and_time_range(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Both status and time_range filters apply together."""
|
||||
# Recent succeeded
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
# Recent failed
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
|
||||
# Old succeeded (outside time range)
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_offset=timedelta(days=-8),
|
||||
)
|
||||
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
assert result["total"] == 1
|
||||
assert result["succeeded"] == 1
|
||||
assert result["failed"] == 0
|
||||
@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from models import Account
|
||||
from models.enums import MessageFileBelongsTo
|
||||
from models.enums import ConversationFromSource, MessageFileBelongsTo
|
||||
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.agent_service import AgentService
|
||||
@ -165,7 +165,7 @@ class TestAgentService:
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
@ -204,7 +204,7 @@ class TestAgentService:
|
||||
answer_unit_price=0.001,
|
||||
provider_response_latency=1.5,
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
@ -406,7 +406,7 @@ class TestAgentService:
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
@ -445,7 +445,7 @@ class TestAgentService:
|
||||
answer_unit_price=0.001,
|
||||
provider_response_latency=1.5,
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
@ -478,7 +478,7 @@ class TestAgentService:
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
@ -517,7 +517,7 @@ class TestAgentService:
|
||||
answer_unit_price=0.001,
|
||||
provider_response_latency=1.5,
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
@ -624,7 +624,7 @@ class TestAgentService:
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
app_model_config_id=None, # Explicitly set to None
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
@ -647,7 +647,7 @@ class TestAgentService:
|
||||
answer_unit_price=0.001,
|
||||
provider_response_latency=1.5,
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models import Account
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
from models.model import MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.app_service import AppService
|
||||
@ -136,8 +137,8 @@ class TestAnnotationService:
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
@ -174,8 +175,8 @@ class TestAnnotationService:
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
@ -721,7 +722,7 @@ class TestAnnotationService:
|
||||
query=f"Query {i}: {fake.sentence()}",
|
||||
user_id=account.id,
|
||||
message_id=fake.uuid4(),
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
score=0.8 + (i * 0.1),
|
||||
)
|
||||
|
||||
@ -772,7 +773,7 @@ class TestAnnotationService:
|
||||
query=query,
|
||||
user_id=account.id,
|
||||
message_id=message_id,
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
score=score,
|
||||
)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.conversation_service import ConversationService
|
||||
@ -107,7 +108,7 @@ class ConversationServiceIntegrationTestDataFactory:
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=invoke_from.value,
|
||||
from_source="api" if isinstance(user, EndUser) else "console",
|
||||
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
dialogue_count=0,
|
||||
@ -154,7 +155,7 @@ class ConversationServiceIntegrationTestDataFactory:
|
||||
currency="USD",
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
from_source="api" if isinstance(user, EndUser) else "console",
|
||||
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
)
|
||||
|
||||
@ -7,7 +7,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@ -94,7 +94,7 @@ class TestAppMessageExportServiceIntegration:
|
||||
name="conv",
|
||||
inputs={"seed": 1},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(conversation)
|
||||
@ -129,7 +129,7 @@ class TestAppMessageExportServiceIntegration:
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
message_metadata=message_metadata,
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import FeedbackRating
|
||||
from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom
|
||||
from models.model import MessageFeedback
|
||||
from services.app_service import AppService
|
||||
from services.errors.message import (
|
||||
@ -149,8 +149,8 @@ class TestMessageService:
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
@ -187,8 +187,8 @@ class TestMessageService:
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@ from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import Message
|
||||
from services import message_service
|
||||
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
|
||||
@ -36,7 +37,7 @@ def test_attach_message_extra_contents_assigns_serialized_payload(db_session_wit
|
||||
total_price=Decimal(0),
|
||||
currency="USD",
|
||||
status="normal",
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=fixture.account.id,
|
||||
)
|
||||
db_session_with_containers.add(message_without_extra_content)
|
||||
|
||||
@ -11,7 +11,14 @@ from sqlalchemy.orm import Session
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo
|
||||
from models.enums import (
|
||||
ConversationFromSource,
|
||||
DataSourceType,
|
||||
FeedbackFromSource,
|
||||
FeedbackRating,
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
)
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@ -166,7 +173,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
name="Test conversation",
|
||||
inputs={},
|
||||
status="normal",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
@ -196,7 +203,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_source=ConversationFromSource.API,
|
||||
from_account_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
from services.app_service import AppService
|
||||
@ -132,11 +133,14 @@ class TestSavedMessageService:
|
||||
# Create a simple conversation first
|
||||
from models.model import Conversation
|
||||
|
||||
is_account = hasattr(user, "current_tenant")
|
||||
from_source = ConversationFromSource.CONSOLE if is_account else ConversationFromSource.API
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
from_source="account" if hasattr(user, "current_tenant") else "end_user",
|
||||
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
|
||||
from_account_id=user.id if hasattr(user, "current_tenant") else None,
|
||||
from_source=from_source,
|
||||
from_end_user_id=user.id if not is_account else None,
|
||||
from_account_id=user.id if is_account else None,
|
||||
name=fake.sentence(nb_words=3),
|
||||
inputs={},
|
||||
status="normal",
|
||||
@ -150,9 +154,9 @@ class TestSavedMessageService:
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
from_source="account" if hasattr(user, "current_tenant") else "end_user",
|
||||
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
|
||||
from_account_id=user.id if hasattr(user, "current_tenant") else None,
|
||||
from_source=from_source,
|
||||
from_end_user_id=user.id if not is_account else None,
|
||||
from_account_id=user.id if is_account else None,
|
||||
inputs={},
|
||||
query=fake.sentence(nb_words=5),
|
||||
message=fake.text(max_nb_chars=100),
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import Account
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import Conversation, EndUser
|
||||
from models.web import PinnedConversation
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -145,7 +146,7 @@ class TestWebConversationService:
|
||||
system_instruction_tokens=50,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_source="console" if isinstance(user, Account) else "api",
|
||||
from_source=ConversationFromSource.CONSOLE if isinstance(user, Account) else ConversationFromSource.API,
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
dialogue_count=0,
|
||||
|
||||
@ -7,7 +7,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import ConversationFromSource, CreatorUserRole
|
||||
from models.model import (
|
||||
Message,
|
||||
)
|
||||
@ -165,7 +165,7 @@ class TestWorkflowRunService:
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
@ -186,7 +186,7 @@ class TestWorkflowRunService:
|
||||
message.answer_price_unit = 0.001
|
||||
message.currency = "USD"
|
||||
message.status = "normal"
|
||||
message.from_source = CreatorUserRole.ACCOUNT
|
||||
message.from_source = ConversationFromSource.CONSOLE
|
||||
message.from_account_id = account.id
|
||||
message.workflow_run_id = workflow_run.id
|
||||
message.inputs = {"input": "test input"}
|
||||
|
||||
320
api/tests/unit_tests/controllers/console/app/test_message.py
Normal file
320
api/tests/unit_tests/controllers/console/app/test_message.py
Normal file
@ -0,0 +1,320 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.message import (
|
||||
ChatMessageListApi,
|
||||
ChatMessagesQuery,
|
||||
FeedbackExportQuery,
|
||||
MessageAnnotationCountApi,
|
||||
MessageApi,
|
||||
MessageFeedbackApi,
|
||||
MessageFeedbackExportApi,
|
||||
MessageFeedbackPayload,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from models import App, AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
import contextlib
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None
|
||||
):
|
||||
with (
|
||||
patch("extensions.ext_database.db") as mock_db,
|
||||
patch("controllers.console.app.wraps.db", mock_db),
|
||||
patch("controllers.console.wraps.db", mock_db),
|
||||
patch("controllers.console.app.message.db", mock_db),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
# Set up a generic query mock that usually returns mock_app_model when getting app
|
||||
app_query_mock = MagicMock()
|
||||
app_query_mock.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
|
||||
data_query_mock = MagicMock()
|
||||
|
||||
def query_side_effect(*args, **kwargs):
|
||||
if args and hasattr(args[0], "__name__") and args[0].__name__ == "App":
|
||||
return app_query_mock
|
||||
return data_query_mock
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
mock_db.data_query = data_query_mock
|
||||
|
||||
# Let the caller override the stat db query logic
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
|
||||
full_path = f"{route_path}?{query_string}" if qs else route_path
|
||||
|
||||
with (
|
||||
patch("libs.login.current_user", proxy_mock),
|
||||
patch("flask_login.current_user", proxy_mock),
|
||||
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
|
||||
):
|
||||
with test_app.test_request_context(full_path, method=method, json=payload):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
|
||||
if "suggested-questions" in route_path:
|
||||
# simplistic extraction for message_id
|
||||
parts = route_path.split("chat-messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
elif "messages/" in route_path and "chat-messages" not in route_path:
|
||||
parts = route_path.split("messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
|
||||
api_instance = endpoint_class()
|
||||
|
||||
# Check if it has a dispatch_request or method
|
||||
if hasattr(api_instance, method.lower()):
|
||||
yield api_instance, mock_db, request.view_args
|
||||
|
||||
|
||||
class TestMessageValidators:
|
||||
def test_chat_messages_query_validators(self):
|
||||
# Test empty_to_none
|
||||
assert ChatMessagesQuery.empty_to_none("") is None
|
||||
assert ChatMessagesQuery.empty_to_none("val") == "val"
|
||||
|
||||
# Test validate_uuid
|
||||
assert ChatMessagesQuery.validate_uuid(None) is None
|
||||
assert (
|
||||
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_message_feedback_validators(self):
|
||||
assert (
|
||||
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_feedback_export_validators(self):
|
||||
assert FeedbackExportQuery.parse_bool(None) is None
|
||||
assert FeedbackExportQuery.parse_bool(True) is True
|
||||
assert FeedbackExportQuery.parse_bool("1") is True
|
||||
assert FeedbackExportQuery.parse_bool("0") is False
|
||||
assert FeedbackExportQuery.parse_bool("off") is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackExportQuery.parse_bool("invalid")
|
||||
|
||||
|
||||
class TestMessageEndpoints:
|
||||
def test_chat_message_list_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.get(**v_args)
|
||||
|
||||
def test_chat_message_list_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
# mock returns
|
||||
q_mock = mock_db.data_query
|
||||
q_mock.where.return_value.first.side_effect = [mock_conv]
|
||||
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg]
|
||||
mock_db.session.scalar.return_value = False
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["limit"] == 1
|
||||
assert resp["has_more"] is False
|
||||
assert len(resp["data"]) == 1
|
||||
|
||||
def test_message_feedback_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageFeedbackApi,
|
||||
"/apps/app_123/feedbacks",
|
||||
"POST",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.post(**v_args)
|
||||
|
||||
def test_message_feedback_success(self, app, mock_account, mock_app_model):
|
||||
payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"}
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.admin_feedback = None
|
||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
||||
|
||||
resp = api.post(**v_args)
|
||||
assert resp == {"result": "success"}
|
||||
|
||||
def test_message_annotation_count(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.count.return_value = 5
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"count": 5}
|
||||
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"]
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"data": ["q1", "q2"]}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected_exc"),
|
||||
[
|
||||
(MessageNotExistsError, NotFound),
|
||||
(ConversationNotExistsError, NotFound),
|
||||
(ProviderTokenNotInitError, ProviderNotInitializeError),
|
||||
(QuotaExceededError, ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError),
|
||||
(SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError),
|
||||
(Exception, InternalServerError),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_errors(
|
||||
self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model
|
||||
):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc()
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
with pytest.raises(expected_exc):
|
||||
api.get(**v_args)
|
||||
|
||||
@patch("services.feedback_service.FeedbackService.export_feedbacks")
|
||||
def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model):
|
||||
mock_export.return_value = {"exported": True}
|
||||
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"exported": True}
|
||||
|
||||
def test_message_api_get_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["id"] == "msg_123"
|
||||
275
api/tests/unit_tests/controllers/console/app/test_statistic.py
Normal file
275
api/tests/unit_tests/controllers/console/app/test_statistic.py
Normal file
@ -0,0 +1,275 @@
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.statistic import (
|
||||
AverageResponseTimeStatistic,
|
||||
AverageSessionInteractionStatistic,
|
||||
DailyConversationStatistic,
|
||||
DailyMessageStatistic,
|
||||
DailyTerminalsStatistic,
|
||||
DailyTokenCostStatistic,
|
||||
TokensPerSecondStatistic,
|
||||
UserSatisfactionRateStatistic,
|
||||
)
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None)
|
||||
):
|
||||
with (
|
||||
patch("controllers.console.app.statistic.db") as mock_db_stat,
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch(
|
||||
"controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123")
|
||||
),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute.return_value = mock_rs
|
||||
|
||||
mock_begin = MagicMock()
|
||||
mock_begin.__enter__.return_value = mock_conn
|
||||
mock_db_stat.engine.begin.return_value = mock_begin
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method="GET"):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
api_instance = endpoint_class()
|
||||
response = api_instance.get(app_id="app_123")
|
||||
return response
|
||||
|
||||
|
||||
class TestStatisticEndpoints:
|
||||
def test_daily_message_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["message_count"] == 10
|
||||
|
||||
def test_daily_conversation_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.conversation_count = 5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyConversationStatistic,
|
||||
"/apps/app_123/statistics/daily-conversations",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["conversation_count"] == 5
|
||||
|
||||
def test_daily_terminals_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.terminal_count = 2
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTerminalsStatistic,
|
||||
"/apps/app_123/statistics/daily-end-users",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["terminal_count"] == 2
|
||||
|
||||
def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.token_count = 100
|
||||
mock_row.total_price = Decimal("0.02")
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTokenCostStatistic,
|
||||
"/apps/app_123/statistics/token-costs",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["token_count"] == 100
|
||||
assert response.json["data"][0]["total_price"] == "0.02"
|
||||
|
||||
def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.interactions = Decimal("3.523")
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageSessionInteractionStatistic,
|
||||
"/apps/app_123/statistics/average-session-interactions",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["interactions"] == 3.52
|
||||
|
||||
def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 100
|
||||
mock_row.feedback_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
UserSatisfactionRateStatistic,
|
||||
"/apps/app_123/statistics/user-satisfaction-rate",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["rate"] == 100.0
|
||||
|
||||
def test_average_response_time_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_app_model.mode = AppMode.COMPLETION
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.latency = 1.234
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageResponseTimeStatistic,
|
||||
"/apps/app_123/statistics/average-response-time",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["latency"] == 1234.0
|
||||
|
||||
def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.tokens_per_second = 15.5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
TokensPerSecondStatistic,
|
||||
"/apps/app_123/statistics/tokens-per-second",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["tps"] == 15.5
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model):
|
||||
mock_parse.side_effect = ValueError("Invalid time")
|
||||
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=invalid&end=invalid",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model):
|
||||
import datetime
|
||||
|
||||
start = datetime.datetime.now()
|
||||
end = datetime.datetime.now()
|
||||
mock_parse.return_value = (start, end)
|
||||
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=something&end=something",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_parse.assert_called_once()
|
||||
@ -0,0 +1,313 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
ConversationVariableCollectionApi,
|
||||
EnvironmentVariableCollectionApi,
|
||||
NodeVariableCollectionApi,
|
||||
SystemVariableCollectionApi,
|
||||
VariableApi,
|
||||
VariableResetApi,
|
||||
WorkflowVariableCollectionApi,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from models import App, AppMode
|
||||
from models.enums import DraftVariableType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None):
|
||||
with (
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch("controllers.console.app.workflow_draft_variable.db"),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method=method, json=payload):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
# extract node_id or variable_id from path manually since view_args overrides
|
||||
if "nodes/" in route_path:
|
||||
request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0]
|
||||
if "variables/" in route_path:
|
||||
# simplistic extraction
|
||||
parts = route_path.split("variables/")
|
||||
if len(parts) > 1 and parts[1] and parts[1] != "reset":
|
||||
request.view_args["variable_id"] = parts[1].split("/")[0]
|
||||
|
||||
api_instance = endpoint_class()
|
||||
# we just call dispatch_request to avoid manual argument passing
|
||||
if hasattr(api_instance, method.lower()):
|
||||
func = getattr(api_instance, method.lower())
|
||||
return func(**request.view_args)
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
@staticmethod
|
||||
def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock:
|
||||
class DummyValueType:
|
||||
def exposed_type(self):
|
||||
return DraftVariableType.NODE
|
||||
|
||||
mock_var = MagicMock()
|
||||
mock_var.app_id = "app_123"
|
||||
mock_var.id = "var_123"
|
||||
mock_var.name = "test_var"
|
||||
mock_var.description = ""
|
||||
mock_var.get_variable_type.return_value = variable_type
|
||||
mock_var.get_selector.return_value = []
|
||||
mock_var.value_type = DummyValueType()
|
||||
mock_var.edited = False
|
||||
mock_var.visible = True
|
||||
mock_var.file_id = None
|
||||
mock_var.variable_file = None
|
||||
mock_var.is_truncated.return_value = False
|
||||
mock_var.get_value.return_value.model_copy.return_value.value = "test_value"
|
||||
return mock_var
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_workflow_variable_collection_get_success(
|
||||
self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model
|
||||
):
|
||||
mock_wf_srv.return_value.is_workflow_exist.return_value = True
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList(
|
||||
variables=[], total=0
|
||||
)
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables?page=1&limit=20",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": [], "total": 0}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.is_workflow_exist.return_value = False
|
||||
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables",
|
||||
"DELETE",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/node_123/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/sys/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/node_123/variables",
|
||||
"DELETE",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
|
||||
)
|
||||
assert resp["id"] == "var_123"
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = None
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
VariableApi,
|
||||
"/apps/app_123/workflows/draft/variables/var_123",
|
||||
"PATCH",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
payload={"name": "new_name"},
|
||||
)
|
||||
assert resp["id"] == "var_123"
|
||||
mock_draft_srv.return_value.update_variable.assert_called_once()
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
mock_draft_srv.return_value.delete_variable.assert_called_once()
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
mock_draft_srv.return_value.reset_variable.return_value = None # means no content
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
VariableResetApi,
|
||||
"/apps/app_123/workflows/draft/variables/var_123/reset",
|
||||
"PUT",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
ConversationVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/conversation-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
SystemVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/system-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf = MagicMock()
|
||||
mock_wf.environment_variables = []
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
EnvironmentVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/environment-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
@ -0,0 +1,209 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.data_source_bearer_auth import (
|
||||
ApiKeyAuthDataSource,
|
||||
ApiKeyAuthDataSourceBinding,
|
||||
ApiKeyAuthDataSourceBindingDelete,
|
||||
)
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSource:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_binding = MagicMock()
|
||||
mock_binding.id = "bind_123"
|
||||
mock_binding.category = "api_key"
|
||||
mock_binding.provider = "custom_provider"
|
||||
mock_binding.disabled = False
|
||||
mock_binding.created_at.timestamp.return_value = 1620000000
|
||||
mock_binding.updated_at.timestamp.return_value = 1620000001
|
||||
|
||||
mock_get_list.return_value = [mock_binding]
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 1
|
||||
assert response["sources"][0]["provider"] == "custom_provider"
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_get_list.return_value = None
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 0
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBinding:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_validate.assert_called_once()
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_create.side_effect = ValueError("Invalid structure")
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBindingDelete:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth")
|
||||
def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBindingDelete()
|
||||
response = api_instance.delete("binding_123")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 204
|
||||
mock_delete.assert_called_once_with("tenant_123", "binding_123")
|
||||
@ -0,0 +1,192 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.auth.data_source_oauth import (
|
||||
OAuthDataSource,
|
||||
OAuthDataSourceBinding,
|
||||
OAuthDataSourceCallback,
|
||||
OAuthDataSourceSync,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthDataSource:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("flask_login.current_user")
|
||||
@patch("libs.login.current_user")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
|
||||
def test_get_oauth_url_successful(
|
||||
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app
|
||||
):
|
||||
mock_oauth_provider = MagicMock()
|
||||
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
|
||||
mock_get_providers.return_value = {"notion": mock_oauth_provider}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
mock_libs_user.return_value = mock_account
|
||||
mock_flask_user.return_value = mock_account
|
||||
|
||||
# also patch current_account_with_tenant
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["data"] == "http://oauth.provider/auth"
|
||||
assert response[1] == 200
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("flask_login.current_user")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("unknown_provider")
|
||||
|
||||
assert response[0]["error"] == "Invalid provider"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceCallback:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_successful(self, mock_get_providers, app):
|
||||
provider_mock = MagicMock()
|
||||
mock_get_providers.return_value = {"notion": provider_mock}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "code=mock_code" in response.location
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_missing_code(self, mock_get_providers, app):
|
||||
provider_mock = MagicMock()
|
||||
mock_get_providers.return_value = {"notion": provider_mock}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "error=Access denied" in response.location
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_invalid_provider(self, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("invalid")
|
||||
|
||||
assert response[0]["error"] == "Invalid provider"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceBinding:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_get_binding_successful(self, mock_get_providers, app):
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_access_token.return_value = None
|
||||
mock_get_providers.return_value = {"notion": mock_provider}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"):
|
||||
api_instance = OAuthDataSourceBinding()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_provider.get_access_token.assert_called_once_with("auth_code_123")
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_get_binding_missing_code(self, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"):
|
||||
api_instance = OAuthDataSourceBinding()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["error"] == "Invalid code"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceSync:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app):
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.sync_data_source.return_value = None
|
||||
mock_get_providers.return_value = {"notion": mock_provider}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSourceSync()
|
||||
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
|
||||
response = api_instance.get("notion", "binding_123")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_provider.sync_data_source.assert_called_once_with("binding_123")
|
||||
@ -0,0 +1,417 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.auth.oauth_server import (
|
||||
OAuthServerAppApi,
|
||||
OAuthServerUserAccountApi,
|
||||
OAuthServerUserAuthorizeApi,
|
||||
OAuthServerUserTokenApi,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthServerAppApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
oauth_app.app_icon = "icon_url"
|
||||
oauth_app.app_label = "Test App"
|
||||
oauth_app.scope = "read,write"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["app_icon"] == "icon_url"
|
||||
assert response["app_label"] == "Test App"
|
||||
assert response["scope"] == "read,write"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_client_id(self, mock_get_app, mock_db, app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(NotFound, match="client_id is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAuthorizeApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
oauth_app = MagicMock()
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.current_account_with_tenant")
|
||||
@patch("controllers.console.wraps.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
def test_successful_authorize(
|
||||
self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.id = "user_123"
|
||||
from models.account import AccountStatus
|
||||
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
|
||||
mock_current.return_value = (mock_account, MagicMock())
|
||||
mock_wrap_current.return_value = (mock_account, MagicMock())
|
||||
|
||||
mock_sign.return_value = "auth_code_123"
|
||||
|
||||
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
|
||||
with patch("libs.login.current_user", mock_account):
|
||||
api_instance = OAuthServerUserAuthorizeApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["code"] == "auth_code_123"
|
||||
mock_sign.assert_called_once_with("test_client_id", "user_123")
|
||||
|
||||
|
||||
class TestOAuthServerUserTokenApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.client_secret = "test_secret"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("access_123", "refresh_123")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "access_123"
|
||||
assert response["refresh_token"] == "refresh_123"
|
||||
assert response["token_type"] == "Bearer"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="code is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "invalid_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="client_secret is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://invalid/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("new_access", "new_refresh")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "new_access"
|
||||
assert response["refresh_token"] == "new_refresh"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="refresh_token is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "invalid_grant",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="invalid grant_type"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAccountApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.name = "Test User"
|
||||
mock_account.email = "test@example.com"
|
||||
mock_account.avatar = "avatar_url"
|
||||
mock_account.interface_language = "en-US"
|
||||
mock_account.timezone = "UTC"
|
||||
mock_validate.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer valid_access_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["name"] == "Test User"
|
||||
assert response["email"] == "test@example.com"
|
||||
assert response["avatar"] == "avatar_url"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Authorization header is required"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "InvalidFormat"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Basic something"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "token_type is invalid"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer "},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_validate.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "access_token or client_id is invalid"
|
||||
@ -24,13 +24,8 @@ class TestBannerApi:
|
||||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = datetime(2024, 1, 1)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = [banner]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [banner]
|
||||
|
||||
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
@ -58,16 +53,14 @@ class TestBannerApi:
|
||||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.side_effect = [
|
||||
scalars_result = MagicMock()
|
||||
scalars_result.all.side_effect = [
|
||||
[],
|
||||
[banner],
|
||||
]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
@ -87,13 +80,8 @@ class TestBannerApi:
|
||||
api = banner_module.BannerApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = []
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
|
||||
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
||||
@ -260,11 +260,10 @@ class TestInstalledAppsCreateApi:
|
||||
app_entity.tenant_id = "t2"
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
None,
|
||||
]
|
||||
# scalar() is called for recommended_app and installed_app lookups
|
||||
session.scalar.side_effect = [recommended, None]
|
||||
# get() is called for app PK lookup
|
||||
session.get.return_value = app_entity
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
@ -282,7 +281,7 @@ class TestInstalledAppsCreateApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
@ -300,10 +299,10 @@ class TestInstalledAppsCreateApi:
|
||||
app_entity = MagicMock(is_public=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
]
|
||||
# scalar() returns recommended_app
|
||||
session.scalar.return_value = recommended
|
||||
# get() returns the app entity
|
||||
session.get.return_value = app_entity
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
|
||||
@ -958,8 +958,8 @@ class TestTrialSitApi:
|
||||
app_model = MagicMock()
|
||||
app_model.id = "a1"
|
||||
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, app_model)
|
||||
|
||||
@ -973,8 +973,8 @@ class TestTrialSitApi:
|
||||
app_model.tenant = MagicMock()
|
||||
app_model.tenant.status = TenantStatus.ARCHIVE
|
||||
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = site
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
|
||||
mock_scalar.return_value = site
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, app_model)
|
||||
|
||||
@ -990,10 +990,10 @@ class TestTrialSitApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module.db.session, "query") as mock_query,
|
||||
patch.object(module.db.session, "scalar") as mock_scalar,
|
||||
patch.object(module.SiteResponse, "model_validate") as mock_validate,
|
||||
):
|
||||
mock_query.return_value.where.return_value.first.return_value = site
|
||||
mock_scalar.return_value = site
|
||||
mock_validate_result = MagicMock()
|
||||
mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"}
|
||||
mock_validate.return_value = mock_validate_result
|
||||
|
||||
@ -34,9 +34,9 @@ def test_installed_app_required_not_found():
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
scalar_mock.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
@ -54,11 +54,11 @@ def test_installed_app_required_app_deleted():
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
patch("controllers.console.explore.wraps.db.session.delete"),
|
||||
patch("controllers.console.explore.wraps.db.session.commit"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
scalar_mock.return_value = installed_app
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
@ -76,9 +76,9 @@ def test_installed_app_required_success():
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
scalar_mock.return_value = installed_app
|
||||
|
||||
result = view("app-id")
|
||||
assert result == installed_app
|
||||
@ -149,9 +149,9 @@ def test_trial_app_required_not_allowed():
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
scalar_mock.return_value = None
|
||||
|
||||
with pytest.raises(TrialAppNotAllowed):
|
||||
view("app-id")
|
||||
@ -170,9 +170,9 @@ def test_trial_app_required_limit_exceeded():
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
scalar_mock.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
@ -194,9 +194,9 @@ def test_trial_app_required_success():
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
scalar_mock.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
|
||||
@ -114,7 +114,7 @@ class TestBaseApiKeyResource:
|
||||
|
||||
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = None
|
||||
db_mock.session.scalar.return_value = None
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
@ -125,7 +125,7 @@ class TestBaseApiKeyResource:
|
||||
|
||||
def test_delete_success(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
db_mock.session.scalar.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource"),
|
||||
|
||||
@ -328,7 +328,7 @@ class TestSystemSetup:
|
||||
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
|
||||
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = None # No setup
|
||||
mock_db.session.scalar.return_value = None # No setup
|
||||
mock_environ_get.return_value = "some_password"
|
||||
|
||||
@setup_required
|
||||
@ -345,7 +345,7 @@ class TestSystemSetup:
|
||||
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
|
||||
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = None # No setup
|
||||
mock_db.session.scalar.return_value = None # No setup
|
||||
mock_environ_get.return_value = None # No INIT_PASSWORD
|
||||
|
||||
@setup_required
|
||||
|
||||
@ -55,9 +55,9 @@ class TestAccountInitApi:
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
|
||||
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.workspace.account.db.session.query") as query_mock,
|
||||
patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
|
||||
scalar_mock.return_value = MagicMock(status="unused")
|
||||
resp = method(api)
|
||||
|
||||
assert resp["result"] == "success"
|
||||
|
||||
@ -207,10 +207,10 @@ class TestMemberCancelInviteApi:
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 200
|
||||
@ -226,9 +226,9 @@ class TestMemberCancelInviteApi:
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
get_mock.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "x")
|
||||
@ -244,13 +244,13 @@ class TestMemberCancelInviteApi:
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.CannotOperateSelfError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 400
|
||||
@ -266,13 +266,13 @@ class TestMemberCancelInviteApi:
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.NoPermissionError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 403
|
||||
@ -288,13 +288,13 @@ class TestMemberCancelInviteApi:
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.MemberNotInTenantError(),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 404
|
||||
|
||||
@ -36,7 +36,115 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestTenantListApi:
|
||||
def test_get_success(self, app):
|
||||
def test_get_success_saas_path(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
|
||||
return_value={
|
||||
"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0},
|
||||
"t2": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": 0},
|
||||
},
|
||||
) as get_plan_bulk_mock,
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["workspaces"]) == 2
|
||||
assert result["workspaces"][0]["current"] is True
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app):
|
||||
"""Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used.
|
||||
|
||||
billing.enabled is mocked False to prove the endpoint does not gate on it for this path
|
||||
(SaaS contract treats enabled as on; display follows subscription.plan).
|
||||
"""
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
features_t2 = MagicMock()
|
||||
features_t2.billing.enabled = False
|
||||
features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
|
||||
return_value={"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}},
|
||||
) as get_plan_bulk_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features_t2,
|
||||
) as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
get_features_mock.assert_called_once_with("t2")
|
||||
|
||||
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app):
|
||||
"""Test fallback to FeatureService when bulk billing returns empty result.
|
||||
|
||||
BillingService.get_plan_bulk catches exceptions internally and returns empty dict,
|
||||
so we simulate the real failure mode by returning empty dict for non-empty input.
|
||||
"""
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -54,27 +162,41 @@ class TestTenantListApi:
|
||||
)
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = True
|
||||
features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
features.billing.enabled = False
|
||||
features.billing.subscription.plan = CloudPlan.TEAM
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
|
||||
return_value={}, # Simulates real failure: empty result for non-empty input
|
||||
) as get_plan_bulk_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features,
|
||||
) as get_features_mock,
|
||||
patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["workspaces"]) == 2
|
||||
assert result["workspaces"][0]["current"] is True
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.TEAM
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
assert get_features_mock.call_count == 2
|
||||
logger_warning_mock.assert_called_once()
|
||||
|
||||
def test_get_billing_disabled(self, app):
|
||||
def test_get_billing_disabled_community_path(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -87,6 +209,7 @@ class TestTenantListApi:
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = False
|
||||
features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
@ -98,15 +221,83 @@ class TestTenantListApi:
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features,
|
||||
),
|
||||
) as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
get_features_mock.assert_called_once_with("t1")
|
||||
|
||||
def test_get_enterprise_only_skips_feature_service(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX
|
||||
assert result["workspaces"][0]["current"] is False
|
||||
assert result["workspaces"][1]["current"] is True
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
def test_get_enterprise_only_with_empty_tenants(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None)
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"] == []
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
|
||||
class TestWorkspaceListApi:
|
||||
@ -258,12 +449,12 @@ class TestSwitchWorkspaceApi:
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
|
||||
),
|
||||
):
|
||||
query_mock.return_value.get.return_value = tenant
|
||||
get_mock.return_value = tenant
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
@ -297,9 +488,9 @@ class TestSwitchWorkspaceApi:
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
|
||||
):
|
||||
query_mock.return_value.get.return_value = None
|
||||
get_mock.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.task_pipeline import message_cycle_manager
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
@ -92,7 +93,7 @@ def test_init_generate_records_marks_existing_conversation():
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id="user-id",
|
||||
from_account_id=None,
|
||||
)
|
||||
|
||||
181
api/tests/unit_tests/core/moderation/api/test_api.py
Normal file
181
api/tests/unit_tests/core/moderation/api/test_api.py
Normal file
@ -0,0 +1,181 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint
|
||||
from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams
|
||||
from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
||||
|
||||
class TestApiModeration:
|
||||
@pytest.fixture
|
||||
def api_config(self):
|
||||
return {
|
||||
"inputs_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"outputs_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"api_based_extension_id": "test-extension-id",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def api_moderation(self, api_config):
|
||||
return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config)
|
||||
|
||||
def test_moderation_input_params(self):
|
||||
params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query")
|
||||
assert params.app_id == "app-1"
|
||||
assert params.inputs == {"key": "val"}
|
||||
assert params.query == "test query"
|
||||
|
||||
# Test defaults
|
||||
params_default = ModerationInputParams()
|
||||
assert params_default.app_id == ""
|
||||
assert params_default.inputs == {}
|
||||
assert params_default.query == ""
|
||||
|
||||
def test_moderation_output_params(self):
|
||||
params = ModerationOutputParams(app_id="app-1", text="test text")
|
||||
assert params.app_id == "app-1"
|
||||
assert params.text == "test text"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ModerationOutputParams()
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_validate_config_success(self, mock_get_extension, api_config):
|
||||
mock_get_extension.return_value = MagicMock(spec=APIBasedExtension)
|
||||
ApiModeration.validate_config("test-tenant-id", api_config)
|
||||
mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id")
|
||||
|
||||
def test_validate_config_missing_extension_id(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": True},
|
||||
"outputs_config": {"enabled": True},
|
||||
}
|
||||
with pytest.raises(ValueError, match="api_based_extension_id is required"):
|
||||
ApiModeration.validate_config("test-tenant-id", config)
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_validate_config_extension_not_found(self, mock_get_extension, api_config):
|
||||
mock_get_extension.return_value = None
|
||||
with pytest.raises(ValueError, match="API-based Extension not found"):
|
||||
ApiModeration.validate_config("test-tenant-id", api_config)
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
|
||||
def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation):
|
||||
mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"}
|
||||
|
||||
result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello")
|
||||
|
||||
assert isinstance(result, ModerationInputsResult)
|
||||
assert result.flagged is True
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == "Blocked by API"
|
||||
|
||||
mock_get_config.assert_called_once_with(
|
||||
APIBasedExtensionPoint.APP_MODERATION_INPUT,
|
||||
{"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"},
|
||||
)
|
||||
|
||||
def test_moderation_for_inputs_disabled(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": False},
|
||||
"outputs_config": {"enabled": True},
|
||||
"api_based_extension_id": "ext-id",
|
||||
}
|
||||
moderation = ApiModeration("app-id", "tenant-id", config)
|
||||
result = moderation.moderation_for_inputs(inputs={}, query="")
|
||||
|
||||
assert result.flagged is False
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == ""
|
||||
|
||||
def test_moderation_for_inputs_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation.moderation_for_inputs({}, "")
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
|
||||
def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation):
|
||||
mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""}
|
||||
|
||||
result = api_moderation.moderation_for_outputs(text="hello world")
|
||||
|
||||
assert isinstance(result, ModerationOutputsResult)
|
||||
assert result.flagged is False
|
||||
|
||||
mock_get_config.assert_called_once_with(
|
||||
APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"}
|
||||
)
|
||||
|
||||
def test_moderation_for_outputs_disabled(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": True},
|
||||
"outputs_config": {"enabled": False},
|
||||
"api_based_extension_id": "ext-id",
|
||||
}
|
||||
moderation = ApiModeration("app-id", "tenant-id", config)
|
||||
result = moderation.moderation_for_outputs(text="test")
|
||||
|
||||
assert result.flagged is False
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
|
||||
def test_moderation_for_outputs_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation.moderation_for_outputs("test")
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
@patch("core.moderation.api.api.decrypt_token")
|
||||
@patch("core.moderation.api.api.APIBasedExtensionRequestor")
|
||||
def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation):
|
||||
mock_ext = MagicMock(spec=APIBasedExtension)
|
||||
mock_ext.api_endpoint = "http://api.test"
|
||||
mock_ext.api_key = "encrypted-key"
|
||||
mock_get_ext.return_value = mock_ext
|
||||
|
||||
mock_decrypt.return_value = "decrypted-key"
|
||||
|
||||
mock_requestor = MagicMock()
|
||||
mock_requestor.request.return_value = {"flagged": True}
|
||||
mock_requestor_cls.return_value = mock_requestor
|
||||
|
||||
params = {"some": "params"}
|
||||
result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
|
||||
|
||||
assert result == {"flagged": True}
|
||||
mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id")
|
||||
mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key")
|
||||
mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key")
|
||||
mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
|
||||
|
||||
def test_get_config_by_requestor_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation):
|
||||
mock_get_ext.return_value = None
|
||||
with pytest.raises(ValueError, match="API-based Extension not found"):
|
||||
api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
|
||||
|
||||
@patch("core.moderation.api.api.db.session.scalar")
|
||||
def test_get_api_based_extension(self, mock_scalar):
|
||||
mock_ext = MagicMock(spec=APIBasedExtension)
|
||||
mock_scalar.return_value = mock_ext
|
||||
|
||||
result = ApiModeration._get_api_based_extension("tenant-1", "ext-1")
|
||||
|
||||
assert result == mock_ext
|
||||
mock_scalar.assert_called_once()
|
||||
# Verify the call has the correct filters
|
||||
args, kwargs = mock_scalar.call_args
|
||||
stmt = args[0]
|
||||
# We can't easily inspect the statement without complex sqlalchemy tricks,
|
||||
# but calling it is usually enough for unit tests if we mock the result.
|
||||
207
api/tests/unit_tests/core/moderation/test_input_moderation.py
Normal file
207
api/tests/unit_tests/core/moderation/test_input_moderation.py
Normal file
@ -0,0 +1,207 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity
|
||||
from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class TestInputModeration:
|
||||
@pytest.fixture
|
||||
def app_config(self):
|
||||
config = MagicMock(spec=AppConfig)
|
||||
config.sensitive_word_avoidance = None
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def input_moderation(self):
|
||||
return InputModeration()
|
||||
|
||||
def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is False
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {"keywords": ["bad"]}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is False
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
mock_factory_cls.assert_called_once_with(
|
||||
name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]}
|
||||
)
|
||||
mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query)
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
@patch("core.moderation.input_moderation.TraceTask")
|
||||
def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
trace_manager = MagicMock(spec=TraceQueueManager)
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value)
|
||||
mock_trace_task.assert_called_once()
|
||||
call_kwargs = mock_trace_task.call_args.kwargs
|
||||
call_args = mock_trace_task.call_args.args
|
||||
assert call_args[0] == TraceTaskName.MODERATION_TRACE
|
||||
assert call_kwargs["message_id"] == message_id
|
||||
assert call_kwargs["moderation_result"] == mock_result
|
||||
assert call_kwargs["inputs"] == inputs
|
||||
assert "timer" in call_kwargs
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content"
|
||||
)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
with pytest.raises(ModerationError) as excinfo:
|
||||
input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == "Blocked content"
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(
|
||||
flagged=True,
|
||||
action=ModerationAction.OVERRIDDEN,
|
||||
inputs={"input_key": "overridden_value"},
|
||||
query="overridden query",
|
||||
)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is True
|
||||
assert final_inputs == {"input_key": "overridden_value"}
|
||||
assert final_query == "overridden query"
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = MagicMock()
|
||||
mock_result.flagged = True
|
||||
mock_result.action = "NONE" # Some other action
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
assert flagged is True
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
234
api/tests/unit_tests/core/moderation/test_output_moderation.py
Normal file
234
api/tests/unit_tests/core/moderation/test_output_moderation.py
Normal file
@ -0,0 +1,234 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageReplaceEvent
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
|
||||
|
||||
class TestOutputModeration:
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
return MagicMock(spec=AppQueueManager)
|
||||
|
||||
@pytest.fixture
|
||||
def moderation_rule(self):
|
||||
return ModerationRule(type="keywords", config={"keywords": "badword"})
|
||||
|
||||
@pytest.fixture
|
||||
def output_moderation(self, mock_queue_manager, moderation_rule):
|
||||
return OutputModeration(
|
||||
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
|
||||
)
|
||||
|
||||
def test_should_direct_output(self, output_moderation):
|
||||
assert output_moderation.should_direct_output() is False
|
||||
output_moderation.final_output = "blocked"
|
||||
assert output_moderation.should_direct_output() is True
|
||||
|
||||
def test_get_final_output(self, output_moderation):
|
||||
assert output_moderation.get_final_output() == ""
|
||||
output_moderation.final_output = "blocked"
|
||||
assert output_moderation.get_final_output() == "blocked"
|
||||
|
||||
def test_append_new_token(self, output_moderation):
|
||||
with patch.object(OutputModeration, "start_thread") as mock_start:
|
||||
output_moderation.append_new_token("hello")
|
||||
assert output_moderation.buffer == "hello"
|
||||
mock_start.assert_called_once()
|
||||
|
||||
output_moderation.thread = MagicMock()
|
||||
output_moderation.append_new_token(" world")
|
||||
assert output_moderation.buffer == "hello world"
|
||||
assert mock_start.call_count == 1
|
||||
|
||||
def test_moderation_completion_no_flag(self, output_moderation):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("safe content")
|
||||
|
||||
assert output == "safe content"
|
||||
assert flagged is False
|
||||
assert output_moderation.is_final_chunk is True
|
||||
|
||||
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
|
||||
)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
|
||||
|
||||
assert output == "preset"
|
||||
assert flagged is True
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert isinstance(args[0], QueueMessageReplaceEvent)
|
||||
assert args[0].text == "preset"
|
||||
assert args[1] == PublishFrom.TASK_PIPELINE
|
||||
|
||||
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
|
||||
)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
|
||||
|
||||
assert output == "masked content"
|
||||
assert flagged is True
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert args[0].text == "masked content"
|
||||
|
||||
def test_start_thread(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
|
||||
mock_current_app._get_current_object.return_value = mock_app
|
||||
with patch("threading.Thread") as mock_thread_class:
|
||||
mock_thread_instance = MagicMock()
|
||||
mock_thread_class.return_value = mock_thread_instance
|
||||
|
||||
thread = output_moderation.start_thread()
|
||||
|
||||
assert thread == mock_thread_instance
|
||||
mock_thread_class.assert_called_once()
|
||||
mock_thread_instance.start.assert_called_once()
|
||||
|
||||
def test_stop_thread(self, output_moderation):
|
||||
mock_thread = MagicMock()
|
||||
mock_thread.is_alive.return_value = True
|
||||
output_moderation.thread = mock_thread
|
||||
|
||||
output_moderation.stop_thread()
|
||||
assert output_moderation.thread_running is False
|
||||
|
||||
output_moderation.thread_running = True
|
||||
mock_thread.is_alive.return_value = False
|
||||
output_moderation.stop_thread()
|
||||
assert output_moderation.thread_running is True
|
||||
|
||||
@patch("core.moderation.output_moderation.ModerationFactory")
|
||||
def test_moderation_success(self, mock_factory_class, output_moderation):
|
||||
mock_factory = mock_factory_class.return_value
|
||||
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_outputs.return_value = mock_result
|
||||
|
||||
result = output_moderation.moderation("tenant", "app", "buffer")
|
||||
|
||||
assert result == mock_result
|
||||
mock_factory_class.assert_called_once_with(
|
||||
name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"}
|
||||
)
|
||||
|
||||
@patch("core.moderation.output_moderation.ModerationFactory")
|
||||
def test_moderation_exception(self, mock_factory_class, output_moderation):
|
||||
mock_factory_class.side_effect = Exception("error")
|
||||
|
||||
result = output_moderation.moderation("tenant", "app", "buffer")
|
||||
assert result is None
|
||||
|
||||
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
# Test exit on thread_running=False
|
||||
output_moderation.thread_running = False
|
||||
output_moderation.worker(mock_app, 10)
|
||||
# Should exit immediately
|
||||
|
||||
def test_worker_no_flag(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
|
||||
output_moderation.buffer = "safe"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
# To avoid infinite loop, we'll set thread_running to False after one iteration
|
||||
def side_effect(*args, **kwargs):
|
||||
output_moderation.thread_running = False
|
||||
return mock_moderation.return_value
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
assert mock_moderation.called
|
||||
|
||||
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
|
||||
)
|
||||
|
||||
output_moderation.buffer = "badword"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
assert output_moderation.final_output == "preset"
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
# It breaks on DIRECT_OUTPUT
|
||||
|
||||
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
# Use side_effect to change thread_running on second call
|
||||
def side_effect(*args, **kwargs):
|
||||
if mock_moderation.call_count > 1:
|
||||
output_moderation.thread_running = False
|
||||
return None
|
||||
return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked")
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.buffer = "badword"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert args[0].text == "masked"
|
||||
|
||||
def test_worker_chunk_too_small(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# chunk_length < buffer_size and not is_final_chunk
|
||||
output_moderation.buffer = "123" # length 3
|
||||
output_moderation.is_final_chunk = False
|
||||
|
||||
def sleep_side_effect(seconds):
|
||||
output_moderation.thread_running = False
|
||||
|
||||
mock_sleep.side_effect = sleep_side_effect
|
||||
|
||||
output_moderation.worker(mock_app, 10) # buffer_size 10
|
||||
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
# Return None (exception or no rule)
|
||||
mock_moderation.return_value = None
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
output_moderation.thread_running = False
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.buffer = "something"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
@ -0,0 +1,160 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import (
|
||||
TidbOnQdrantConfig,
|
||||
TidbOnQdrantVector,
|
||||
)
|
||||
|
||||
|
||||
class TestTidbOnQdrantVectorDeleteByIds:
|
||||
"""Unit tests for TidbOnQdrantVector.delete_by_ids method."""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_instance(self):
|
||||
"""Create a TidbOnQdrantVector instance for testing."""
|
||||
config = TidbOnQdrantConfig(
|
||||
endpoint="http://localhost:6333",
|
||||
api_key="test_api_key",
|
||||
)
|
||||
|
||||
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"):
|
||||
vector = TidbOnQdrantVector(
|
||||
collection_name="test_collection",
|
||||
group_id="test_group",
|
||||
config=config,
|
||||
)
|
||||
return vector
|
||||
|
||||
def test_delete_by_ids_with_multiple_ids(self, vector_instance):
|
||||
"""Test batch deletion with multiple document IDs."""
|
||||
ids = ["doc1", "doc2", "doc3"]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify that delete was called once with MatchAny filter
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
# Check collection name
|
||||
assert call_args[1]["collection_name"] == "test_collection"
|
||||
|
||||
# Verify filter uses MatchAny with all IDs
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
assert len(filter_obj.must) == 1
|
||||
|
||||
field_condition = filter_obj.must[0]
|
||||
assert field_condition.key == "metadata.doc_id"
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"}
|
||||
|
||||
def test_delete_by_ids_with_single_id(self, vector_instance):
|
||||
"""Test deletion with a single document ID."""
|
||||
ids = ["doc1"]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify that delete was called once
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
# Verify filter uses MatchAny with single ID
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
field_condition = filter_obj.must[0]
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert field_condition.match.any == ["doc1"]
|
||||
|
||||
def test_delete_by_ids_with_empty_list(self, vector_instance):
|
||||
"""Test deletion with empty ID list returns early without API call."""
|
||||
vector_instance.delete_by_ids([])
|
||||
|
||||
# Verify that delete was NOT called
|
||||
vector_instance._client.delete.assert_not_called()
|
||||
|
||||
def test_delete_by_ids_with_404_error(self, vector_instance):
|
||||
"""Test that 404 errors (collection not found) are handled gracefully."""
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
# Mock a 404 error
|
||||
error = UnexpectedResponse(
|
||||
status_code=404,
|
||||
reason_phrase="Not Found",
|
||||
content=b"Collection not found",
|
||||
headers=httpx.Headers(),
|
||||
)
|
||||
vector_instance._client.delete.side_effect = error
|
||||
|
||||
# Should not raise an exception
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify delete was called
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
|
||||
def test_delete_by_ids_with_unexpected_error(self, vector_instance):
|
||||
"""Test that non-404 errors are re-raised."""
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
# Mock a 500 error
|
||||
error = UnexpectedResponse(
|
||||
status_code=500,
|
||||
reason_phrase="Internal Server Error",
|
||||
content=b"Server error",
|
||||
headers=httpx.Headers(),
|
||||
)
|
||||
vector_instance._client.delete.side_effect = error
|
||||
|
||||
# Should re-raise the exception
|
||||
with pytest.raises(UnexpectedResponse) as exc_info:
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
def test_delete_by_ids_with_large_batch(self, vector_instance):
|
||||
"""Test deletion with a large batch of IDs."""
|
||||
# Create 1000 IDs
|
||||
ids = [f"doc_{i}" for i in range(1000)]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify single delete call with all IDs
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
field_condition = filter_obj.must[0]
|
||||
|
||||
# Verify all 1000 IDs are in the batch
|
||||
assert len(field_condition.match.any) == 1000
|
||||
assert "doc_0" in field_condition.match.any
|
||||
assert "doc_999" in field_condition.match.any
|
||||
|
||||
def test_delete_by_ids_filter_structure(self, vector_instance):
|
||||
"""Test that the filter structure is correctly constructed."""
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
|
||||
# Verify Filter structure
|
||||
assert isinstance(filter_obj, rest.Filter)
|
||||
assert filter_obj.must is not None
|
||||
assert len(filter_obj.must) == 1
|
||||
|
||||
# Verify FieldCondition structure
|
||||
field_condition = filter_obj.must[0]
|
||||
assert isinstance(field_condition, rest.FieldCondition)
|
||||
assert field_condition.key == "metadata.doc_id"
|
||||
|
||||
# Verify MatchAny structure
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert field_condition.match.any == ids
|
||||
@ -0,0 +1,677 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormRepositoryImpl,
|
||||
HumanInputFormSubmissionRepository,
|
||||
_HumanInputFormEntityImpl,
|
||||
_HumanInputFormRecipientEntityImpl,
|
||||
_InvalidTimeoutStatusError,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _FakeSelect:
|
||||
def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect())
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader")
|
||||
|
||||
|
||||
def _make_form_definition_json(*, include_expiration_time: bool) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"form_content": "hi",
|
||||
"inputs": [],
|
||||
"user_actions": [{"id": "submit", "title": "Submit"}],
|
||||
"rendered_content": "<p>hi</p>",
|
||||
}
|
||||
if include_expiration_time:
|
||||
payload["expiration_time"] = naive_utc_now()
|
||||
return json.dumps(payload, default=str)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyForm:
|
||||
id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_definition: str
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
|
||||
selected_action_id: str | None = None
|
||||
submitted_data: str | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
completed_by_recipient_id: str | None = None
|
||||
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyRecipient:
|
||||
id: str
|
||||
form_id: str
|
||||
recipient_type: RecipientType
|
||||
access_token: str | None
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def first(self) -> Any:
|
||||
if isinstance(self._obj, list):
|
||||
return self._obj[0] if self._obj else None
|
||||
return self._obj
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
if self._obj is None:
|
||||
return []
|
||||
if isinstance(self._obj, list):
|
||||
return list(self._obj)
|
||||
return [self._obj]
|
||||
|
||||
|
||||
class _FakeExecuteResult:
|
||||
def __init__(self, rows: Sequence[tuple[Any, ...]]):
|
||||
self._rows = list(rows)
|
||||
|
||||
def all(self) -> list[tuple[Any, ...]]:
|
||||
return list(self._rows)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scalars_result: Any = None,
|
||||
scalars_results: list[Any] | None = None,
|
||||
forms: dict[str, _DummyForm] | None = None,
|
||||
recipients: dict[str, _DummyRecipient] | None = None,
|
||||
execute_rows: Sequence[tuple[Any, ...]] = (),
|
||||
):
|
||||
if scalars_results is not None:
|
||||
self._scalars_queue = list(scalars_results)
|
||||
else:
|
||||
self._scalars_queue = [scalars_result]
|
||||
self._forms = forms or {}
|
||||
self._recipients = recipients or {}
|
||||
self._execute_rows = list(execute_rows)
|
||||
self.added: list[Any] = []
|
||||
|
||||
def scalars(self, _query: Any) -> _FakeScalarResult:
|
||||
if self._scalars_queue:
|
||||
value = self._scalars_queue.pop(0)
|
||||
else:
|
||||
value = None
|
||||
return _FakeScalarResult(value)
|
||||
|
||||
def execute(self, _stmt: Any) -> _FakeExecuteResult:
|
||||
return _FakeExecuteResult(self._execute_rows)
|
||||
|
||||
def get(self, model_cls: Any, obj_id: str) -> Any:
|
||||
name = getattr(model_cls, "__name__", "")
|
||||
if name == "HumanInputForm":
|
||||
return self._forms.get(obj_id)
|
||||
if name == "HumanInputFormRecipient":
|
||||
return self._recipients.get(obj_id)
|
||||
return None
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def add_all(self, objs: Sequence[Any]) -> None:
|
||||
self.added.extend(list(objs))
|
||||
|
||||
def flush(self) -> None:
|
||||
# Simulate DB default population for attributes referenced in entity wrappers.
|
||||
for obj in self.added:
|
||||
if hasattr(obj, "id") and obj.id in (None, ""):
|
||||
obj.id = f"gen-{len(str(self.added))}"
|
||||
if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None:
|
||||
if obj.recipient_type == RecipientType.CONSOLE:
|
||||
obj.access_token = "token-console"
|
||||
elif obj.recipient_type == RecipientType.BACKSTAGE:
|
||||
obj.access_token = "token-backstage"
|
||||
else:
|
||||
obj.access_token = "token-webapp"
|
||||
|
||||
def refresh(self, _obj: Any) -> None:
|
||||
return None
|
||||
|
||||
def begin(self) -> _FakeSession:
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _FakeSession:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _SessionFactoryStub:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def create_session(self) -> _FakeSession:
|
||||
return self._session
|
||||
|
||||
|
||||
def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session))
|
||||
|
||||
|
||||
def test_recipient_entity_token_raises_when_missing() -> None:
|
||||
recipient = SimpleNamespace(id="r1", access_token=None)
|
||||
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
|
||||
with pytest.raises(AssertionError, match="access_token should not be None"):
|
||||
_ = entity.token
|
||||
|
||||
|
||||
def test_recipient_entity_id_and_token_success() -> None:
|
||||
recipient = SimpleNamespace(id="r1", access_token="tok")
|
||||
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
|
||||
assert entity.id == "r1"
|
||||
assert entity.token == "tok"
|
||||
|
||||
|
||||
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok")
|
||||
webapp = _DummyRecipient(
|
||||
id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"
|
||||
)
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "ctok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "wtok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token is None
|
||||
|
||||
|
||||
def test_form_entity_submitted_data_parsed() -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
submitted_data='{"a": 1}',
|
||||
submitted_at=naive_utc_now(),
|
||||
)
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
|
||||
assert entity.submitted is True
|
||||
assert entity.submitted_data == {"a": 1}
|
||||
assert entity.rendered_content == "<p>x</p>"
|
||||
assert entity.selected_action_id is None
|
||||
assert entity.status == HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
def test_form_record_from_models_injects_expiration_time_when_missing() -> None:
|
||||
expiration = naive_utc_now()
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=False),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=expiration,
|
||||
submitted_data='{"k": "v"}',
|
||||
)
|
||||
record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type]
|
||||
assert record.definition.expiration_time == expiration
|
||||
assert record.submitted_data == {"k": "v"}
|
||||
assert record.submitted is False
|
||||
|
||||
|
||||
def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created: list[SimpleNamespace] = []
|
||||
|
||||
def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def]
|
||||
recipient = SimpleNamespace(
|
||||
id=f"{payload.TYPE}-{len(created)}",
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=payload.TYPE,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
access_token="tok",
|
||||
)
|
||||
created.append(recipient)
|
||||
return recipient
|
||||
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new))
|
||||
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined]
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
members=[
|
||||
_WorkspaceMemberInfo(user_id="u1", email=""),
|
||||
_WorkspaceMemberInfo(user_id="u2", email="a@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="u3", email="a@example.com"),
|
||||
],
|
||||
external_emails=["", "a@example.com", "b@example.com", "b@example.com"],
|
||||
)
|
||||
assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL]
|
||||
|
||||
|
||||
def test_query_workspace_members_by_ids_empty_returns_empty() -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == []
|
||||
|
||||
|
||||
def test_query_workspace_members_by_ids_maps_rows() -> None:
|
||||
session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")])
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"])
|
||||
assert rows == [
|
||||
_WorkspaceMemberInfo(user_id="u1", email="a@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="u2", email="b@example.com"),
|
||||
]
|
||||
|
||||
|
||||
def test_query_all_workspace_members_maps_rows() -> None:
|
||||
session = _FakeSession(execute_rows=[("u1", "a@example.com")])
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
rows = repo._query_all_workspace_members(session=session)
|
||||
assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
|
||||
|
||||
|
||||
def test_repository_init_sets_tenant_id() -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo._tenant_id == "tenant"
|
||||
|
||||
|
||||
def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
|
||||
result = repo._delivery_method_to_model(
|
||||
session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod()
|
||||
)
|
||||
assert result.delivery.id == "del-1"
|
||||
assert result.delivery.form_id == "form-1"
|
||||
assert len(result.recipients) == 1
|
||||
assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
|
||||
|
||||
def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
|
||||
called: dict[str, Any] = {}
|
||||
|
||||
def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]:
|
||||
called.update(
|
||||
{"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config}
|
||||
)
|
||||
return ["r"]
|
||||
|
||||
monkeypatch.setattr(repo, "_build_email_recipients", fake_build)
|
||||
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
subject="s",
|
||||
body="b",
|
||||
)
|
||||
)
|
||||
result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method)
|
||||
assert result.recipients == ["r"]
|
||||
assert called["delivery_id"] == "del-1"
|
||||
|
||||
|
||||
def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr(
|
||||
repo,
|
||||
"_query_all_workspace_members",
|
||||
lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")],
|
||||
)
|
||||
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
|
||||
recipients = repo._build_email_recipients(
|
||||
session=MagicMock(),
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
||||
|
||||
def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
|
||||
def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]:
|
||||
assert restrict_to_user_ids == ["u1"]
|
||||
return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
|
||||
|
||||
monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query)
|
||||
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
|
||||
recipients = repo._build_email_recipients(
|
||||
session=MagicMock(),
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
||||
|
||||
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo.get_form("run", "node") is None
|
||||
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="r1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="tok",
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, [recipient]])
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
entity = repo.get_form("run", "node")
|
||||
assert entity is not None
|
||||
assert entity.id == "f1"
|
||||
assert entity.recipients[0].id == "r1"
|
||||
assert entity.recipients[0].token == "tok"
|
||||
|
||||
|
||||
def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
ids = iter(["form-id", "del-web", "del-console", "del-backstage"])
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids))
|
||||
|
||||
session = _FakeSession()
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
|
||||
form_config = HumanInputNodeData(
|
||||
title="Title",
|
||||
delivery_methods=[],
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
params = FormCreateParams(
|
||||
app_id="app",
|
||||
workflow_execution_id="run",
|
||||
node_id="node",
|
||||
form_config=form_config,
|
||||
rendered_content="<p>hello</p>",
|
||||
delivery_methods=[WebAppDeliveryMethod()],
|
||||
display_in_ui=True,
|
||||
resolved_default_values={},
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
console_recipient_required=True,
|
||||
console_creator_account_id="acc-1",
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
|
||||
entity = repo.create_form(params)
|
||||
assert entity.id == "form-id"
|
||||
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
|
||||
# Console token should take precedence when console recipient is present.
|
||||
assert entity.web_app_token == "token-console"
|
||||
assert len(entity.recipients) == 3
|
||||
|
||||
|
||||
def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_token("tok") is None
|
||||
|
||||
recipient = SimpleNamespace(form=None)
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_token("tok") is None
|
||||
|
||||
|
||||
def test_submission_repository_init_no_args() -> None:
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert isinstance(repo, HumanInputFormSubmissionRepository)
|
||||
|
||||
|
||||
def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = SimpleNamespace(
|
||||
id="r1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="tok",
|
||||
form=form,
|
||||
)
|
||||
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.get_by_token("tok")
|
||||
assert record is not None
|
||||
assert record.access_token == "tok"
|
||||
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP)
|
||||
assert record is not None
|
||||
assert record.recipient_id == "r1"
|
||||
|
||||
|
||||
def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None
|
||||
|
||||
|
||||
def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
missing_session = _FakeSession(forms={})
|
||||
_patch_session_factory(monkeypatch, missing_session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form not found"):
|
||||
repo.mark_submitted(
|
||||
form_id="missing",
|
||||
recipient_id=None,
|
||||
selected_action_id="a",
|
||||
form_data={},
|
||||
submission_user_id=None,
|
||||
submission_end_user_id=None,
|
||||
)
|
||||
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=fixed_now,
|
||||
)
|
||||
recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok")
|
||||
session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=recipient.id,
|
||||
selected_action_id="approve",
|
||||
form_data={"k": "v"},
|
||||
submission_user_id="u",
|
||||
submission_end_user_id="eu",
|
||||
)
|
||||
assert form.status == HumanInputFormStatus.SUBMITTED
|
||||
assert form.submitted_at == fixed_now
|
||||
assert record.submitted_data == {"k": "v"}
|
||||
|
||||
|
||||
def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"):
|
||||
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
status=HumanInputFormStatus.TIMEOUT,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r")
|
||||
assert record.status == HumanInputFormStatus.TIMEOUT
|
||||
|
||||
|
||||
def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form already submitted"):
|
||||
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
|
||||
|
||||
|
||||
def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
selected_action_id="a",
|
||||
submitted_data="{}",
|
||||
submission_user_id="u",
|
||||
submission_end_user_id="eu",
|
||||
completed_by_recipient_id="r",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
|
||||
assert form.status == HumanInputFormStatus.EXPIRED
|
||||
assert form.selected_action_id is None
|
||||
assert form.submitted_data is None
|
||||
assert form.submission_user_id is None
|
||||
assert form.submission_end_user_id is None
|
||||
assert form.completed_by_recipient_id is None
|
||||
assert record.status == HumanInputFormStatus.EXPIRED
|
||||
|
||||
|
||||
def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(forms={}))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form not found"):
|
||||
repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT)
|
||||
@ -1,84 +1,291 @@
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
import pytest
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from models import Account, WorkflowRun
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from models import Account, CreatorUserRole, EndUser, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = str(uuid4())
|
||||
user.current_tenant_id = str(uuid4())
|
||||
|
||||
repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=real_session_factory,
|
||||
user=user,
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = False
|
||||
repository._session_factory = MagicMock(return_value=session_context)
|
||||
return repository
|
||||
|
||||
|
||||
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
|
||||
return WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "hello"},
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist():
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
session_factory = MagicMock(spec=sessionmaker)
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
session_factory.return_value.__enter__.return_value = session
|
||||
return session_factory
|
||||
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists():
|
||||
session = MagicMock()
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""Mock SQLAlchemy Engine."""
|
||||
return MagicMock(spec=Engine)
|
||||
|
||||
execution_id = str(uuid4())
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repository._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
session.get.return_value = existing_run
|
||||
|
||||
execution = _build_execution(
|
||||
execution_id=execution_id,
|
||||
started_at=datetime(2026, 1, 1, 12, 30, 0),
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = MagicMock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_execution():
|
||||
"""Sample WorkflowExecution for testing."""
|
||||
return WorkflowExecution(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
outputs={"output1": "result1"},
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
error_message="",
|
||||
total_tokens=100,
|
||||
total_steps=5,
|
||||
exceptions_count=0,
|
||||
started_at=datetime.now(UTC),
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
class TestSQLAlchemyWorkflowExecutionRepository:
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
app_id = "test_app_id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
assert repo._session_factory == mock_session_factory
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role == CreatorUserRole.ACCOUNT
|
||||
|
||||
def test_init_with_engine(self, mock_engine, mock_account):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_engine,
|
||||
user=mock_account,
|
||||
app_id="test_app_id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert isinstance(repo._session_factory, sessionmaker)
|
||||
assert repo._session_factory.kw["bind"] == mock_engine
|
||||
|
||||
def test_init_invalid_session_factory(self, mock_account):
|
||||
with pytest.raises(ValueError, match="Invalid session_factory type"):
|
||||
SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory="invalid", user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
def test_init_no_tenant_id(self, mock_session_factory):
|
||||
user = MagicMock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None
|
||||
)
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
assert repo._creator_user_role == CreatorUserRole.END_USER
|
||||
|
||||
def test_to_domain_model(self, mock_session_factory, mock_account):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
db_model = MagicMock(spec=WorkflowRun)
|
||||
db_model.id = str(uuid4())
|
||||
db_model.workflow_id = str(uuid4())
|
||||
db_model.type = "workflow"
|
||||
db_model.version = "1.0"
|
||||
db_model.inputs_dict = {"in": "val"}
|
||||
db_model.outputs_dict = {"out": "val"}
|
||||
db_model.graph_dict = {"nodes": []}
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = "some error"
|
||||
db_model.total_tokens = 50
|
||||
db_model.total_steps = 3
|
||||
db_model.exceptions_count = 1
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = datetime.now(UTC)
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
assert domain_model.id_ == db_model.id
|
||||
assert domain_model.workflow_id == db_model.workflow_id
|
||||
assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert domain_model.inputs == db_model.inputs_dict
|
||||
assert domain_model.error_message == "some error"
|
||||
|
||||
def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
|
||||
# Make elapsed time deterministic to avoid flaky tests
|
||||
sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC)
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
assert db_model.id == sample_workflow_execution.id_
|
||||
assert db_model.tenant_id == repo._tenant_id
|
||||
assert db_model.app_id == "test_app"
|
||||
assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
assert db_model.status == sample_workflow_execution.status.value
|
||||
assert db_model.total_tokens == sample_workflow_execution.total_tokens
|
||||
assert db_model.elapsed_time == 10.0
|
||||
|
||||
def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Test with empty/None fields
|
||||
sample_workflow_execution.graph = None
|
||||
sample_workflow_execution.inputs = None
|
||||
sample_workflow_execution.outputs = None
|
||||
sample_workflow_execution.error_message = None
|
||||
sample_workflow_execution.finished_at = None
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
assert db_model.graph is None
|
||||
assert db_model.inputs is None
|
||||
assert db_model.outputs is None
|
||||
assert db_model.error is None
|
||||
assert db_model.elapsed_time == 0
|
||||
|
||||
def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=None,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
assert not hasattr(db_model, "app_id") or db_model.app_id is None
|
||||
assert db_model.tenant_id == repo._tenant_id
|
||||
|
||||
def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
# Test triggered_from missing
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
repo._creator_user_id = "some_id"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
def test_save(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.merge.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
# Check cache
|
||||
assert sample_workflow_execution.id_ in repo._execution_cache
|
||||
cached_model = repo._execution_cache[sample_workflow_execution.id_]
|
||||
assert cached_model.id == sample_workflow_execution.id_
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
sample_workflow_execution.started_at = started_at
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.get.return_value = None
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
execution_id = sample_workflow_execution.id_
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repo._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.get.return_value = existing_run
|
||||
|
||||
sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
|
||||
@ -0,0 +1,772 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
_deterministic_json_dump,
|
||||
_filter_by_offload_type,
|
||||
_find_first,
|
||||
_replace_or_append_offload,
|
||||
)
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from models import Account, EndUser
|
||||
from models.enums import ExecutionOffLoadType
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account:
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
return user
|
||||
|
||||
|
||||
def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser:
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = user_id
|
||||
user.tenant_id = tenant_id
|
||||
return user
|
||||
|
||||
|
||||
def _execution(
|
||||
*,
|
||||
execution_id: str = "exec-id",
|
||||
node_execution_id: str = "node-exec-id",
|
||||
workflow_run_id: str = "run-id",
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
node_execution_id=node_execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Title",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
status=status,
|
||||
error=None,
|
||||
elapsed_time=1.0,
|
||||
metadata=metadata,
|
||||
created_at=datetime.now(UTC),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
|
||||
class _SessionCtx:
|
||||
def __init__(self, session: Any):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _session_factory(session: Any) -> sessionmaker:
|
||||
factory = Mock(spec=sessionmaker)
|
||||
factory.return_value = _SessionCtx(session)
|
||||
return factory
|
||||
|
||||
|
||||
def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
engine: Engine = create_engine("sqlite:///:memory:")
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=engine,
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert isinstance(repo._session_factory, sessionmaker)
|
||||
|
||||
sm = Mock(spec=sessionmaker)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=sm,
|
||||
user=_mock_end_user(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert repo._creator_user_role.value == "end_user"
|
||||
|
||||
|
||||
def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Invalid session_factory type"):
|
||||
SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type]
|
||||
session_factory=object(),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
user = _mock_account()
|
||||
user.current_tenant_id = None
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=user,
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created: dict[str, Any] = {}
|
||||
|
||||
class FakeTruncator:
|
||||
def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int):
|
||||
created.update(
|
||||
{
|
||||
"max_size_bytes": max_size_bytes,
|
||||
"array_element_limit": array_element_limit,
|
||||
"string_length_limit": string_length_limit,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator",
|
||||
FakeTruncator,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
_ = repo._create_truncator()
|
||||
assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE
|
||||
|
||||
|
||||
def test_helpers_find_first_and_replace_or_append_and_filter() -> None:
|
||||
assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}'
|
||||
assert _find_first([], lambda _: True) is None
|
||||
assert _find_first([1, 2, 3], lambda x: x > 1) == 2
|
||||
|
||||
off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2
|
||||
|
||||
replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS))
|
||||
assert len(replaced) == 2
|
||||
assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS]
|
||||
|
||||
|
||||
def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1})
|
||||
|
||||
# Happy path: deterministic json dump should be sorted
|
||||
db_model = repo._to_db_model(execution)
|
||||
assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1}
|
||||
assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1
|
||||
|
||||
repo._triggered_from = None
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
|
||||
def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution()
|
||||
db_model = repo._to_db_model(execution)
|
||||
assert db_model.app_id == "app"
|
||||
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
repo._creator_user_id = "user"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
|
||||
def test_is_duplicate_key_error_and_regenerate_id(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
unique = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError("dup", params=None, orig=unique)
|
||||
assert repo._is_duplicate_key_error(duplicate_error) is True
|
||||
assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False
|
||||
|
||||
execution = _execution(execution_id="old-id")
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "old-id"
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
|
||||
caplog.set_level(logging.WARNING)
|
||||
repo._regenerate_id_on_duplicate(execution, db_model)
|
||||
assert execution.id == "new-id"
|
||||
assert db_model.id == "new-id"
|
||||
assert any("Duplicate key conflict" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id1"
|
||||
db_model.node_execution_id = "node1"
|
||||
db_model.foo = "bar" # type: ignore[attr-defined]
|
||||
db_model.__dict__["_private"] = "x"
|
||||
|
||||
existing = SimpleNamespace()
|
||||
session.get.return_value = existing
|
||||
repo._persist_to_database(db_model)
|
||||
assert existing.foo == "bar"
|
||||
session.add.assert_not_called()
|
||||
assert repo._node_execution_cache["node1"] is db_model
|
||||
|
||||
session.reset_mock()
|
||||
session.get.return_value = None
|
||||
repo._node_execution_cache.clear()
|
||||
repo._persist_to_database(db_model)
|
||||
session.add.assert_called_once_with(db_model)
|
||||
assert repo._node_execution_cache["node1"] is db_model
|
||||
|
||||
|
||||
def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None
|
||||
|
||||
class FakeTruncator:
|
||||
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
|
||||
return value, False
|
||||
|
||||
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
|
||||
assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None
|
||||
|
||||
|
||||
def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
uploaded: dict[str, Any] = {}
|
||||
|
||||
class FakeFileService:
|
||||
def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def]
|
||||
uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user})
|
||||
return SimpleNamespace(id="file-id", key="file-key")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService()
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id")
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
class FakeTruncator:
|
||||
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
|
||||
return {"truncated": True}, True
|
||||
|
||||
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
|
||||
|
||||
result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS)
|
||||
assert result is not None
|
||||
assert result.truncated_value == {"truncated": True}
|
||||
assert uploaded["filename"].startswith("node_execution_exec_inputs.json")
|
||||
assert result.offload.file_id == "file-id"
|
||||
assert result.offload.type_ == ExecutionOffLoadType.INPUTS
|
||||
|
||||
|
||||
def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id"
|
||||
db_model.node_execution_id = "node-exec"
|
||||
db_model.workflow_id = "wf"
|
||||
db_model.workflow_run_id = "run"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"trunc": "i"})
|
||||
db_model.process_data = json.dumps({"trunc": "p"})
|
||||
db_model.outputs = json.dumps({"trunc": "o"})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 0.1
|
||||
db_model.execution_metadata = json.dumps({"total_tokens": 3})
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
|
||||
off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
|
||||
off_in.file = SimpleNamespace(key="k-in")
|
||||
off_out.file = SimpleNamespace(key="k-out")
|
||||
off_proc.file = SimpleNamespace(key="k-proc")
|
||||
db_model.offload_data = [off_out, off_in, off_proc]
|
||||
|
||||
def fake_load(key: str) -> bytes:
|
||||
return json.dumps({"full": key}).encode()
|
||||
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load)
|
||||
|
||||
domain = repo._to_domain_model(db_model)
|
||||
assert domain.inputs == {"full": "k-in"}
|
||||
assert domain.outputs == {"full": "k-out"}
|
||||
assert domain.process_data == {"full": "k-proc"}
|
||||
assert domain.get_truncated_inputs() == {"trunc": "i"}
|
||||
assert domain.get_truncated_outputs() == {"trunc": "o"}
|
||||
assert domain.get_truncated_process_data() == {"trunc": "p"}
|
||||
|
||||
|
||||
def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id"
|
||||
db_model.node_execution_id = "node-exec"
|
||||
db_model.workflow_id = "wf"
|
||||
db_model.workflow_run_id = "run"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"i": 1})
|
||||
db_model.process_data = json.dumps({"p": 2})
|
||||
db_model.outputs = json.dumps({"o": 3})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 0.1
|
||||
db_model.execution_metadata = "{}"
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
db_model.offload_data = []
|
||||
|
||||
domain = repo._to_domain_model(db_model)
|
||||
assert domain.inputs == {"i": 1}
|
||||
assert domain.outputs == {"o": 3}
|
||||
|
||||
|
||||
def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FakeConverter:
|
||||
def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
return {"wrapped": values["a"]}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter",
|
||||
FakeConverter,
|
||||
)
|
||||
assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}'
|
||||
|
||||
|
||||
def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace(
|
||||
id="id",
|
||||
offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)],
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
process_data=None,
|
||||
)
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
|
||||
|
||||
trunc_result = SimpleNamespace(
|
||||
truncated_value={"trunc": True},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None
|
||||
)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
# Inputs should be truncated, outputs/process_data encoded directly
|
||||
db_model = session.merge.call_args.args[0]
|
||||
assert json.loads(db_model.inputs) == {"trunc": True}
|
||||
assert json.loads(db_model.outputs) == {"b": 2}
|
||||
assert json.loads(db_model.process_data) == {"c": 3}
|
||||
assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data)
|
||||
assert execution.get_truncated_inputs() == {"trunc": True}
|
||||
|
||||
|
||||
def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
existing = SimpleNamespace(
|
||||
id="id",
|
||||
offload_data=[],
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
process_data=None,
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = existing
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
|
||||
|
||||
def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any:
|
||||
if values == {"b": 2}:
|
||||
return SimpleNamespace(
|
||||
truncated_value={"b": "trunc"},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"),
|
||||
)
|
||||
if values == {"c": 3}:
|
||||
return SimpleNamespace(
|
||||
truncated_value={"c": "trunc"},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"),
|
||||
)
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(repo, "_truncate_and_upload", trunc)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
db_model = session.merge.call_args.args[0]
|
||||
assert json.loads(db_model.outputs) == {"b": "trunc"}
|
||||
assert json.loads(db_model.process_data) == {"c": "trunc"}
|
||||
assert execution.get_truncated_outputs() == {"b": "trunc"}
|
||||
assert execution.get_truncated_process_data() == {"c": "trunc"}
|
||||
|
||||
|
||||
def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1})
|
||||
fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None)
|
||||
monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model)
|
||||
monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
merged = session.merge.call_args.args[0]
|
||||
assert merged.inputs == '{"a": 1}'
|
||||
|
||||
|
||||
def test_save_retries_duplicate_and_logs_non_duplicate(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(execution_id="id")
|
||||
unique = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError("dup", params=None, orig=unique)
|
||||
other_error = IntegrityError("other", params=None, orig=None)
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
def persist(_db_model: Any) -> None:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
raise duplicate_error
|
||||
|
||||
monkeypatch.setattr(repo, "_persist_to_database", persist)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
|
||||
repo.save(execution)
|
||||
assert execution.id == "new-id"
|
||||
assert repo._node_execution_cache[execution.node_execution_id] is not None
|
||||
|
||||
caplog.set_level(logging.ERROR)
|
||||
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error))
|
||||
with pytest.raises(IntegrityError):
|
||||
repo.save(_execution(execution_id="id2", node_execution_id="node2"))
|
||||
assert any("Non-duplicate key integrity error" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_save_logs_and_reraises_on_unexpected_error(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
caplog.set_level(logging.ERROR)
|
||||
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
repo.save(_execution(execution_id="id3", node_execution_id="node3"))
|
||||
assert any("Failed to save workflow node execution" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
class FakeStmt:
|
||||
def __init__(self) -> None:
|
||||
self.where_calls = 0
|
||||
self.order_by_args: tuple[Any, ...] | None = None
|
||||
|
||||
def where(self, *_args: Any) -> FakeStmt:
|
||||
self.where_calls += 1
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any) -> FakeStmt:
|
||||
self.order_by_args = args
|
||||
return self
|
||||
|
||||
stmt = FakeStmt()
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
|
||||
lambda _q: stmt,
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
|
||||
|
||||
model1 = SimpleNamespace(node_execution_id="n1")
|
||||
model2 = SimpleNamespace(node_execution_id=None)
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [model1, model2]
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
order = OrderConfig(order_by=["index", "missing"], order_direction="desc")
|
||||
db_models = repo.get_db_models_by_workflow_run("run", order)
|
||||
assert db_models == [model1, model2]
|
||||
assert repo._node_execution_cache["n1"] is model1
|
||||
assert stmt.order_by_args is not None
|
||||
|
||||
|
||||
def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
class FakeStmt:
|
||||
def where(self, *_args: Any) -> FakeStmt:
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any) -> FakeStmt:
|
||||
self.args = args # type: ignore[attr-defined]
|
||||
return self
|
||||
|
||||
stmt = FakeStmt()
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
|
||||
lambda _q: stmt,
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc"))
|
||||
|
||||
|
||||
def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")]
|
||||
monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models)
|
||||
monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}")
|
||||
|
||||
class FakeExecutor:
|
||||
def __enter__(self) -> FakeExecutor:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
def map(self, func, items, timeout: int): # type: ignore[no-untyped-def]
|
||||
assert timeout == 30
|
||||
return list(map(func, items))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor",
|
||||
lambda max_workers: FakeExecutor(),
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_run("run", order_config=None)
|
||||
assert result == ["domain:db1", "domain:db2"]
|
||||
137
api/tests/unit_tests/core/schemas/test_registry.py
Normal file
137
api/tests/unit_tests/core/schemas/test_registry.py
Normal file
@ -0,0 +1,137 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
|
||||
|
||||
class TestSchemaRegistry:
|
||||
def test_initialization(self, tmp_path):
|
||||
base_dir = tmp_path / "schemas"
|
||||
base_dir.mkdir()
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
assert registry.base_dir == base_dir
|
||||
assert registry.versions == {}
|
||||
assert registry.metadata == {}
|
||||
|
||||
def test_default_registry_singleton(self):
|
||||
registry1 = SchemaRegistry.default_registry()
|
||||
registry2 = SchemaRegistry.default_registry()
|
||||
assert registry1 is registry2
|
||||
assert isinstance(registry1, SchemaRegistry)
|
||||
|
||||
def test_load_all_versions_non_existent_dir(self, tmp_path):
|
||||
base_dir = tmp_path / "non_existent"
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
registry.load_all_versions()
|
||||
assert registry.versions == {}
|
||||
|
||||
def test_load_all_versions_filtering(self, tmp_path):
|
||||
base_dir = tmp_path / "schemas"
|
||||
base_dir.mkdir()
|
||||
(base_dir / "not_a_version_dir").mkdir()
|
||||
(base_dir / "v1").mkdir()
|
||||
(base_dir / "some_file.txt").write_text("content")
|
||||
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
with patch.object(registry, "_load_version_dir") as mock_load:
|
||||
registry.load_all_versions()
|
||||
mock_load.assert_called_once()
|
||||
assert mock_load.call_args[0][0] == "v1"
|
||||
|
||||
def test_load_version_dir_filtering(self, tmp_path):
|
||||
version_dir = tmp_path / "v1"
|
||||
version_dir.mkdir()
|
||||
(version_dir / "schema1.json").write_text("{}")
|
||||
(version_dir / "not_a_schema.txt").write_text("content")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
with patch.object(registry, "_load_schema") as mock_load:
|
||||
registry._load_version_dir("v1", version_dir)
|
||||
mock_load.assert_called_once()
|
||||
assert mock_load.call_args[0][1] == "schema1"
|
||||
|
||||
def test_load_version_dir_non_existent(self, tmp_path):
|
||||
version_dir = tmp_path / "non_existent"
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry._load_version_dir("v1", version_dir)
|
||||
assert "v1" not in registry.versions
|
||||
|
||||
def test_load_schema_success(self, tmp_path):
|
||||
schema_path = tmp_path / "test.json"
|
||||
schema_content = {"title": "Test Schema", "description": "A test schema"}
|
||||
schema_path.write_text(json.dumps(schema_content))
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
registry._load_schema("v1", "test", schema_path)
|
||||
|
||||
assert registry.versions["v1"]["test"] == schema_content
|
||||
uri = "https://dify.ai/schemas/v1/test.json"
|
||||
assert registry.metadata[uri]["title"] == "Test Schema"
|
||||
assert registry.metadata[uri]["version"] == "v1"
|
||||
|
||||
def test_load_schema_invalid_json(self, tmp_path, caplog):
|
||||
schema_path = tmp_path / "invalid.json"
|
||||
schema_path.write_text("invalid json")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
registry._load_schema("v1", "invalid", schema_path)
|
||||
|
||||
assert "Failed to load schema v1/invalid" in caplog.text
|
||||
|
||||
def test_load_schema_os_error(self, tmp_path, caplog):
|
||||
schema_path = tmp_path / "error.json"
|
||||
schema_path.write_text("{}")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
|
||||
with patch("builtins.open", side_effect=OSError("Read error")):
|
||||
registry._load_schema("v1", "error", schema_path)
|
||||
|
||||
assert "Failed to load schema v1/error" in caplog.text
|
||||
|
||||
def test_get_schema(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"test": {"type": "object"}}}
|
||||
|
||||
# Valid URI
|
||||
assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"}
|
||||
|
||||
# Invalid URI
|
||||
assert registry.get_schema("invalid-uri") is None
|
||||
|
||||
# Missing version
|
||||
assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None
|
||||
|
||||
def test_list_versions(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v2": {}, "v1": {}}
|
||||
assert registry.list_versions() == ["v1", "v2"]
|
||||
|
||||
def test_list_schemas(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"b": {}, "a": {}}}
|
||||
|
||||
assert registry.list_schemas("v1") == ["a", "b"]
|
||||
assert registry.list_schemas("v2") == []
|
||||
|
||||
def test_get_all_schemas_for_version(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"test": {"title": "Test Label"}}}
|
||||
|
||||
results = registry.get_all_schemas_for_version("v1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "test"
|
||||
assert results[0]["label"] == "Test Label"
|
||||
assert results[0]["schema"] == {"title": "Test Label"}
|
||||
|
||||
# Default label if title missing
|
||||
registry.versions["v1"]["no_title"] = {}
|
||||
results = registry.get_all_schemas_for_version("v1")
|
||||
item = next(r for r in results if r["name"] == "no_title")
|
||||
assert item["label"] == "no_title"
|
||||
|
||||
# Empty if version missing
|
||||
assert registry.get_all_schemas_for_version("v2") == []
|
||||
80
api/tests/unit_tests/core/schemas/test_schema_manager.py
Normal file
80
api/tests/unit_tests/core/schemas/test_schema_manager.py
Normal file
@ -0,0 +1,80 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
from core.schemas.schema_manager import SchemaManager
|
||||
|
||||
|
||||
def test_init_with_provided_registry():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
assert manager.registry == mock_registry
|
||||
|
||||
|
||||
@patch("core.schemas.schema_manager.SchemaRegistry.default_registry")
|
||||
def test_init_with_default_registry(mock_default_registry):
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_default_registry.return_value = mock_registry
|
||||
|
||||
manager = SchemaManager()
|
||||
|
||||
mock_default_registry.assert_called_once()
|
||||
assert manager.registry == mock_registry
|
||||
|
||||
|
||||
def test_get_all_schema_definitions():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}]
|
||||
mock_registry.get_all_schemas_for_version.return_value = expected_definitions
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_all_schema_definitions(version="v2")
|
||||
|
||||
mock_registry.get_all_schemas_for_version.assert_called_once_with("v2")
|
||||
assert result == expected_definitions
|
||||
|
||||
|
||||
def test_get_schema_by_name_success():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_schema = {"type": "object"}
|
||||
mock_registry.get_schema.return_value = mock_schema
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_schema_by_name("my_schema", version="v1")
|
||||
|
||||
expected_uri = "https://dify.ai/schemas/v1/my_schema.json"
|
||||
mock_registry.get_schema.assert_called_once_with(expected_uri)
|
||||
assert result == {"name": "my_schema", "schema": mock_schema}
|
||||
|
||||
|
||||
def test_get_schema_by_name_not_found():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_registry.get_schema.return_value = None
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_schema_by_name("non_existent", version="v1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_list_available_schemas():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_schemas = ["schema1", "schema2"]
|
||||
mock_registry.list_schemas.return_value = expected_schemas
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.list_available_schemas(version="v1")
|
||||
|
||||
mock_registry.list_schemas.assert_called_once_with("v1")
|
||||
assert result == expected_schemas
|
||||
|
||||
|
||||
def test_list_available_versions():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_versions = ["v1", "v2"]
|
||||
mock_registry.list_versions.return_value = expected_versions
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.list_available_versions()
|
||||
|
||||
mock_registry.list_versions.assert_called_once()
|
||||
assert result == expected_versions
|
||||
@ -16,6 +16,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@ -324,7 +325,7 @@ class TestConversationModel:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=from_end_user_id,
|
||||
)
|
||||
|
||||
@ -345,7 +346,7 @@ class TestConversationModel:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
)
|
||||
conversation._inputs = inputs
|
||||
@ -364,7 +365,7 @@ class TestConversationModel:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
)
|
||||
inputs = {"query": "Hello", "context": "test"}
|
||||
@ -383,7 +384,7 @@ class TestConversationModel:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
summary="Test summary",
|
||||
)
|
||||
@ -402,7 +403,7 @@ class TestConversationModel:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
summary=None,
|
||||
)
|
||||
@ -425,7 +426,7 @@ class TestConversationModel:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
override_model_configs='{"model": "gpt-4"}',
|
||||
)
|
||||
@ -446,7 +447,7 @@ class TestConversationModel:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=from_end_user_id,
|
||||
dialogue_count=5,
|
||||
)
|
||||
@ -487,7 +488,7 @@ class TestMessageModel:
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
|
||||
# Assert
|
||||
@ -511,7 +512,7 @@ class TestMessageModel:
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
message._inputs = inputs
|
||||
|
||||
@ -533,7 +534,7 @@ class TestMessageModel:
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
inputs = {"query": "Hello", "context": "test"}
|
||||
|
||||
@ -555,7 +556,7 @@ class TestMessageModel:
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
override_model_configs='{"model": "gpt-4"}',
|
||||
)
|
||||
|
||||
@ -578,7 +579,7 @@ class TestMessageModel:
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
message_metadata=json.dumps(metadata),
|
||||
)
|
||||
|
||||
@ -600,7 +601,7 @@ class TestMessageModel:
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
message_metadata=None,
|
||||
)
|
||||
|
||||
@ -627,7 +628,7 @@ class TestMessageModel:
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
total_price=Decimal("0.0003"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
status="normal",
|
||||
)
|
||||
message.id = str(uuid4())
|
||||
@ -988,7 +989,7 @@ class TestModelIntegration:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
@ -1003,7 +1004,7 @@ class TestModelIntegration:
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
message.id = message_id
|
||||
|
||||
@ -1064,7 +1065,7 @@ class TestModelIntegration:
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
message.id = message_id
|
||||
|
||||
@ -1158,7 +1159,7 @@ class TestConversationStatusCount:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = str(uuid4())
|
||||
|
||||
@ -1183,7 +1184,7 @@ class TestConversationStatusCount:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
@ -1215,7 +1216,7 @@ class TestConversationStatusCount:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
@ -1307,7 +1308,7 @@ class TestConversationStatusCount:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
@ -1361,7 +1362,7 @@ class TestConversationStatusCount:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
@ -1418,7 +1419,7 @@ class TestConversationStatusCount:
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
|
||||
@ -1,135 +0,0 @@
|
||||
"""Unit tests for non-SQL helper logic in workflow run repository."""
|
||||
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
|
||||
from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus
|
||||
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowPauseReason
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
_build_human_input_required_reason,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_pause() -> Mock:
|
||||
"""Create a sample WorkflowPause model."""
|
||||
pause = Mock(spec=WorkflowPauseModel)
|
||||
pause.id = "pause-123"
|
||||
pause.workflow_id = "workflow-123"
|
||||
pause.workflow_run_id = "workflow-run-123"
|
||||
pause.state_object_key = "workflow-state-123.json"
|
||||
pause.resumed_at = None
|
||||
pause.created_at = datetime.now(UTC)
|
||||
return pause
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
|
||||
# Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
|
||||
assert entity.resumed_at == sample_workflow_pause.resumed_at
|
||||
|
||||
def test_get_state(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result == expected_state
|
||||
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result1 = entity.get_state()
|
||||
result2 = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
mock_storage.load.assert_called_once()
|
||||
|
||||
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
"""Test helper that builds HumanInputRequired pause reasons."""
|
||||
|
||||
def test_prefers_backstage_token_when_available(self) -> None:
|
||||
"""Use backstage token when multiple recipient types may exist."""
|
||||
# Arrange
|
||||
expiration_time = datetime.now(UTC)
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "Alice"},
|
||||
node_title="Ask Name",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id="form-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="rendered",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
reason_model = WorkflowPauseReason(
|
||||
pause_id="pause-1",
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
message="",
|
||||
)
|
||||
access_token = secrets.token_urlsafe(8)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id="form-1",
|
||||
delivery_id="delivery-1",
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Act
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
|
||||
|
||||
# Assert
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.form_token == access_token
|
||||
assert reason.node_title == "Ask Name"
|
||||
assert reason.form_content == "content"
|
||||
assert reason.inputs[0].output_variable_name == "name"
|
||||
assert reason.actions[0].id == "approve"
|
||||
@ -1,251 +0,0 @@
|
||||
"""Unit tests for workflow run repository with status filter."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from models import WorkflowRun, WorkflowRunTriggeredFrom
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
|
||||
class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Test workflow run repository with status filtering."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(self):
|
||||
"""Create a mock session maker."""
|
||||
return MagicMock(spec=sessionmaker)
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_session_maker):
|
||||
"""Create repository instance with mock session."""
|
||||
return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
|
||||
|
||||
def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker):
|
||||
"""Test getting paginated workflow runs without status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_runs
|
||||
|
||||
# Act
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 3
|
||||
assert result.limit == 20
|
||||
assert result.has_more is False
|
||||
|
||||
def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker):
|
||||
"""Test getting paginated workflow runs with status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_runs
|
||||
|
||||
# Act
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 2
|
||||
assert all(run.status == "succeeded" for run in result.data)
|
||||
|
||||
def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count without status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the GROUP BY query results
|
||||
mock_results = [
|
||||
("succeeded", 5),
|
||||
("failed", 2),
|
||||
("running", 1),
|
||||
]
|
||||
mock_session.execute.return_value.all.return_value = mock_results
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 8
|
||||
assert result["succeeded"] == 5
|
||||
assert result["failed"] == 2
|
||||
assert result["running"] == 1
|
||||
assert result["stopped"] == 0
|
||||
assert result["partial-succeeded"] == 0
|
||||
|
||||
def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count with status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the count query for succeeded status
|
||||
mock_session.scalar.return_value = 5
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 5
|
||||
assert result["succeeded"] == 5
|
||||
assert result["running"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["stopped"] == 0
|
||||
assert result["partial-succeeded"] == 0
|
||||
|
||||
def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker):
|
||||
"""Test that invalid status is still counted in total but not in any specific status."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock count query returning 0 for invalid status
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="invalid_status",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 0
|
||||
assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"])
|
||||
|
||||
def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count with time range filter verifies SQL query construction."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the GROUP BY query results
|
||||
mock_results = [
|
||||
("succeeded", 3),
|
||||
("running", 2),
|
||||
]
|
||||
mock_session.execute.return_value.all.return_value = mock_results
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
time_range="1d",
|
||||
)
|
||||
|
||||
# Assert results
|
||||
assert result["total"] == 5
|
||||
assert result["succeeded"] == 3
|
||||
assert result["running"] == 2
|
||||
assert result["failed"] == 0
|
||||
|
||||
# Verify that execute was called (which means GROUP BY query was used)
|
||||
assert mock_session.execute.called, "execute should have been called for GROUP BY query"
|
||||
|
||||
# Verify SQL query includes time filter by checking the statement
|
||||
call_args = mock_session.execute.call_args
|
||||
assert call_args is not None, "execute should have been called with a statement"
|
||||
|
||||
# The first argument should be the SQL statement
|
||||
stmt = call_args[0][0]
|
||||
# Convert to string to inspect the query
|
||||
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Verify the query includes created_at filter
|
||||
# The query should have a WHERE clause with created_at comparison
|
||||
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
|
||||
"Query should include created_at filter for time range"
|
||||
)
|
||||
|
||||
def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count with both status and time range filters verifies SQL query."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the count query for running status within time range
|
||||
mock_session.scalar.return_value = 2
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="running",
|
||||
time_range="1d",
|
||||
)
|
||||
|
||||
# Assert results
|
||||
assert result["total"] == 2
|
||||
assert result["running"] == 2
|
||||
assert result["succeeded"] == 0
|
||||
assert result["failed"] == 0
|
||||
|
||||
# Verify that scalar was called (which means COUNT query was used)
|
||||
assert mock_session.scalar.called, "scalar should have been called for count query"
|
||||
|
||||
# Verify SQL query includes both status and time filter
|
||||
call_args = mock_session.scalar.call_args
|
||||
assert call_args is not None, "scalar should have been called with a statement"
|
||||
|
||||
# The first argument should be the SQL statement
|
||||
stmt = call_args[0][0]
|
||||
# Convert to string to inspect the query
|
||||
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Verify the query includes both filters
|
||||
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
|
||||
"Query should include created_at filter for time range"
|
||||
)
|
||||
assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), (
|
||||
"Query should include status filter"
|
||||
)
|
||||
@ -1303,6 +1303,24 @@ class TestBillingServiceSubscriptionOperations:
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_get_plan_bulk_converts_string_expiration_date_to_int(self, mock_send_request):
|
||||
"""Test bulk plan retrieval converts string expiration_date to int."""
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1"]
|
||||
mock_send_request.return_value = {
|
||||
"data": {
|
||||
"tenant-1": {"plan": "sandbox", "expiration_date": "1735689600"},
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert "tenant-1" in result
|
||||
assert isinstance(result["tenant-1"]["expiration_date"], int)
|
||||
assert result["tenant-1"]["expiration_date"] == 1735689600
|
||||
|
||||
def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request):
|
||||
"""Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant)."""
|
||||
# Arrange
|
||||
|
||||
@ -15,6 +15,7 @@ from sqlalchemy import asc, desc
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account, ConversationVariable
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import (
|
||||
@ -350,7 +351,7 @@ class TestConversationServiceGetConversation:
|
||||
app_model = ConversationServiceTestDataFactory.create_app_mock()
|
||||
user = ConversationServiceTestDataFactory.create_account_mock()
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock(
|
||||
from_account_id=user.id, from_source="console"
|
||||
from_account_id=user.id, from_source=ConversationFromSource.CONSOLE
|
||||
)
|
||||
|
||||
mock_query = mock_db_session.query.return_value
|
||||
@ -374,7 +375,7 @@ class TestConversationServiceGetConversation:
|
||||
app_model = ConversationServiceTestDataFactory.create_app_mock()
|
||||
user = ConversationServiceTestDataFactory.create_end_user_mock()
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock(
|
||||
from_end_user_id=user.id, from_source="api"
|
||||
from_end_user_id=user.id, from_source=ConversationFromSource.API
|
||||
)
|
||||
|
||||
mock_query = mock_db_session.query.return_value
|
||||
@ -1111,7 +1112,7 @@ class TestConversationServiceEdgeCases:
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock(
|
||||
from_source="api", from_end_user_id="user-123"
|
||||
from_source=ConversationFromSource.API, from_end_user_id="user-123"
|
||||
)
|
||||
mock_session.scalars.return_value.all.return_value = [conversation]
|
||||
|
||||
@ -1143,7 +1144,7 @@ class TestConversationServiceEdgeCases:
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock(
|
||||
from_source="console", from_account_id="account-123"
|
||||
from_source=ConversationFromSource.CONSOLE, from_account_id="account-123"
|
||||
)
|
||||
mock_session.scalars.return_value.all.return_value = [conversation]
|
||||
|
||||
|
||||
558
api/tests/unit_tests/services/test_metadata_service.py
Normal file
558
api/tests/unit_tests/services/test_metadata_service.py
Normal file
@ -0,0 +1,558 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from models.dataset import Dataset
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
MetadataArgs,
|
||||
MetadataDetail,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DocumentStub:
|
||||
id: str
|
||||
name: str
|
||||
uploader: str
|
||||
upload_date: datetime
|
||||
last_update_date: datetime
|
||||
data_source_type: str
|
||||
doc_metadata: dict[str, object] | None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
mocked_db = mocker.patch("services.metadata_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.metadata_service.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_account(mocker: MockerFixture) -> MagicMock:
|
||||
mock_user = SimpleNamespace(id="user-1")
|
||||
return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1"))
|
||||
|
||||
|
||||
def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub:
|
||||
now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC)
|
||||
return _DocumentStub(
|
||||
id=document_id,
|
||||
name=f"doc-{document_id}",
|
||||
uploader="qa@example.com",
|
||||
upload_date=now,
|
||||
last_update_date=now,
|
||||
data_source_type="upload_file",
|
||||
doc_metadata=doc_metadata,
|
||||
)
|
||||
|
||||
|
||||
def _dataset(**kwargs: Any) -> Dataset:
|
||||
return cast(Dataset, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name="x" * 256)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="cannot exceed 255"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists(
|
||||
mock_db: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name="priority")
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Built-in fields"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
|
||||
def test_create_metadata_should_persist_metadata_when_input_is_valid(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="number", name="score")
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
# Assert
|
||||
assert result.tenant_id == "tenant-1"
|
||||
assert result.dataset_id == "dataset-1"
|
||||
assert result.type == "number"
|
||||
assert result.name == "score"
|
||||
assert result.created_by == "user-1"
|
||||
mock_db.session.add.assert_called_once_with(result)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None:
|
||||
# Arrange
|
||||
too_long_name = "x" * 256
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="cannot exceed 255"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name)
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate")
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin(
|
||||
mock_db: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Built-in fields"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source)
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_update_bound_documents_and_return_metadata(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC)
|
||||
mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now)
|
||||
|
||||
metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None)
|
||||
bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")]
|
||||
query_duplicate = MagicMock()
|
||||
query_duplicate.filter_by.return_value.first.return_value = None
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = metadata
|
||||
query_bindings = MagicMock()
|
||||
query_bindings.filter_by.return_value.all.return_value = bindings
|
||||
mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings]
|
||||
|
||||
doc_1 = _build_document("1", {"old_name": "value", "other": "keep"})
|
||||
doc_2 = _build_document("2", None)
|
||||
mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids")
|
||||
mock_get_documents.return_value = [doc_1, doc_2]
|
||||
|
||||
# Act
|
||||
result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name")
|
||||
|
||||
# Assert
|
||||
assert result is metadata
|
||||
assert metadata.name == "new_name"
|
||||
assert metadata.updated_by == "user-1"
|
||||
assert metadata.updated_at == fixed_now
|
||||
assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"}
|
||||
assert doc_2.doc_metadata == {"new_name": None}
|
||||
mock_get_documents.assert_called_once_with(["doc-1", "doc-2"])
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_return_none_when_metadata_does_not_exist(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_logger = mocker.patch("services.metadata_service.logger")
|
||||
|
||||
query_duplicate = MagicMock()
|
||||
query_duplicate.filter_by.return_value.first.return_value = None
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.query.side_effect = [query_duplicate, query_metadata]
|
||||
|
||||
# Act
|
||||
result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_logger.exception.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_metadata_should_remove_metadata_and_related_document_fields(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
metadata = SimpleNamespace(id="metadata-1", name="obsolete")
|
||||
bindings = [SimpleNamespace(document_id="doc-1")]
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = metadata
|
||||
query_bindings = MagicMock()
|
||||
query_bindings.filter_by.return_value.all.return_value = bindings
|
||||
mock_db.session.query.side_effect = [query_metadata, query_bindings]
|
||||
|
||||
document = _build_document("1", {"obsolete": "legacy", "remaining": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document])
|
||||
|
||||
# Act
|
||||
result = MetadataService.delete_metadata("dataset-1", "metadata-1")
|
||||
|
||||
# Assert
|
||||
assert result is metadata
|
||||
assert document.doc_metadata == {"remaining": "value"}
|
||||
mock_db.session.delete.assert_called_once_with(metadata)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_delete_metadata_should_return_none_when_metadata_is_missing(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_logger = mocker.patch("services.metadata_service.logger")
|
||||
|
||||
# Act
|
||||
result = MetadataService.delete_metadata("dataset-1", "missing-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_logger.exception.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_get_built_in_fields_should_return_all_expected_fields() -> None:
|
||||
# Arrange
|
||||
expected_names = {
|
||||
BuiltInField.document_name,
|
||||
BuiltInField.uploader,
|
||||
BuiltInField.upload_date,
|
||||
BuiltInField.last_update_date,
|
||||
BuiltInField.source,
|
||||
}
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_built_in_fields()
|
||||
|
||||
# Assert
|
||||
assert {item["name"] for item in result} == expected_names
|
||||
assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"]
|
||||
|
||||
|
||||
def test_enable_built_in_field_should_return_immediately_when_already_enabled(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
|
||||
|
||||
# Act
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
get_docs.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_enable_built_in_field_should_populate_documents_and_enable_flag(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
doc_1 = _build_document("1", {"custom": "value"})
|
||||
doc_2 = _build_document("2", None)
|
||||
mocker.patch(
|
||||
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
|
||||
return_value=[doc_1, doc_2],
|
||||
)
|
||||
|
||||
# Act
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
assert dataset.built_in_field_enabled is True
|
||||
assert doc_1.doc_metadata is not None
|
||||
assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1"
|
||||
assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
|
||||
assert doc_2.doc_metadata is not None
|
||||
assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_disable_built_in_field_should_return_immediately_when_already_disabled(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
|
||||
|
||||
# Act
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
get_docs.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
document = _build_document(
|
||||
"1",
|
||||
{
|
||||
BuiltInField.document_name: "doc",
|
||||
BuiltInField.uploader: "user",
|
||||
BuiltInField.upload_date: 1.0,
|
||||
BuiltInField.last_update_date: 2.0,
|
||||
BuiltInField.source: MetadataDataSource.upload_file,
|
||||
"custom": "keep",
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
|
||||
return_value=[document],
|
||||
)
|
||||
|
||||
# Act
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
assert dataset.built_in_field_enabled is False
|
||||
assert document.doc_metadata == {"custom": "keep"}
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
document = _build_document("1", {"legacy": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
|
||||
delete_chain = mock_db.session.query.return_value.filter_by.return_value
|
||||
delete_chain.delete.return_value = 1
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="1",
|
||||
metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")],
|
||||
partial_update=False,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
assert document.doc_metadata == {"priority": "high"}
|
||||
delete_chain.delete.assert_called_once()
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
document = _build_document("1", {"existing": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="1",
|
||||
metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")],
|
||||
partial_update=True,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
assert document.doc_metadata is not None
|
||||
assert document.doc_metadata["existing"] == "value"
|
||||
assert document.doc_metadata["new_key"] == "new_value"
|
||||
assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None)
|
||||
operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Document not found"):
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
mock_db.session.rollback.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("dataset_id", "document_id", "expected_key"),
|
||||
[
|
||||
("dataset-1", None, "dataset_metadata_lock_dataset-1"),
|
||||
(None, "doc-1", "document_metadata_lock_doc-1"),
|
||||
],
|
||||
)
|
||||
def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked(
|
||||
dataset_id: str | None,
|
||||
document_id: str | None,
|
||||
expected_key: str,
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600)
|
||||
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = 1
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="knowledge base metadata operation is running"):
|
||||
MetadataService.knowledge_base_metadata_lock_check("dataset-1", None)
|
||||
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = 1
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="document metadata operation is running"):
|
||||
MetadataService.knowledge_base_metadata_lock_check(None, "doc-1")
|
||||
|
||||
|
||||
def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(
|
||||
id="dataset-1",
|
||||
built_in_field_enabled=True,
|
||||
doc_metadata=[
|
||||
{"id": "meta-1", "name": "priority", "type": "string"},
|
||||
{"id": "built-in", "name": "ignored", "type": "string"},
|
||||
{"id": "meta-2", "name": "score", "type": "number"},
|
||||
],
|
||||
)
|
||||
count_chain = mock_db.session.query.return_value.filter_by.return_value
|
||||
count_chain.count.side_effect = [3, 1]
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
|
||||
# Assert
|
||||
assert result["built_in_field_enabled"] is True
|
||||
assert result["doc_metadata"] == [
|
||||
{"id": "meta-1", "name": "priority", "type": "string", "count": 3},
|
||||
{"id": "meta-2", "name": "score", "type": "number", "count": 1},
|
||||
]
|
||||
|
||||
|
||||
def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None)
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
|
||||
# Assert
|
||||
assert result == {"doc_metadata": [], "built_in_field_enabled": False}
|
||||
mock_db.session.query.assert_not_called()
|
||||
@ -0,0 +1,808 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import (
|
||||
CredentialFormSchema,
|
||||
FieldModelSchema,
|
||||
FormType,
|
||||
ModelCredentialSchema,
|
||||
ProviderCredentialSchema,
|
||||
)
|
||||
from models.provider import LoadBalancingModelConfig
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
def _build_provider_credential_schema() -> ProviderCredentialSchema:
|
||||
return ProviderCredentialSchema(
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _build_model_credential_schema() -> ModelCredentialSchema:
|
||||
return ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _build_provider_configuration(
|
||||
*,
|
||||
custom_provider: bool = False,
|
||||
load_balancing_enabled: bool | None = None,
|
||||
model_schema: ModelCredentialSchema | None = None,
|
||||
provider_schema: ProviderCredentialSchema | None = None,
|
||||
) -> MagicMock:
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.provider = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_credential_schema=model_schema,
|
||||
provider_credential_schema=provider_schema,
|
||||
)
|
||||
provider_configuration.custom_configuration = SimpleNamespace(provider=custom_provider)
|
||||
provider_configuration.extract_secret_variables.return_value = ["api_key"]
|
||||
provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: credentials
|
||||
provider_configuration.get_provider_model_setting.return_value = (
|
||||
None if load_balancing_enabled is None else SimpleNamespace(load_balancing_enabled=load_balancing_enabled)
|
||||
)
|
||||
return provider_configuration
|
||||
|
||||
|
||||
def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig:
|
||||
return cast(LoadBalancingModelConfig, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(mocker: MockerFixture) -> ModelLoadBalancingService:
|
||||
# Arrange
|
||||
provider_manager = MagicMock()
|
||||
mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager)
|
||||
svc = ModelLoadBalancingService()
|
||||
svc.provider_manager = provider_manager
|
||||
return svc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
# Arrange
|
||||
mocked_db = mocker.patch("services.model_load_balancing_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "expected_provider_method"),
|
||||
[
|
||||
("enable_model_load_balancing", "enable_model_load_balancing"),
|
||||
("disable_model_load_balancing", "disable_model_load_balancing"),
|
||||
],
|
||||
)
|
||||
def test_enable_disable_model_load_balancing_should_call_provider_configuration_method_when_provider_exists(
|
||||
method_name: str,
|
||||
expected_provider_method: str,
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
# Act
|
||||
getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
getattr(provider_configuration, expected_provider_method).assert_called_once_with(
|
||||
model="gpt-4o-mini", model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name",
|
||||
["enable_model_load_balancing", "disable_model_load_balancing"],
|
||||
)
|
||||
def test_enable_disable_model_load_balancing_should_raise_value_error_when_provider_missing(
|
||||
method_name: str,
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
|
||||
|
||||
|
||||
def test_get_load_balancing_configs_should_raise_value_error_when_provider_missing(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
service.get_load_balancing_configs("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
|
||||
|
||||
|
||||
def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_for_custom_provider(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(
|
||||
custom_provider=True,
|
||||
load_balancing_enabled=True,
|
||||
provider_schema=_build_provider_credential_schema(),
|
||||
)
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
config = SimpleNamespace(
|
||||
id="cfg-1",
|
||||
name="primary",
|
||||
encrypted_config=json.dumps({"api_key": "encrypted-key"}),
|
||||
credential_id="cred-1",
|
||||
enabled=True,
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config]
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
|
||||
return_value=("rsa", "cipher"),
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.decrypt_token_with_decoding",
|
||||
return_value="plain-key",
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl",
|
||||
return_value=(False, 0),
|
||||
)
|
||||
|
||||
# Act
|
||||
is_enabled, configs = service.get_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert is_enabled is True
|
||||
assert len(configs) == 2
|
||||
assert configs[0]["name"] == "__inherit__"
|
||||
assert configs[1]["name"] == "primary"
|
||||
assert configs[1]["credentials"] == {"api_key": "plain-key"}
|
||||
assert mock_db.session.add.call_count == 1
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
|
||||
|
||||
def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate_json_or_decrypt_errors(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(
|
||||
custom_provider=True,
|
||||
load_balancing_enabled=None,
|
||||
provider_schema=_build_provider_credential_schema(),
|
||||
)
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
normal_config = SimpleNamespace(
|
||||
id="cfg-1",
|
||||
name="normal",
|
||||
encrypted_config=json.dumps({"api_key": "bad-encrypted"}),
|
||||
credential_id="cred-1",
|
||||
enabled=True,
|
||||
)
|
||||
inherit_config = SimpleNamespace(
|
||||
id="cfg-2",
|
||||
name="__inherit__",
|
||||
encrypted_config="not-json",
|
||||
credential_id=None,
|
||||
enabled=False,
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
|
||||
normal_config,
|
||||
inherit_config,
|
||||
]
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
|
||||
return_value=("rsa", "cipher"),
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.decrypt_token_with_decoding",
|
||||
side_effect=ValueError("cannot decrypt"),
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl",
|
||||
return_value=(True, 15),
|
||||
)
|
||||
|
||||
# Act
|
||||
is_enabled, configs = service.get_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
config_from="predefined-model",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert is_enabled is False
|
||||
assert configs[0]["name"] == "__inherit__"
|
||||
assert configs[0]["credentials"] == {}
|
||||
assert configs[1]["credentials"] == {"api_key": "bad-encrypted"}
|
||||
assert configs[1]["in_cooldown"] is True
|
||||
assert configs[1]["ttl"] == 15
|
||||
|
||||
|
||||
def test_get_load_balancing_config_should_raise_value_error_when_provider_missing(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||
|
||||
|
||||
def test_get_load_balancing_config_should_return_none_when_config_not_found(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_exists(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: {
|
||||
"masked": credentials.get("api_key", "")
|
||||
}
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = config
|
||||
|
||||
# Act
|
||||
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||
|
||||
# Assert
|
||||
assert result == {
|
||||
"id": "cfg-1",
|
||||
"name": "primary",
|
||||
"credentials": {"masked": ""},
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def test_init_inherit_config_should_create_and_persist_inherit_configuration(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
model_type = ModelType.LLM
|
||||
|
||||
# Act
|
||||
inherit_config = service._init_inherit_config("tenant-1", "openai", "gpt-4o-mini", model_type)
|
||||
|
||||
# Assert
|
||||
assert inherit_config.tenant_id == "tenant-1"
|
||||
assert inherit_config.provider_name == "openai"
|
||||
assert inherit_config.model_name == "gpt-4o-mini"
|
||||
assert inherit_config.model_type == "text-generation"
|
||||
assert inherit_config.name == "__inherit__"
|
||||
mock_db.session.add.assert_called_once_with(inherit_config)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_provider_missing(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_configs_is_not_list(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing configs"):
|
||||
service.update_load_balancing_configs( # type: ignore[arg-type]
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
cast(list[dict[str, object]], "invalid-configs"),
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_config_item_is_not_dict(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config"):
|
||||
service.update_load_balancing_configs( # type: ignore[list-item]
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
cast(list[dict[str, object]], ["bad-item"]),
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_credential_id_not_found(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"credential_id": "cred-1", "enabled": True}],
|
||||
"predefined-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_name_or_enabled_is_invalid(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config name"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"enabled": True}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config enabled"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"name": "cfg-without-enabled"}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_existing_config_id_is_invalid(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
current_config = SimpleNamespace(id="cfg-1")
|
||||
mock_db.session.scalars.return_value.all.return_value = [current_config]
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config id: cfg-2"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"id": "cfg-2", "name": "invalid", "enabled": True}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_when_credentials_are_invalid_for_update_or_create(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
existing_config = SimpleNamespace(id="cfg-1", name="old", enabled=True, encrypted_config=None, updated_at=None)
|
||||
mock_db.session.scalars.return_value.all.return_value = [existing_config]
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"id": "cfg-1", "name": "new", "enabled": True, "credentials": "bad"}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"name": "new-config", "enabled": True, "credentials": "bad"}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_update_existing_create_new_and_delete_removed_configs(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
existing_config_1 = SimpleNamespace(
|
||||
id="cfg-1",
|
||||
name="existing-one",
|
||||
enabled=True,
|
||||
encrypted_config=json.dumps({"api_key": "old"}),
|
||||
updated_at=None,
|
||||
)
|
||||
existing_config_2 = SimpleNamespace(
|
||||
id="cfg-2",
|
||||
name="existing-two",
|
||||
enabled=True,
|
||||
encrypted_config=None,
|
||||
updated_at=None,
|
||||
)
|
||||
mock_db.session.scalars.return_value.all.return_value = [existing_config_1, existing_config_2]
|
||||
mocker.patch.object(service, "_custom_credentials_validate", return_value={"api_key": "encrypted"})
|
||||
mock_clear_cache = mocker.patch.object(service, "_clear_credentials_cache")
|
||||
|
||||
# Act
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[
|
||||
{"id": "cfg-1", "name": "updated-name", "enabled": False, "credentials": {"api_key": "plain"}},
|
||||
{"name": "new-config", "enabled": True, "credentials": {"api_key": "plain"}},
|
||||
],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert existing_config_1.name == "updated-name"
|
||||
assert existing_config_1.enabled is False
|
||||
assert json.loads(existing_config_1.encrypted_config) == {"api_key": "encrypted"}
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_db.session.delete.assert_called_once_with(existing_config_2)
|
||||
assert mock_db.session.commit.call_count >= 3
|
||||
mock_clear_cache.assert_any_call("tenant-1", "cfg-1")
|
||||
mock_clear_cache.assert_any_call("tenant-1", "cfg-2")
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_raise_value_error_for_invalid_new_config_name_or_missing_credentials(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config name"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"name": "__inherit__", "enabled": True, "credentials": {"api_key": "x"}}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"name": "new", "enabled": True}],
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
|
||||
def test_update_load_balancing_configs_should_create_from_existing_provider_credential_when_credential_id_provided(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}')
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record
|
||||
|
||||
# Act
|
||||
service.update_load_balancing_configs(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
[{"credential_id": "cred-1", "enabled": True}],
|
||||
"predefined-model",
|
||||
)
|
||||
|
||||
# Assert
|
||||
created_config = mock_db.session.add.call_args.args[0]
|
||||
assert created_config.name == "Main Credential"
|
||||
assert created_config.credential_id == "cred-1"
|
||||
assert created_config.credential_source_type == "provider"
|
||||
assert created_config.encrypted_config == '{"api_key":"enc"}'
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
|
||||
def test_validate_load_balancing_credentials_should_raise_value_error_when_provider_missing(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service.provider_manager.get_configurations.return_value = {}
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist"):
|
||||
service.validate_load_balancing_credentials(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
{"api_key": "plain"},
|
||||
)
|
||||
|
||||
|
||||
def test_validate_load_balancing_credentials_should_raise_value_error_when_config_id_is_invalid(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"):
|
||||
service.validate_load_balancing_credentials(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
{"api_key": "plain"},
|
||||
config_id="cfg-1",
|
||||
)
|
||||
|
||||
|
||||
def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_with_or_without_config(
|
||||
service: ModelLoadBalancingService,
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
existing_config = SimpleNamespace(id="cfg-1")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = existing_config
|
||||
mock_validate = mocker.patch.object(service, "_custom_credentials_validate")
|
||||
|
||||
# Act
|
||||
service.validate_load_balancing_credentials(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
{"api_key": "plain"},
|
||||
config_id="cfg-1",
|
||||
)
|
||||
service.validate_load_balancing_credentials(
|
||||
"tenant-1",
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
ModelType.LLM.value,
|
||||
{"api_key": "plain"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_validate.call_count == 2
|
||||
assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config
|
||||
assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None
|
||||
|
||||
|
||||
def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt(
|
||||
service: ModelLoadBalancingService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
load_balancing_model_config = _load_balancing_model_config(
|
||||
encrypted_config=json.dumps({"api_key": "old-encrypted-token"})
|
||||
)
|
||||
mocker.patch("services.model_load_balancing_service.encrypter.decrypt_token", return_value="old-plain-value")
|
||||
mock_encrypt = mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service._custom_credentials_validate(
|
||||
tenant_id="tenant-1",
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": HIDDEN_VALUE, "region": "us"},
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"api_key": "enc:old-plain-value", "region": "us"}
|
||||
mock_encrypt.assert_called_once_with("tenant-1", "old-plain-value")
|
||||
|
||||
|
||||
def test_custom_credentials_validate_should_handle_invalid_original_json_and_validate_with_model_schema(
|
||||
service: ModelLoadBalancingService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(model_schema=_build_model_credential_schema())
|
||||
load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json")
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.model_credentials_validate.return_value = {"api_key": "validated"}
|
||||
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
|
||||
mock_encrypt = mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service._custom_credentials_validate(
|
||||
tenant_id="tenant-1",
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "plain"},
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
validate=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"api_key": "enc:validated"}
|
||||
mock_factory.model_credentials_validate.assert_called_once()
|
||||
mock_factory.provider_credentials_validate.assert_not_called()
|
||||
mock_encrypt.assert_called_once_with("tenant-1", "validated")
|
||||
|
||||
|
||||
def test_custom_credentials_validate_should_validate_with_provider_schema_when_model_schema_absent(
|
||||
service: ModelLoadBalancingService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"}
|
||||
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service._custom_credentials_validate(
|
||||
tenant_id="tenant-1",
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "plain"},
|
||||
validate=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"api_key": "enc:provider-validated"}
|
||||
mock_factory.provider_credentials_validate.assert_called_once()
|
||||
mock_factory.model_credentials_validate.assert_not_called()
|
||||
|
||||
|
||||
def test_get_credential_schema_should_return_model_schema_or_provider_schema_or_raise(
|
||||
service: ModelLoadBalancingService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
model_schema = _build_model_credential_schema()
|
||||
provider_schema = _build_provider_credential_schema()
|
||||
provider_configuration_with_model = _build_provider_configuration(model_schema=model_schema)
|
||||
provider_configuration_with_provider = _build_provider_configuration(provider_schema=provider_schema)
|
||||
provider_configuration_without_schema = _build_provider_configuration()
|
||||
|
||||
# Act
|
||||
schema_from_model = service._get_credential_schema(provider_configuration_with_model)
|
||||
schema_from_provider = service._get_credential_schema(provider_configuration_with_provider)
|
||||
|
||||
# Assert
|
||||
assert schema_from_model is model_schema
|
||||
assert schema_from_provider is provider_schema
|
||||
with pytest.raises(ValueError, match="No credential schema found"):
|
||||
service._get_credential_schema(provider_configuration_without_schema)
|
||||
|
||||
|
||||
def test_clear_credentials_cache_should_delete_load_balancing_cache_entry(
|
||||
service: ModelLoadBalancingService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_cache_instance = MagicMock()
|
||||
mock_cache_cls = mocker.patch(
|
||||
"services.model_load_balancing_service.ProviderCredentialsCache",
|
||||
return_value=mock_cache_instance,
|
||||
)
|
||||
|
||||
# Act
|
||||
service._clear_credentials_cache("tenant-1", "cfg-1")
|
||||
|
||||
# Assert
|
||||
mock_cache_cls.assert_called_once()
|
||||
assert mock_cache_cls.call_args.kwargs == {
|
||||
"tenant_id": "tenant-1",
|
||||
"identity_id": "cfg-1",
|
||||
"cache_type": mocker.ANY,
|
||||
}
|
||||
assert mock_cache_cls.call_args.kwargs["cache_type"].name == "LOAD_BALANCING_MODEL"
|
||||
mock_cache_instance.delete.assert_called_once()
|
||||
224
api/tests/unit_tests/services/test_oauth_server_service.py
Normal file
224
api/tests/unit_tests/services/test_oauth_server_service.py
Normal file
@ -0,0 +1,224 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.oauth_server.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(mocker: MockerFixture) -> MagicMock:
|
||||
"""Mock the OAuth server Session context manager."""
|
||||
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
mocker.patch("services.oauth_server.Session", return_value=session_cm)
|
||||
return session
|
||||
|
||||
|
||||
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_execute_result = MagicMock()
|
||||
expected_app = MagicMock()
|
||||
mock_execute_result.scalar_one_or_none.return_value = expected_app
|
||||
mock_session.execute.return_value = mock_execute_result
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.get_oauth_provider_app("client-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected_app
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_execute_result.scalar_one_or_none.assert_called_once()
|
||||
|
||||
|
||||
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
# Assert
|
||||
expected_code = str(deterministic_uuid)
|
||||
assert code == expected_code
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
|
||||
# Act
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
mock_redis_client.delete.assert_called_once_with(code_key)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
|
||||
# Act
|
||||
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh_token == "refresh-1"
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
|
||||
# Arrange
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=grant_type,
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
# Assert
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_load_user_when_token_exists(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
# Assert
|
||||
assert result is expected_user
|
||||
mock_load_user.assert_called_once_with("user-88")
|
||||
1249
api/tests/unit_tests/services/test_trigger_provider_service.py
Normal file
1249
api/tests/unit_tests/services/test_trigger_provider_service.py
Normal file
File diff suppressed because it is too large
Load Diff
259
api/tests/unit_tests/services/test_web_conversation_service.py
Normal file
259
api/tests/unit_tests/services/test_web_conversation_service.py
Normal file
@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import Account
|
||||
from models.model import App, EndUser
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_model() -> App:
|
||||
return cast(App, SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _end_user(**kwargs: Any) -> EndUser:
|
||||
return cast(EndUser, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_raise_error_when_user_is_none(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="User is required"):
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=None,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
fake_user = _account(id="user-1")
|
||||
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
mock_pagination.return_value = MagicMock()
|
||||
|
||||
# Act
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=fake_user,
|
||||
last_id="conv-9",
|
||||
limit=10,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_pagination.call_args.kwargs
|
||||
assert call_kwargs["include_ids"] is None
|
||||
assert call_kwargs["exclude_ids"] is None
|
||||
assert call_kwargs["last_id"] == "conv-9"
|
||||
assert call_kwargs["sort_by"] == "-updated_at"
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
fake_account_cls = type("FakeAccount", (), {})
|
||||
fake_user = cast(Account, fake_account_cls())
|
||||
fake_user.id = "account-1"
|
||||
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
|
||||
mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {}))
|
||||
session.scalars.return_value.all.return_value = ["conv-1", "conv-2"]
|
||||
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
mock_pagination.return_value = MagicMock()
|
||||
|
||||
# Act
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=fake_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_pagination.call_args.kwargs
|
||||
assert call_kwargs["include_ids"] == ["conv-1", "conv-2"]
|
||||
assert call_kwargs["exclude_ids"] is None
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
fake_end_user_cls = type("FakeEndUser", (), {})
|
||||
fake_user = cast(EndUser, fake_end_user_cls())
|
||||
fake_user.id = "end-user-1"
|
||||
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
|
||||
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
|
||||
session.scalars.return_value.all.return_value = ["conv-3"]
|
||||
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
mock_pagination.return_value = MagicMock()
|
||||
|
||||
# Act
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=fake_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_pagination.call_args.kwargs
|
||||
assert call_kwargs["include_ids"] is None
|
||||
assert call_kwargs["exclude_ids"] == ["conv-3"]
|
||||
|
||||
|
||||
def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
|
||||
|
||||
# Act
|
||||
WebConversationService.pin(app_model, "conv-1", None)
|
||||
|
||||
# Assert
|
||||
mock_db.session.add.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_pin_should_return_early_when_conversation_is_already_pinned(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_account_cls = type("FakeAccount", (), {})
|
||||
fake_user = cast(Account, fake_account_cls())
|
||||
fake_user.id = "account-1"
|
||||
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = object()
|
||||
mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
|
||||
|
||||
# Act
|
||||
WebConversationService.pin(app_model, "conv-1", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_get_conversation.assert_not_called()
|
||||
mock_db.session.add.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_pin_should_create_pinned_conversation_when_not_already_pinned(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_account_cls = type("FakeAccount", (), {})
|
||||
fake_user = cast(Account, fake_account_cls())
|
||||
fake_user.id = "account-2"
|
||||
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_conversation = SimpleNamespace(id="conv-2")
|
||||
mock_get_conversation = mocker.patch(
|
||||
"services.web_conversation_service.ConversationService.get_conversation",
|
||||
return_value=mock_conversation,
|
||||
)
|
||||
|
||||
# Act
|
||||
WebConversationService.pin(app_model, "conv-2", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user)
|
||||
added_obj = mock_db.session.add.call_args.args[0]
|
||||
assert added_obj.app_id == "app-1"
|
||||
assert added_obj.conversation_id == "conv-2"
|
||||
assert added_obj.created_by_role == "account"
|
||||
assert added_obj.created_by == "account-2"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
|
||||
# Act
|
||||
WebConversationService.unpin(app_model, "conv-1", None)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_unpin_should_return_early_when_conversation_is_not_pinned(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_end_user_cls = type("FakeEndUser", (), {})
|
||||
fake_user = cast(EndUser, fake_end_user_cls())
|
||||
fake_user.id = "end-user-3"
|
||||
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
|
||||
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
WebConversationService.unpin(app_model, "conv-7", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_unpin_should_delete_pinned_conversation_when_exists(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_end_user_cls = type("FakeEndUser", (), {})
|
||||
fake_user = cast(EndUser, fake_end_user_cls())
|
||||
fake_user.id = "end-user-4"
|
||||
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
|
||||
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
pinned_obj = SimpleNamespace(id="pin-1")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj
|
||||
|
||||
# Act
|
||||
WebConversationService.unpin(app_model, "conv-8", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_called_once_with(pinned_obj)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
379
api/tests/unit_tests/services/test_webapp_auth_service.py
Normal file
379
api/tests/unit_tests/services/test_webapp_auth_service.py
Normal file
@ -0,0 +1,379 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from models import Account, AccountStatus
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback"
|
||||
TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token"
|
||||
TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data"
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
# Arrange
|
||||
mocked_db = mocker.patch("services.webapp_auth_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
|
||||
def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt")
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AccountLoginError, match="Account is banned"):
|
||||
WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("password_value", [None, "hash"])
|
||||
def test_authenticate_should_raise_password_error_when_password_is_invalid(
|
||||
password_value: str | None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt")
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
mocker.patch("services.webapp_auth_service.compare_password", return_value=False)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AccountPasswordError, match="Invalid email or password"):
|
||||
WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
|
||||
def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt")
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
mocker.patch("services.webapp_auth_service.compare_password", return_value=True)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
# Assert
|
||||
assert result is account
|
||||
|
||||
|
||||
def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = _account(id="a1", email="u@example.com")
|
||||
mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token")
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.login(account)
|
||||
|
||||
# Assert
|
||||
assert result == "jwt-token"
|
||||
mock_get_token.assert_called_once_with(account=account)
|
||||
|
||||
|
||||
def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_user_through_email("missing@example.com")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.BANNED)
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(Unauthorized, match="Account is banned"):
|
||||
WebAppAuthService.get_user_through_email("user@example.com")
|
||||
|
||||
|
||||
def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.ACTIVE)
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_user_through_email("user@example.com")
|
||||
|
||||
# Assert
|
||||
assert result is account
|
||||
|
||||
|
||||
def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Email must be provided"):
|
||||
WebAppAuthService.send_email_code_login_email(account=None, email=None)
|
||||
|
||||
|
||||
def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
account = _account(email="user@example.com")
|
||||
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6])
|
||||
mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1")
|
||||
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
|
||||
|
||||
# Assert
|
||||
assert result == "token-1"
|
||||
mock_generate_token.assert_called_once()
|
||||
assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"}
|
||||
mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456")
|
||||
|
||||
|
||||
def test_send_email_code_login_email_should_send_mail_for_email_without_account(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0])
|
||||
mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2")
|
||||
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans")
|
||||
|
||||
# Assert
|
||||
assert result == "token-2"
|
||||
mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000")
|
||||
|
||||
|
||||
def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"})
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_email_code_login_data("token-abc")
|
||||
|
||||
# Assert
|
||||
assert result == {"code": "123"}
|
||||
mock_get_data.assert_called_once_with("token-abc", "email_code_login")
|
||||
|
||||
|
||||
def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token")
|
||||
|
||||
# Act
|
||||
WebAppAuthService.revoke_email_code_login_token("token-xyz")
|
||||
|
||||
# Assert
|
||||
mock_revoke.assert_called_once_with("token-xyz", "email_code_login")
|
||||
|
||||
|
||||
def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(NotFound, match="Site not found"):
|
||||
WebAppAuthService.create_end_user("app-code", "user@example.com")
|
||||
|
||||
|
||||
def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
site = SimpleNamespace(app_id="app-1")
|
||||
app_query = MagicMock()
|
||||
app_query.where.return_value.first.return_value = None
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None]
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(NotFound, match="App not found"):
|
||||
WebAppAuthService.create_end_user("app-code", "user@example.com")
|
||||
|
||||
|
||||
def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
site = SimpleNamespace(app_id="app-1")
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model]
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.create_end_user("app-code", "user@example.com")
|
||||
|
||||
# Assert
|
||||
assert result.tenant_id == "tenant-1"
|
||||
assert result.app_id == "app-1"
|
||||
assert result.session_id == "user@example.com"
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = _account(id="a1", email="user@example.com")
|
||||
mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60)
|
||||
mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1")
|
||||
|
||||
# Act
|
||||
token = WebAppAuthService._get_account_jwt_token(account)
|
||||
|
||||
# Assert
|
||||
assert token == "jwt-1"
|
||||
payload = mock_issue.call_args.args[0]
|
||||
assert payload["user_id"] == "a1"
|
||||
assert payload["session_id"] == "user@example.com"
|
||||
assert payload["token_source"] == "webapp_login_token"
|
||||
assert payload["auth_type"] == "internal"
|
||||
assert payload["exp"] > int(datetime.now(UTC).timestamp())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "expected"),
|
||||
[
|
||||
("private", True),
|
||||
("private_all", True),
|
||||
("public", False),
|
||||
],
|
||||
)
|
||||
def test_is_app_require_permission_check_should_use_access_mode_when_provided(
|
||||
access_mode: str,
|
||||
expected: bool,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Either app_code or app_id must be provided"):
|
||||
WebAppAuthService.is_app_require_permission_check()
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="App ID could not be determined"):
|
||||
WebAppAuthService.is_app_require_permission_check(app_code="app-code")
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
|
||||
mocker.patch(
|
||||
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=SimpleNamespace(access_mode="private"),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.is_app_require_permission_check(app_code="app-code")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=SimpleNamespace(access_mode="public"),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.is_app_require_permission_check(app_id="app-1")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "expected"),
|
||||
[
|
||||
("public", WebAppAuthType.PUBLIC),
|
||||
("private", WebAppAuthType.INTERNAL),
|
||||
("private_all", WebAppAuthType.INTERNAL),
|
||||
("sso_verified", WebAppAuthType.EXTERNAL),
|
||||
],
|
||||
)
|
||||
def test_get_app_auth_type_should_map_access_modes_correctly(
|
||||
access_mode: str,
|
||||
expected: WebAppAuthType,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = WebAppAuthService.get_app_auth_type(access_mode=access_mode)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
|
||||
mocker.patch(
|
||||
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=SimpleNamespace(access_mode="private_all"),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_app_auth_type(app_code="app-code")
|
||||
|
||||
# Assert
|
||||
assert result == WebAppAuthType.INTERNAL
|
||||
|
||||
|
||||
def test_get_app_auth_type_should_raise_when_no_input_provided() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"):
|
||||
WebAppAuthService.get_app_auth_type()
|
||||
|
||||
|
||||
def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Could not determine app authentication type"):
|
||||
WebAppAuthService.get_app_auth_type(access_mode="unknown")
|
||||
300
api/tests/unit_tests/services/test_workflow_app_service.py
Normal file
300
api/tests/unit_tests/services/test_workflow_app_service.py
Normal file
@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from models import App, WorkflowAppLog
|
||||
from models.enums import AppTriggerType, CreatorUserRole
|
||||
from services.workflow_app_service import LogView, WorkflowAppService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> WorkflowAppService:
|
||||
# Arrange
|
||||
return WorkflowAppService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_model() -> App:
|
||||
# Arrange
|
||||
return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
|
||||
|
||||
|
||||
def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog:
|
||||
return cast(WorkflowAppLog, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None:
|
||||
# Arrange
|
||||
log = _workflow_app_log(id="log-1", status="succeeded")
|
||||
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
|
||||
|
||||
# Act
|
||||
details = view.details
|
||||
proxied_status = view.status
|
||||
|
||||
# Assert
|
||||
assert details == {"trigger_metadata": {"type": "plugin"}}
|
||||
assert proxied_status == "succeeded"
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
log_1 = SimpleNamespace(id="log-1")
|
||||
log_2 = SimpleNamespace(id="log-2")
|
||||
session.scalar.return_value = 3
|
||||
session.scalars.return_value.all.return_value = [log_1, log_2]
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=1,
|
||||
limit=2,
|
||||
detail=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["page"] == 1
|
||||
assert result["limit"] == 2
|
||||
assert result["total"] == 3
|
||||
assert result["has_more"] is True
|
||||
assert len(result["data"]) == 2
|
||||
assert isinstance(result["data"][0], LogView)
|
||||
assert result["data"][0].details is None
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [1]
|
||||
log_1 = SimpleNamespace(id="log-1")
|
||||
session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')]
|
||||
mock_handle = mocker.patch.object(
|
||||
service,
|
||||
"handle_trigger_metadata",
|
||||
return_value={"type": "trigger_plugin", "icon": "url"},
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword="run-1",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_before=None,
|
||||
created_at_after=None,
|
||||
page=1,
|
||||
limit=20,
|
||||
detail=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 1
|
||||
assert len(result["data"]) == 1
|
||||
assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}}
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Account not found: account@example.com"):
|
||||
service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
created_by_account="account@example.com",
|
||||
)
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0]
|
||||
session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
created_by_account="account@example.com",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 0
|
||||
assert result["data"] == []
|
||||
|
||||
|
||||
def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
log_account = SimpleNamespace(
|
||||
id="log-1",
|
||||
created_by="acc-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
workflow_run_summary={"run": "1"},
|
||||
trigger_metadata='{"type":"trigger-webhook"}',
|
||||
log_created_at="2026-01-01",
|
||||
)
|
||||
log_end_user = SimpleNamespace(
|
||||
id="log-2",
|
||||
created_by="end-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
workflow_run_summary={"run": "2"},
|
||||
trigger_metadata='{"type":"trigger-webhook"}',
|
||||
log_created_at="2026-01-02",
|
||||
)
|
||||
log_unknown = SimpleNamespace(
|
||||
id="log-3",
|
||||
created_by="other",
|
||||
created_by_role="system",
|
||||
workflow_run_summary={"run": "3"},
|
||||
trigger_metadata='{"type":"trigger-webhook"}',
|
||||
log_created_at="2026-01-03",
|
||||
)
|
||||
session.scalar.return_value = 3
|
||||
session.scalars.side_effect = [
|
||||
SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]),
|
||||
SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]),
|
||||
SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=1,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 3
|
||||
assert len(result["data"]) == 3
|
||||
assert result["data"][0]["created_by_account"].id == "acc-1"
|
||||
assert result["data"][1]["created_by_end_user"].id == "end-1"
|
||||
assert result["data"][2]["created_by_account"] is None
|
||||
assert result["data"][2]["created_by_end_user"] is None
|
||||
|
||||
|
||||
def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing(
|
||||
service: WorkflowAppService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = service.handle_trigger_metadata("tenant-1", None)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin(
|
||||
service: WorkflowAppService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
meta = {
|
||||
"type": AppTriggerType.TRIGGER_PLUGIN.value,
|
||||
"icon_filename": "light.png",
|
||||
"icon_dark_filename": "dark.png",
|
||||
}
|
||||
mock_icon = mocker.patch(
|
||||
"services.workflow_app_service.PluginService.get_plugin_icon_url",
|
||||
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
# Assert
|
||||
assert result["icon"] == "https://cdn/light.png"
|
||||
assert result["icon_dark"] == "https://cdn/dark.png"
|
||||
assert mock_icon.call_count == 2
|
||||
|
||||
|
||||
def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup(
|
||||
service: WorkflowAppService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
|
||||
mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url")
|
||||
|
||||
# Act
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
# Assert
|
||||
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
|
||||
mock_icon.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
('{"k":"v"}', {"k": "v"}),
|
||||
("not-json", None),
|
||||
({"raw": True}, {"raw": True}),
|
||||
],
|
||||
)
|
||||
def test_safe_json_loads_should_handle_various_inputs(
|
||||
value: object,
|
||||
expected: object,
|
||||
service: WorkflowAppService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = service._safe_json_loads(value)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
short_result = service._safe_parse_uuid("short")
|
||||
invalid_result = service._safe_parse_uuid("x" * 40)
|
||||
|
||||
# Assert
|
||||
assert short_result is None
|
||||
assert invalid_result is None
|
||||
|
||||
|
||||
def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None:
|
||||
# Arrange
|
||||
raw_uuid = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
result = service._safe_parse_uuid(raw_uuid)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert str(result) == raw_uuid
|
||||
File diff suppressed because it is too large
Load Diff
576
api/tests/unit_tests/services/test_workspace_service.py
Normal file
576
api/tests/unit_tests/services/test_workspace_service.py
Normal file
@ -0,0 +1,576 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from models.account import Tenant
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants used throughout the tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TENANT_ID = "tenant-abc"
|
||||
ACCOUNT_ID = "account-xyz"
|
||||
FILES_BASE_URL = "https://files.example.com"
|
||||
|
||||
DB_PATH = "services.workspace_service.db"
|
||||
FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features"
|
||||
TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles"
|
||||
DIFY_CONFIG_PATH = "services.workspace_service.dify_config"
|
||||
CURRENT_USER_PATH = "services.workspace_service.current_user"
|
||||
CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tenant(
|
||||
tenant_id: str = TENANT_ID,
|
||||
name: str = "My Workspace",
|
||||
plan: str = "sandbox",
|
||||
status: str = "active",
|
||||
custom_config: dict | None = None,
|
||||
) -> Tenant:
|
||||
"""Create a minimal Tenant-like namespace."""
|
||||
return cast(
|
||||
Tenant,
|
||||
SimpleNamespace(
|
||||
id=tenant_id,
|
||||
name=name,
|
||||
plan=plan,
|
||||
status=status,
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
custom_config_dict=custom_config or {},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_feature(
|
||||
can_replace_logo: bool = False,
|
||||
next_credit_reset_date: str | None = None,
|
||||
billing_plan: str = "sandbox",
|
||||
) -> MagicMock:
|
||||
"""Create a feature namespace matching what FeatureService.get_features returns."""
|
||||
feature = MagicMock()
|
||||
feature.can_replace_logo = can_replace_logo
|
||||
feature.next_credit_reset_date = next_credit_reset_date
|
||||
feature.billing.subscription.plan = billing_plan
|
||||
return feature
|
||||
|
||||
|
||||
def _make_pool(quota_limit: int, quota_used: int) -> MagicMock:
|
||||
pool = MagicMock()
|
||||
pool.quota_limit = quota_limit
|
||||
pool.quota_used = quota_used
|
||||
return pool
|
||||
|
||||
|
||||
def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace:
|
||||
return SimpleNamespace(role=role)
|
||||
|
||||
|
||||
def _tenant_info(result: object) -> dict[str, Any] | None:
|
||||
return cast(dict[str, Any] | None, result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user() -> SimpleNamespace:
|
||||
"""Return a lightweight current_user stand-in."""
|
||||
return SimpleNamespace(id=ACCOUNT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
|
||||
"""
|
||||
Patch the common external boundaries used by WorkspaceService.get_tenant_info.
|
||||
|
||||
Returns a dict of named mocks so individual tests can customise them.
|
||||
"""
|
||||
mocker.patch(CURRENT_USER_PATH, mock_current_user)
|
||||
|
||||
mock_db_session = mocker.patch(f"{DB_PATH}.session")
|
||||
mock_query_chain = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query_chain
|
||||
mock_query_chain.where.return_value = mock_query_chain
|
||||
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
|
||||
|
||||
mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature())
|
||||
mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False)
|
||||
mock_config = mocker.patch(DIFY_CONFIG_PATH)
|
||||
mock_config.EDITION = "SELF_HOSTED"
|
||||
mock_config.FILES_URL = FILES_BASE_URL
|
||||
|
||||
return {
|
||||
"db_session": mock_db_session,
|
||||
"query_chain": mock_query_chain,
|
||||
"get_features": mock_feature,
|
||||
"has_roles": mock_has_roles,
|
||||
"config": mock_config,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. None Tenant Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None:
|
||||
"""get_tenant_info should short-circuit and return None for a falsy tenant."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = None
|
||||
|
||||
# Act
|
||||
result = WorkspaceService.get_tenant_info(cast(Tenant, tenant))
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None:
|
||||
"""get_tenant_info treats any falsy value as absent (e.g. empty string, 0)."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange / Act / Assert
|
||||
assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Basic Tenant Info — happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_base_fields(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""get_tenant_info should always return the six base scalar fields."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["id"] == TENANT_ID
|
||||
assert result["name"] == "My Workspace"
|
||||
assert result["plan"] == "sandbox"
|
||||
assert result["status"] == "active"
|
||||
assert result["created_at"] == "2024-01-01T00:00:00Z"
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_populate_role_from_tenant_account_join(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""The 'role' field should be taken from TenantAccountJoin, not the default."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin")
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["role"] == "admin"
|
||||
|
||||
|
||||
def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
The service asserts that TenantAccountJoin exists.
|
||||
Missing join should raise AssertionError.
|
||||
"""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["query_chain"].first.return_value = None
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Logo Customisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config block should appear for OWNER/ADMIN when can_replace_logo is True."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(
|
||||
custom_config={
|
||||
"replace_webapp_logo": True,
|
||||
"remove_webapp_brand": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" in result
|
||||
assert result["custom_config"]["remove_webapp_brand"] is True
|
||||
expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo"
|
||||
assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url
|
||||
|
||||
|
||||
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config should be absent when can_replace_logo is False."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" not in result
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config block is gated on OWNER or ADMIN role."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = False # regular member
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" not in result
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_files_url_for_logo_url(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""The logo URL should use dify_config.FILES_URL as the base."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
custom_base = "https://cdn.mycompany.io"
|
||||
basic_mocks["config"].FILES_URL = custom_base
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(custom_config={"replace_webapp_logo": True})
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Cloud-Edition Credit Features
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
|
||||
"""Patches for CLOUD edition tests, billing plan = professional by default."""
|
||||
mocker.patch(CURRENT_USER_PATH, mock_current_user)
|
||||
|
||||
mock_db_session = mocker.patch(f"{DB_PATH}.session")
|
||||
mock_query_chain = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query_chain
|
||||
mock_query_chain.where.return_value = mock_query_chain
|
||||
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
|
||||
|
||||
mock_feature = mocker.patch(
|
||||
FEATURE_SERVICE_PATH,
|
||||
return_value=_make_feature(
|
||||
can_replace_logo=False,
|
||||
next_credit_reset_date="2025-02-01",
|
||||
billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX,
|
||||
),
|
||||
)
|
||||
mocker.patch(TENANT_SERVICE_PATH, return_value=False)
|
||||
mock_config = mocker.patch(DIFY_CONFIG_PATH)
|
||||
mock_config.EDITION = "CLOUD"
|
||||
mock_config.FILES_URL = FILES_BASE_URL
|
||||
|
||||
return {
|
||||
"db_session": mock_db_session,
|
||||
"query_chain": mock_query_chain,
|
||||
"get_features": mock_feature,
|
||||
"config": mock_config,
|
||||
}
|
||||
|
||||
|
||||
def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""next_credit_reset_date should be present in CLOUD edition."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
CREDIT_POOL_SERVICE_PATH,
|
||||
side_effect=[None, None], # both paid and trial pools absent
|
||||
)
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["next_credit_reset_date"] == "2025-02-01"
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""trial_credits/trial_credits_used come from the paid pool when conditions are met."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=1000, quota_used=200)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool)
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 1000
|
||||
assert result["trial_credits_used"] == 200
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""quota_limit == -1 means unlimited; service should still use the paid pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=-1, quota_used=999)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == -1
|
||||
assert result["trial_credits_used"] == 999
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When paid pool is exhausted (used >= limit), switch to trial pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full
|
||||
trial_pool = _make_pool(quota_limit=100, quota_used=10)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 100
|
||||
assert result["trial_credits_used"] == 10
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When paid_pool is None, fall back to trial pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
trial_pool = _make_pool(quota_limit=50, quota_used=5)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 50
|
||||
assert result["trial_credits_used"] == 5
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
When the subscription plan IS SANDBOX, the paid pool branch is skipped
|
||||
entirely and we fall back to the trial pool.
|
||||
"""
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange — override billing plan to SANDBOX
|
||||
cloud_mocks["get_features"].return_value = _make_feature(
|
||||
next_credit_reset_date="2025-02-01",
|
||||
billing_plan=CloudPlan.SANDBOX,
|
||||
)
|
||||
paid_pool = _make_pool(quota_limit=1000, quota_used=0)
|
||||
trial_pool = _make_pool(quota_limit=200, quota_used=20)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 200
|
||||
assert result["trial_credits_used"] == 20
|
||||
|
||||
|
||||
def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When both paid and trial pools are absent, trial_credits should not be set."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Self-hosted / Non-Cloud Edition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange (basic_mocks already sets EDITION = "SELF_HOSTED")
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "next_credit_reset_date" not in result
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. DB query integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
The DB query for TenantAccountJoin must be scoped to the correct
|
||||
tenant_id and current_user.id.
|
||||
"""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = _make_tenant(tenant_id="my-special-tenant")
|
||||
mock_current_user = mocker.patch(CURRENT_USER_PATH)
|
||||
mock_current_user.id = "special-user-id"
|
||||
|
||||
# Act
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert — db.session.query was invoked (at least once)
|
||||
basic_mocks["db_session"].query.assert_called()
|
||||
@ -0,0 +1,643 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
# Arrange
|
||||
mocked_db = mocker.patch("services.tools.api_tools_manage_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace:
|
||||
return SimpleNamespace(operation_id=operation_id)
|
||||
|
||||
|
||||
def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.parser_api_schema("valid-schema")
|
||||
|
||||
# Assert
|
||||
assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value
|
||||
assert len(result["credentials_schema"]) == 3
|
||||
assert "warning" in result
|
||||
|
||||
|
||||
def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=RuntimeError("bad schema"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"):
|
||||
ApiToolManageService.parser_api_schema("invalid")
|
||||
|
||||
|
||||
def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER)
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=expected,
|
||||
)
|
||||
extra_info: dict[str, str] = {}
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=ValueError("parse failed"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema: parse failed"):
|
||||
ApiToolManageService.convert_schema_to_tool_bundles("schema")
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_provider_already_exists(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="provider provider-a already exists"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name=" provider-a ",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
many_tools = [_tool_bundle(str(i)) for i in range(101)]
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=(many_tools, ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="the number of apis should be less than 100"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_create_provider_when_input_is_valid(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=mock_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.encrypt.return_value = {"auth_type": "none"}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=["news"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_controller.load_bundled_tools.assert_called_once()
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.get",
|
||||
return_value=SimpleNamespace(status_code=200, text="schema-content"),
|
||||
)
|
||||
mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True})
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
|
||||
|
||||
# Assert
|
||||
assert result == {"schema": "schema-content"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [400, 404, 500])
|
||||
def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid(
|
||||
status_code: int,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.get",
|
||||
return_value=SimpleNamespace(status_code=status_code, text="schema-content"),
|
||||
)
|
||||
mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema, please check the url you provided"):
|
||||
ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found(
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="you have not added provider provider-a"):
|
||||
ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
|
||||
|
||||
|
||||
def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")])
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
|
||||
return_value=controller,
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"])
|
||||
mock_convert = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
|
||||
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == [{"name": "tool-a"}, {"name": "tool-b"}]
|
||||
assert mock_convert.call_count == 2
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found(
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="api provider provider-a does not exists"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
original_provider="provider-a",
|
||||
icon={},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_raise_error_when_auth_type_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(credentials={}, name="old")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
original_provider="provider-a",
|
||||
icon={},
|
||||
credentials={},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(
|
||||
credentials={"auth_type": "none", "api_key_value": "encrypted-old"},
|
||||
name="old",
|
||||
icon="",
|
||||
schema="",
|
||||
description="",
|
||||
schema_type_str="",
|
||||
tools_str="",
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
credentials_str="",
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=controller,
|
||||
)
|
||||
cache = MagicMock()
|
||||
encrypter = MagicMock()
|
||||
encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"}
|
||||
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
|
||||
encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(encrypter, cache),
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-new",
|
||||
original_provider="provider-old",
|
||||
icon={"emoji": "E"},
|
||||
credentials={"auth_type": "none", "api_key_value": "***"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=["news"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
assert provider.name == "provider-new"
|
||||
assert provider.privacy_policy == "privacy"
|
||||
assert provider.credentials_str != ""
|
||||
cache.delete.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="you have not added provider provider-a"):
|
||||
ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
|
||||
def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
provider = object()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_db.session.delete.assert_called_once_with(provider)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
expected = {"provider": "value"}
|
||||
mock_get = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolManager.user_get_api_provider",
|
||||
return_value=expected,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1")
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None:
|
||||
# Arrange
|
||||
schema_type = "bad-schema-type"
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema type"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=schema_type, # type: ignore[arg-type]
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=RuntimeError("invalid"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid tool name tool-b"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-b",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_auth_type_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
provider_controller = MagicMock()
|
||||
tool_obj = MagicMock()
|
||||
tool_obj.fork_tool_runtime.return_value = tool_obj
|
||||
tool_obj.validate_credentials.side_effect = ValueError("validation failed")
|
||||
provider_controller.get_tool.return_value = tool_obj
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=provider_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
|
||||
mock_encrypter.mask_plugin_credentials.return_value = {}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "validation failed"}
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
provider_controller = MagicMock()
|
||||
tool_obj = MagicMock()
|
||||
tool_obj.fork_tool_runtime.return_value = tool_obj
|
||||
tool_obj.validate_credentials.return_value = {"ok": True}
|
||||
provider_controller.get_tool.return_value = tool_obj
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=provider_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
|
||||
mock_encrypter.mask_plugin_credentials.return_value = {}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={"x": "1"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": {"ok": True}}
|
||||
|
||||
|
||||
def test_list_api_tools_should_return_all_user_providers_with_converted_tools(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_one = SimpleNamespace(name="p1")
|
||||
provider_two = SimpleNamespace(name="p2")
|
||||
mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two]
|
||||
|
||||
controller_one = MagicMock()
|
||||
controller_one.get_tools.return_value = ["tool-a"]
|
||||
controller_two = MagicMock()
|
||||
controller_two.get_tools.return_value = ["tool-b", "tool-c"]
|
||||
|
||||
user_provider_one = SimpleNamespace(labels=[], tools=[])
|
||||
user_provider_two = SimpleNamespace(labels=[], tools=[])
|
||||
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
|
||||
side_effect=[controller_one, controller_two],
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"])
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider",
|
||||
side_effect=[user_provider_one, user_provider_two],
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider")
|
||||
mock_convert = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
|
||||
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.list_api_tools("tenant-1")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert user_provider_one.tools == [{"name": "tool-a"}]
|
||||
assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}]
|
||||
assert mock_convert.call_count == 3
|
||||
1045
api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py
Normal file
1045
api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
13
sdks/nodejs-client/pnpm-lock.yaml
generated
13
sdks/nodejs-client/pnpm-lock.yaml
generated
@ -324,79 +324,66 @@ packages:
|
||||
resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==}
|
||||
cpu: [arm]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-arm-musleabihf@4.59.0':
|
||||
resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==}
|
||||
cpu: [arm]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-arm64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-arm64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-loong64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==}
|
||||
cpu: [loong64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-loong64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==}
|
||||
cpu: [loong64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-ppc64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==}
|
||||
cpu: [ppc64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-ppc64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==}
|
||||
cpu: [ppc64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-riscv64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==}
|
||||
cpu: [riscv64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-riscv64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==}
|
||||
cpu: [riscv64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-s390x-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==}
|
||||
cpu: [s390x]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-x64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-x64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-openbsd-x64@4.59.0':
|
||||
resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==}
|
||||
|
||||
@ -6,19 +6,23 @@ NEXT_PUBLIC_EDITION=SELF_HOSTED
|
||||
NEXT_PUBLIC_BASE_PATH=
|
||||
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
|
||||
# different from api or web app domain.
|
||||
# example: http://cloud.dify.ai/console/api
|
||||
# example: https://cloud.dify.ai/console/api
|
||||
NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
|
||||
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# console or api domain.
|
||||
# example: http://udify.app/api
|
||||
# example: https://udify.app/api
|
||||
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
|
||||
# Dev-only Hono proxy targets. The frontend keeps requesting http://localhost:5001 directly.
|
||||
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN=
|
||||
|
||||
# Dev-only Hono proxy targets.
|
||||
# The frontend keeps requesting http://localhost:5001 directly,
|
||||
# the proxy server will forward the request to the target server,
|
||||
# so that you don't need to run a separate backend server and use online API in development.
|
||||
HONO_PROXY_HOST=127.0.0.1
|
||||
HONO_PROXY_PORT=5001
|
||||
HONO_CONSOLE_API_PROXY_TARGET=
|
||||
HONO_PUBLIC_API_PROXY_TARGET=
|
||||
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN=
|
||||
|
||||
# The API PREFIX for MARKETPLACE
|
||||
NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Dify Frontend
|
||||
|
||||
This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app).
|
||||
This is a [Next.js] project, but you can dev with [vinext].
|
||||
|
||||
## Getting Started
|
||||
|
||||
@ -8,8 +8,11 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
|
||||
|
||||
Before starting the web frontend service, please make sure the following environment is ready.
|
||||
|
||||
- [Node.js](https://nodejs.org)
|
||||
- [pnpm](https://pnpm.io)
|
||||
- [Node.js]
|
||||
- [pnpm]
|
||||
|
||||
You can also use [Vite+] with the corresponding `vp` commands.
|
||||
For example, use `vp install` instead of `pnpm install` and `vp test` instead of `pnpm run test`.
|
||||
|
||||
> [!TIP]
|
||||
> It is recommended to install and enable Corepack to manage package manager versions automatically:
|
||||
@ -19,7 +22,7 @@ Before starting the web frontend service, please make sure the following environ
|
||||
> corepack enable
|
||||
> ```
|
||||
>
|
||||
> Learn more: [Corepack](https://github.com/nodejs/corepack#readme)
|
||||
> Learn more: [Corepack]
|
||||
|
||||
First, install the dependencies:
|
||||
|
||||
@ -27,31 +30,14 @@ First, install the dependencies:
|
||||
pnpm install
|
||||
```
|
||||
|
||||
Then, configure the environment variables. Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. Modify the values of these environment variables according to your requirements:
|
||||
Then, configure the environment variables.
|
||||
Create a file named `.env.local` in the current directory and copy the contents from `.env.example`.
|
||||
Modify the values of these environment variables according to your requirements:
|
||||
|
||||
```bash
|
||||
cp .env.example .env.local
|
||||
```
|
||||
|
||||
```txt
|
||||
# For production release, change this to PRODUCTION
|
||||
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
|
||||
# The deployment edition, SELF_HOSTED
|
||||
NEXT_PUBLIC_EDITION=SELF_HOSTED
|
||||
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
|
||||
# different from api or web app domain.
|
||||
# example: http://cloud.dify.ai/console/api
|
||||
NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN=
|
||||
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# console or api domain.
|
||||
# example: http://udify.app/api
|
||||
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
|
||||
|
||||
# SENTRY
|
||||
NEXT_PUBLIC_SENTRY_DSN=
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> 1. When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
@ -61,11 +47,16 @@ Finally, run the development server:
|
||||
|
||||
```bash
|
||||
pnpm run dev
|
||||
# or if you are using vinext which provides a better development experience
|
||||
pnpm run dev:vinext
|
||||
# (optional) start the dev proxy server so that you can use online API in development
|
||||
pnpm run dev:proxy
|
||||
```
|
||||
|
||||
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
|
||||
Open <http://localhost:3000> with your browser to see the result.
|
||||
|
||||
You can start editing the file under folder `app`. The page auto-updates as you edit the file.
|
||||
You can start editing the file under folder `app`.
|
||||
The page auto-updates as you edit the file.
|
||||
|
||||
## Deploy
|
||||
|
||||
@ -91,7 +82,7 @@ pnpm run start --port=3001 --host=0.0.0.0
|
||||
|
||||
## Storybook
|
||||
|
||||
This project uses [Storybook](https://storybook.js.org/) for UI component development.
|
||||
This project uses [Storybook] for UI component development.
|
||||
|
||||
To start the storybook server, run:
|
||||
|
||||
@ -99,19 +90,24 @@ To start the storybook server, run:
|
||||
pnpm storybook
|
||||
```
|
||||
|
||||
Open [http://localhost:6006](http://localhost:6006) with your browser to see the result.
|
||||
Open <http://localhost:6006> with your browser to see the result.
|
||||
|
||||
## Lint Code
|
||||
|
||||
If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscode/settings.json` for lint code setting.
|
||||
|
||||
Then follow the [Lint Documentation](./docs/lint.md) to lint the code.
|
||||
Then follow the [Lint Documentation] to lint the code.
|
||||
|
||||
## Test
|
||||
|
||||
We use [Vitest](https://vitest.dev/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing.
|
||||
We use [Vitest] and [React Testing Library] for Unit Testing.
|
||||
|
||||
**📖 Complete Testing Guide**: See [web/testing/testing.md](./testing/testing.md) for detailed testing specifications, best practices, and examples.
|
||||
**📖 Complete Testing Guide**: See [web/docs/test.md] for detailed testing specifications, best practices, and examples.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> As we are using Vite+, the `vitest` command is not available.
|
||||
> Please make sure to run tests with `vp` commands.
|
||||
> For example, use `npx vp test` instead of `npx vitest`.
|
||||
|
||||
Run test:
|
||||
|
||||
@ -119,12 +115,17 @@ Run test:
|
||||
pnpm test
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Our test is not fully stable yet, and we are actively working on improving it.
|
||||
> If you encounter test failures only in CI but not locally, please feel free to ignore them and report the issue to us.
|
||||
> You can try to re-run the test in CI, and it may pass successfully.
|
||||
|
||||
### Example Code
|
||||
|
||||
If you are not familiar with writing tests, refer to:
|
||||
|
||||
- [classnames.spec.ts](./utils/classnames.spec.ts) - Utility function test example
|
||||
- [index.spec.tsx](./app/components/base/button/index.spec.tsx) - Component test example
|
||||
- [classnames.spec.ts] - Utility function test example
|
||||
- [index.spec.tsx] - Component test example
|
||||
|
||||
### Analyze Component Complexity
|
||||
|
||||
@ -134,7 +135,7 @@ Before writing tests, use the script to analyze component complexity:
|
||||
pnpm analyze-component app/components/your-component/index.tsx
|
||||
```
|
||||
|
||||
This will help you determine the testing strategy. See [web/testing/testing.md](./testing/testing.md) for details.
|
||||
This will help you determine the testing strategy. See [web/testing/testing.md] for details.
|
||||
|
||||
## Documentation
|
||||
|
||||
@ -142,4 +143,19 @@ Visit <https://docs.dify.ai> to view the full documentation.
|
||||
|
||||
## Community
|
||||
|
||||
The Dify community can be found on [Discord community](https://discord.gg/5AEfbxcd9k), where you can ask questions, voice ideas, and share your projects.
|
||||
The Dify community can be found on [Discord community], where you can ask questions, voice ideas, and share your projects.
|
||||
|
||||
[Corepack]: https://github.com/nodejs/corepack#readme
|
||||
[Discord community]: https://discord.gg/5AEfbxcd9k
|
||||
[Lint Documentation]: ./docs/lint.md
|
||||
[Next.js]: https://nextjs.org
|
||||
[Node.js]: https://nodejs.org
|
||||
[React Testing Library]: https://testing-library.com/docs/react-testing-library/intro
|
||||
[Storybook]: https://storybook.js.org
|
||||
[Vite+]: https://viteplus.dev
|
||||
[Vitest]: https://vitest.dev
|
||||
[classnames.spec.ts]: ./utils/classnames.spec.ts
|
||||
[index.spec.tsx]: ./app/components/base/button/index.spec.tsx
|
||||
[pnpm]: https://pnpm.io
|
||||
[vinext]: https://github.com/cloudflare/vinext
|
||||
[web/docs/test.md]: ./docs/test.md
|
||||
|
||||
@ -95,7 +95,7 @@ describe('Cloud Plan Payment Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
cleanup()
|
||||
toast.close()
|
||||
toast.dismiss()
|
||||
setupAppContext()
|
||||
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' })
|
||||
mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' })
|
||||
|
||||
@ -66,7 +66,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
cleanup()
|
||||
toast.close()
|
||||
toast.dismiss()
|
||||
setupAppContext()
|
||||
|
||||
// Mock window.location with minimal getter/setter (Location props are non-enumerable)
|
||||
|
||||
@ -11,8 +11,8 @@ import SideBar from '@/app/components/explore/sidebar'
|
||||
import { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
const { mockToastAdd } = vi.hoisted(() => ({
|
||||
mockToastAdd: vi.fn(),
|
||||
const { mockToastSuccess } = vi.hoisted(() => ({
|
||||
mockToastSuccess: vi.fn(),
|
||||
}))
|
||||
|
||||
let mockMediaType: string = MediaType.pc
|
||||
@ -53,14 +53,16 @@ vi.mock('@/service/use-explore', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
add: mockToastAdd,
|
||||
close: vi.fn(),
|
||||
update: vi.fn(),
|
||||
promise: vi.fn(),
|
||||
},
|
||||
}))
|
||||
vi.mock('@/app/components/base/ui/toast', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/app/components/base/ui/toast')>()
|
||||
return {
|
||||
...actual,
|
||||
toast: {
|
||||
...actual.toast,
|
||||
success: mockToastSuccess,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
const createInstalledApp = (overrides: Partial<InstalledApp> = {}): InstalledApp => ({
|
||||
id: overrides.id ?? 'app-1',
|
||||
@ -105,9 +107,7 @@ describe('Sidebar Lifecycle Flow', () => {
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: true })
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'success',
|
||||
}))
|
||||
expect(mockToastSuccess).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Step 2: Simulate refetch returning pinned state, then unpin
|
||||
@ -124,9 +124,7 @@ describe('Sidebar Lifecycle Flow', () => {
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: false })
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'success',
|
||||
}))
|
||||
expect(mockToastSuccess).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -150,10 +148,7 @@ describe('Sidebar Lifecycle Flow', () => {
|
||||
// Step 4: Uninstall API called and success toast shown
|
||||
await waitFor(() => {
|
||||
expect(mockUninstall).toHaveBeenCalledWith('app-1')
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'success',
|
||||
title: 'common.api.remove',
|
||||
}))
|
||||
expect(mockToastSuccess).toHaveBeenCalledWith('common.api.remove')
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -24,17 +24,11 @@ export default function CheckCode() {
|
||||
const verify = async () => {
|
||||
try {
|
||||
if (!code.trim()) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('checkCode.emptyCode', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('checkCode.emptyCode', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
if (!/\d{6}/.test(code)) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('checkCode.invalidCode', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('checkCode.invalidCode', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
|
||||
@ -27,15 +27,12 @@ export default function CheckCode() {
|
||||
const handleGetEMailVerificationCode = async () => {
|
||||
try {
|
||||
if (!email) {
|
||||
toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) })
|
||||
toast.error(t('error.emailEmpty', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
|
||||
if (!emailRegex.test(email)) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('error.emailInValid', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('error.emailInValid', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
@ -48,16 +45,10 @@ export default function CheckCode() {
|
||||
router.push(`/webapp-reset-password/check-code?${params.toString()}`)
|
||||
}
|
||||
else if (res.code === 'account_not_found') {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('error.registrationNotAllowed', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('error.registrationNotAllowed', { ns: 'login' }))
|
||||
}
|
||||
else {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: res.data,
|
||||
})
|
||||
toast.error(res.data)
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
|
||||
@ -24,10 +24,7 @@ const ChangePasswordForm = () => {
|
||||
const [showConfirmPassword, setShowConfirmPassword] = useState(false)
|
||||
|
||||
const showErrorMessage = useCallback((message: string) => {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: message,
|
||||
})
|
||||
toast.error(message)
|
||||
}, [])
|
||||
|
||||
const getSignInUrl = () => {
|
||||
|
||||
@ -43,24 +43,15 @@ export default function CheckCode() {
|
||||
try {
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
if (!code.trim()) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('checkCode.emptyCode', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('checkCode.emptyCode', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
if (!/\d{6}/.test(code)) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('checkCode.invalidCode', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('checkCode.invalidCode', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
if (!redirectUrl || !appCode) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('error.redirectUrlMissing', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('error.redirectUrlMissing', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
|
||||
@ -17,10 +17,7 @@ const ExternalMemberSSOAuth = () => {
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
|
||||
const showErrorToast = (message: string) => {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: message,
|
||||
})
|
||||
toast.error(message)
|
||||
}
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
|
||||
@ -22,15 +22,12 @@ export default function MailAndCodeAuth() {
|
||||
const handleGetEMailVerificationCode = async () => {
|
||||
try {
|
||||
if (!email) {
|
||||
toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) })
|
||||
toast.error(t('error.emailEmpty', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
|
||||
if (!emailRegex.test(email)) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('error.emailInValid', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('error.emailInValid', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
|
||||
@ -46,26 +46,20 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
const handleEmailPasswordLogin = async () => {
|
||||
if (!email) {
|
||||
toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) })
|
||||
toast.error(t('error.emailEmpty', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
if (!emailRegex.test(email)) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('error.emailInValid', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('error.emailInValid', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
if (!password?.trim()) {
|
||||
toast.add({ type: 'error', title: t('error.passwordEmpty', { ns: 'login' }) })
|
||||
toast.error(t('error.passwordEmpty', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
|
||||
if (!redirectUrl || !appCode) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('error.redirectUrlMissing', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('error.redirectUrlMissing', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
try {
|
||||
@ -94,15 +88,12 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}
|
||||
else {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: res.data,
|
||||
})
|
||||
toast.error(res.data)
|
||||
}
|
||||
}
|
||||
catch (e: any) {
|
||||
if (e.code === 'authentication_failed')
|
||||
toast.add({ type: 'error', title: e.message })
|
||||
toast.error(e.message)
|
||||
}
|
||||
finally {
|
||||
setIsLoading(false)
|
||||
|
||||
@ -37,10 +37,7 @@ const SSOAuth: FC<SSOAuthProps> = ({
|
||||
const handleSSOLogin = () => {
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
if (!redirectUrl || !appCode) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('error.invalidRedirectUrlOrAppCode', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('error.invalidRedirectUrlOrAppCode', { ns: 'login' }))
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
@ -66,10 +63,7 @@ const SSOAuth: FC<SSOAuthProps> = ({
|
||||
})
|
||||
}
|
||||
else {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: t('error.invalidSSOProtocol', { ns: 'login' }),
|
||||
})
|
||||
toast.error(t('error.invalidSSOProtocol', { ns: 'login' }))
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,10 +91,7 @@ export default function OAuthAuthorize() {
|
||||
globalThis.location.href = url.toString()
|
||||
}
|
||||
catch (err: any) {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`,
|
||||
})
|
||||
toast.error(`${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`)
|
||||
}
|
||||
}
|
||||
|
||||
@ -102,11 +99,10 @@ export default function OAuthAuthorize() {
|
||||
const invalidParams = !client_id || !redirect_uri
|
||||
if ((invalidParams || isError) && !hasNotifiedRef.current) {
|
||||
hasNotifiedRef.current = true
|
||||
toast.add({
|
||||
type: 'error',
|
||||
title: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
|
||||
timeout: 0,
|
||||
})
|
||||
toast.error(
|
||||
invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
|
||||
{ timeout: 0 },
|
||||
)
|
||||
}
|
||||
}, [client_id, redirect_uri, isError])
|
||||
|
||||
|
||||
@ -137,10 +137,7 @@ const Apps = ({
|
||||
})
|
||||
|
||||
setIsShowCreateModal(false)
|
||||
toast.add({
|
||||
type: 'success',
|
||||
title: t('newApp.appCreated', { ns: 'app' }),
|
||||
})
|
||||
toast.success(t('newApp.appCreated', { ns: 'app' }))
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
if (app.app_id)
|
||||
@ -149,7 +146,7 @@ const Apps = ({
|
||||
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id!, mode }, push)
|
||||
}
|
||||
catch {
|
||||
toast.add({ type: 'error', title: t('newApp.appCreateFailed', { ns: 'app' }) })
|
||||
toast.error(t('newApp.appCreateFailed', { ns: 'app' }))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -141,6 +141,145 @@ describe('useChat', () => {
|
||||
expect(result.current.chatList[0].suggestedQuestions).toEqual(['Ask Bob'])
|
||||
})
|
||||
|
||||
describe('opening statement referential stability', () => {
|
||||
it('should keep the same item reference across multiple streaming chatTree mutations', () => {
|
||||
let callbacks: HookCallbacks
|
||||
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
callbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const config = {
|
||||
opening_statement: 'Welcome!',
|
||||
suggested_questions: ['Q1', 'Q2'],
|
||||
}
|
||||
const { result } = renderHook(() => useChat(config as ChatConfig))
|
||||
|
||||
const openerInitial = result.current.chatList[0]
|
||||
expect(openerInitial.isOpeningStatement).toBe(true)
|
||||
expect(openerInitial.content).toBe('Welcome!')
|
||||
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'hello' }, {})
|
||||
})
|
||||
|
||||
act(() => {
|
||||
callbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' })
|
||||
})
|
||||
expect(result.current.chatList[0]).toBe(openerInitial)
|
||||
|
||||
act(() => {
|
||||
callbacks.onData('chunk-1 ', true, { messageId: 'm-1', conversationId: 'c-1', taskId: 't-1' })
|
||||
})
|
||||
expect(result.current.chatList.length).toBeGreaterThan(1)
|
||||
expect(result.current.chatList[0]).toBe(openerInitial)
|
||||
|
||||
act(() => {
|
||||
callbacks.onData('chunk-2 ', false, { messageId: 'm-1' })
|
||||
})
|
||||
expect(result.current.chatList[0]).toBe(openerInitial)
|
||||
|
||||
act(() => {
|
||||
callbacks.onData('chunk-3', false, { messageId: 'm-1' })
|
||||
callbacks.onMessageEnd({ metadata: { retriever_resources: [] } })
|
||||
callbacks.onWorkflowFinished({ data: { status: 'succeeded' } })
|
||||
callbacks.onCompleted()
|
||||
})
|
||||
expect(result.current.chatList[0]).toBe(openerInitial)
|
||||
expect(result.current.chatList.at(-1)!.content).toBe('chunk-1 chunk-2 chunk-3')
|
||||
})
|
||||
|
||||
it('should keep stable reference when getIntroduction identity changes but output is identical', () => {
|
||||
const config = {
|
||||
opening_statement: 'Hello {{name}}',
|
||||
suggested_questions: ['Ask about {{name}}'],
|
||||
}
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ fs }) => useChat(config as ChatConfig, fs as UseChatFormSettings),
|
||||
{ initialProps: { fs: { inputs: { name: 'Alice' }, inputsForm: [] } } },
|
||||
)
|
||||
|
||||
const openerBefore = result.current.chatList[0]
|
||||
expect(openerBefore.content).toBe('Hello Alice')
|
||||
expect(openerBefore.suggestedQuestions).toEqual(['Ask about Alice'])
|
||||
|
||||
rerender({ fs: { inputs: { name: 'Alice' }, inputsForm: [] } })
|
||||
|
||||
expect(result.current.chatList[0]).toBe(openerBefore)
|
||||
})
|
||||
|
||||
it('should produce a new item when the processed content actually changes', () => {
|
||||
const config = {
|
||||
opening_statement: 'Hello {{name}}',
|
||||
suggested_questions: ['Ask {{name}}'],
|
||||
}
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ fs }) => useChat(config as ChatConfig, fs as UseChatFormSettings),
|
||||
{ initialProps: { fs: { inputs: { name: 'Alice' }, inputsForm: [] } } },
|
||||
)
|
||||
|
||||
const before = result.current.chatList[0]
|
||||
|
||||
rerender({ fs: { inputs: { name: 'Bob' }, inputsForm: [] } })
|
||||
|
||||
const after = result.current.chatList[0]
|
||||
expect(after).not.toBe(before)
|
||||
expect(after.content).toBe('Hello Bob')
|
||||
expect(after.suggestedQuestions).toEqual(['Ask Bob'])
|
||||
})
|
||||
|
||||
it('should keep content and suggestedQuestions stable for opener already in prevChatTree even when sibling metadata changes', () => {
|
||||
let callbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
callbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const config = {
|
||||
opening_statement: 'Hello updated',
|
||||
suggested_questions: ['S1'],
|
||||
}
|
||||
const prevChatTree = [{
|
||||
id: 'opening-statement',
|
||||
content: 'old',
|
||||
isAnswer: true,
|
||||
isOpeningStatement: true,
|
||||
suggestedQuestions: [],
|
||||
}]
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useChat(config as ChatConfig, undefined, prevChatTree as ChatItemInTree[]),
|
||||
)
|
||||
|
||||
const openerBefore = result.current.chatList[0]
|
||||
expect(openerBefore.content).toBe('Hello updated')
|
||||
expect(openerBefore.suggestedQuestions).toEqual(['S1'])
|
||||
|
||||
const contentBefore = openerBefore.content
|
||||
const suggestionsBefore = openerBefore.suggestedQuestions
|
||||
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'msg' }, {})
|
||||
})
|
||||
act(() => {
|
||||
callbacks.onData('resp', true, { messageId: 'm-1', conversationId: 'c-1', taskId: 't-1' })
|
||||
})
|
||||
|
||||
expect(result.current.chatList.length).toBeGreaterThan(1)
|
||||
const openerAfter = result.current.chatList[0]
|
||||
expect(openerAfter.content).toBe(contentBefore)
|
||||
expect(openerAfter.suggestedQuestions).toBe(suggestionsBefore)
|
||||
})
|
||||
|
||||
it('should use a stable id of "opening-statement"', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useChat({ opening_statement: 'Hi' } as ChatConfig),
|
||||
)
|
||||
expect(result.current.chatList[0].id).toBe('opening-statement')
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleSend', () => {
|
||||
it('should block send if already responding', async () => {
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user